batchwriteitem is alive

This commit is contained in:
Emil Lerch 2024-02-08 15:25:41 -08:00
parent 6b7ff24f69
commit 685bae479b
Signed by: lobo
GPG Key ID: A7B62D657EF764F8
5 changed files with 651 additions and 52 deletions

View File

@ -30,11 +30,18 @@ pub fn deinit(self: Self) void {
pub var data_dir: []const u8 = "";
pub var test_retain_db: bool = false;
var test_db: ?sqlite.Db = null;
var test_db: ?*sqlite.Db = null;
pub fn testDbDeinit() void {
test_retain_db = false;
if (test_db) |db| {
db.deinit();
test_db = null;
}
}
/// Gets the database for this account. If under test, a memory database is used
/// instead. Will initialize the database with appropriate metadata tables
pub fn dbForAccount(allocator: std.mem.Allocator, account_id: []const u8) !sqlite.Db {
pub fn dbForAccount(allocator: std.mem.Allocator, account_id: []const u8) !*sqlite.Db {
const builtin = @import("builtin");
if (builtin.is_test and test_retain_db)
if (test_db) |db| return db;
@ -47,7 +54,8 @@ pub fn dbForAccount(allocator: std.mem.Allocator, account_id: []const u8) !sqlit
defer allocator.free(db_file_name);
const mode = if (builtin.is_test) sqlite.Db.Mode.Memory else sqlite.Db.Mode{ .File = db_file_name };
const new = mode == .Memory or (std.fs.cwd().statFile(file_without_path) catch null == null);
var db = try sqlite.Db.init(.{
var db = try allocator.create(sqlite.Db);
db.* = try sqlite.Db.init(.{
.mode = mode,
.open_flags = .{
.write = true,

View File

@ -21,6 +21,63 @@ const AttributeValue = union(ddb.AttributeTypeName) {
binary_set: [][]const u8,
const Self = @This();
pub fn jsonParse(
allocator: std.mem.Allocator,
source: *std.json.Scanner,
options: std.json.ParseOptions,
) !Self {
if (.object_begin != try source.next()) return error.UnexpectedToken;
const token = try source.nextAlloc(allocator, options.allocate.?);
if (token != .string) return error.UnexpectedToken;
var rc: Self = undefined;
if (std.mem.eql(u8, token.string, "string"))
rc = Self{ .string = try std.json.innerParse([]const u8, allocator, source, options) };
if (std.mem.eql(u8, token.string, "number"))
rc = Self{ .number = try std.json.innerParse([]const u8, allocator, source, options) };
if (std.mem.eql(u8, token.string, "binary"))
rc = Self{ .binary = try std.json.innerParse([]const u8, allocator, source, options) };
if (std.mem.eql(u8, token.string, "boolean"))
rc = Self{ .boolean = try std.json.innerParse(bool, allocator, source, options) };
if (std.mem.eql(u8, token.string, "null"))
rc = Self{ .null = try std.json.innerParse(bool, allocator, source, options) };
if (std.mem.eql(u8, token.string, "string_set"))
rc = Self{ .string_set = try std.json.innerParse([][]const u8, allocator, source, options) };
if (std.mem.eql(u8, token.string, "number_set"))
rc = Self{ .number_set = try std.json.innerParse([][]const u8, allocator, source, options) };
if (std.mem.eql(u8, token.string, "binary_set"))
rc = Self{ .binary_set = try std.json.innerParse([][]const u8, allocator, source, options) };
if (std.mem.eql(u8, token.string, "list")) {
var json = try std.json.Value.jsonParse(allocator, source, options);
rc = Self{ .list = json.array };
}
if (std.mem.eql(u8, token.string, "map")) {
var json = try std.json.Value.jsonParse(allocator, source, options);
rc = Self{ .map = json.object };
}
if (.object_end != try source.next()) return error.UnexpectedToken;
return rc;
}
pub fn jsonStringify(self: Self, jws: anytype) !void {
try jws.beginObject();
try jws.objectField(@tagName(self));
switch (self) {
.string, .number, .binary => |s| try jws.write(s),
.boolean, .null => |b| try jws.write(b),
.string_set, .number_set, .binary_set => |s| try jws.write(s),
.list => |l| try jws.write(l.items),
.map => |inner| {
try jws.beginObject();
var it = inner.iterator();
while (it.next()) |entry| {
try jws.objectField(entry.key_ptr.*);
try jws.write(entry.value_ptr.*);
}
try jws.endObject();
},
}
return try jws.endObject();
}
pub fn validate(self: Self) !void {
switch (self) {
.string, .string_set, .boolean, .null, .map, .list => {},
@ -475,19 +532,114 @@ const Params = struct {
pub fn handler(request: *AuthenticatedRequest, writer: anytype) ![]const u8 {
const allocator = request.allocator;
const account_id = request.account_id;
_ = account_id;
var params = try Params.parseRequest(allocator, request, writer);
defer params.deinit();
try params.validate();
// 1. Get the list of encrypted table names using the account id root key
// 2. Get the matching table-scope encryption keys
// 3. For each table request:
// 1. Find the hash values of put and delete requests in the request
// 2. Encrypt the hash values
// 3. Delete any existing records with that hash value (for delete requests, we're done here)
// 4. If put request, put the new item in the table (with encrypted values, using table encryption)
var account_tables = try ddb.tablesForAccount(allocator, account_id);
defer account_tables.deinit();
// 2. For each table request:
for (params.request_items) |table_req| {
var request_table: ddb.Table = undefined;
var found = false;
for (account_tables.items) |tbl| {
if (std.mem.eql(u8, tbl.name, table_req.table_name)) {
request_table = tbl;
found = true;
}
}
if (!found) {
std.log.warn("Table name in request does not exist in account. Table name specified: {s}", .{table_req.table_name});
continue; // TODO: This API has the concept of returning the list of unprocessed stuff. We need to do that here
}
for (table_req.requests) |req| {
if (req.put_request) |p|
try process_request(allocator, account_tables.db, &request_table, .put, p);
if (req.delete_request) |d|
try process_request(allocator, account_tables.db, &request_table, .delete, d);
}
}
// TODO: Capacity limiting and metrics
return "hi";
if (params.return_consumed_capacity != .none or params.return_item_collection_metrics)
try returnException(
request,
.internal_server_error,
error.NotImplemented,
writer,
"Changes processed, but metrics/capacity are not yet implemented",
);
// {
// "UnprocessedItems": {
// "Forum": [
// {
// "PutRequest": {
// "Item": {
// "Name": {
// "S": "Amazon ElastiCache"
// },
// "Category": {
// "S": "Amazon Web Services"
// }
// }
// }
// }
// ]
// },
// "ConsumedCapacity": [
// {
// "TableName": "Forum",
// "CapacityUnits": 3
// }
// ]
// }
return "{}";
}
const RequestType = enum {
put,
delete,
};
fn process_request(
allocator: std.mem.Allocator,
db: anytype,
table: *ddb.Table,
req_type: RequestType,
req_attributes: []Attribute,
) !void {
_ = db;
// 1. Find the hash values of put and delete requests in the request
const hash_key_attribute_name = table.info.value.hash_key_attribute_name;
const range_key_attribute_name = table.info.value.range_key_attribute_name;
var hash_attribute: ?Attribute = null;
var range_attribute: ?Attribute = null;
for (req_attributes) |*att| {
if (std.mem.eql(u8, att.name, hash_key_attribute_name)) {
hash_attribute = att.*;
continue;
}
if (range_key_attribute_name) |r| {
if (std.mem.eql(u8, att.name, r))
range_attribute = att.*;
}
}
if (hash_attribute == null) return error.HashAttributeNotFound;
if (range_attribute == null and range_key_attribute_name != null) return error.RangeAttributeNotFound;
const hash_value = try std.json.stringifyAlloc(allocator, hash_attribute.?.value, .{});
defer allocator.free(hash_value);
const range_value = if (range_attribute) |r|
try std.json.stringifyAlloc(allocator, r.value, .{})
else
null;
defer if (range_value) |r| allocator.free(r);
if (req_type == .delete) {
try table.deleteItem(hash_value, range_value);
} else {
const attributes_as_string = try std.json.stringifyAlloc(allocator, req_attributes, .{});
defer allocator.free(attributes_as_string);
try table.putItem(hash_value, range_value, attributes_as_string);
}
}
test "basic request parsing failure" {
@ -661,3 +813,248 @@ test "all types request parsing" {
try std.testing.expectEqualStrings("this text is base64-encoded", buf);
try std.testing.expect(put_and_or_delete.delete_request == null);
}
test "write item" {
Account.test_retain_db = true;
defer Account.testDbDeinit();
const allocator = std.testing.allocator;
const account_id = "1234";
var db = try Account.dbForAccount(allocator, account_id);
defer allocator.destroy(db);
defer Account.testDbDeinit();
const account = try Account.accountForId(allocator, account_id); // This will get us the encryption key needed
defer account.deinit();
var hash = ddb.AttributeDefinition{ .name = "Artist", .type = .S };
var range = ddb.AttributeDefinition{ .name = "SongTitle", .type = .S };
var definitions = @constCast(&[_]*ddb.AttributeDefinition{
&hash,
&range,
});
var table_info: ddb.TableInfo = .{
.table_key = undefined,
.attribute_definitions = definitions[0..],
.hash_key_attribute_name = "Artist",
.range_key_attribute_name = "SongTitle",
};
encryption.randomEncodedKey(&table_info.table_key);
try ddb.createDdbTable(
allocator,
db,
account,
"MusicCollection",
table_info,
5,
5,
false,
);
var request = AuthenticatedRequest{
.output_format = .text,
.event_data =
\\ {
\\ "RequestItems": {
\\ "MusicCollection": [
\\ {
\\ "PutRequest": {
\\ "Item": {
\\ "Artist": {
\\ "S": "Mettalica"
\\ },
\\ "SongTitle": {
\\ "S": "Master of Puppets"
\\ },
\\ "Binary": {
\\ "B": "dGhpcyB0ZXh0IGlzIGJhc2U2NC1lbmNvZGVk"
\\ },
\\ "Boolean": {
\\ "BOOL": true
\\ },
\\ "Null": {
\\ "NULL": true
\\ },
\\ "List": {
\\ "L": [ {"S": "Cookies"} , {"S": "Coffee"}, {"N": "3.14159"}]
\\ },
\\ "Map": {
\\ "M": {"Name": {"S": "Joe"}, "Age": {"N": "35"}}
\\ },
\\ "Number Set": {
\\ "NS": ["42.2", "-19", "7.5", "3.14"]
\\ },
\\ "Binary Set": {
\\ "BS": ["U3Vubnk=", "UmFpbnk=", "U25vd3k="]
\\ },
\\ "String Set": {
\\ "SS": ["Giraffe", "Hippo" ,"Zebra"]
\\ }
\\ }
\\ }
\\ }
\\ ]
\\ }
\\ }
,
.headers = undefined,
.status = .ok,
.reason = "",
.account_id = "1234",
.allocator = allocator,
};
var al = std.ArrayList(u8).init(allocator);
defer al.deinit();
var writer = al.writer();
_ = try handler(&request, writer);
}
test "round trip attributes" {
const allocator = std.testing.allocator;
var json_stuff = try std.json.parseFromSlice(std.json.Value, allocator,
\\ {
\\ "M": {"Name": {"S": "Joe"}, "Age": {"N": "35"}},
\\ "L": [ {"S": "Cookies"} , {"S": "Coffee"}, {"N": "3.14159"}]
\\ }
, .{});
defer json_stuff.deinit();
const map = json_stuff.value.object.get("M").?.object;
const list = json_stuff.value.object.get("L").?.array;
const attributes = &[_]Attribute{
.{
.name = "foo",
.value = .{ .string = "bar" },
},
.{
.name = "foo",
.value = .{ .number = "42" },
},
.{
.name = "foo",
.value = .{ .binary = "YmFy" }, // "bar"
},
.{
.name = "foo",
.value = .{ .boolean = true },
},
.{
.name = "foo",
.value = .{ .null = false },
},
.{
.name = "foo",
.value = .{ .string_set = @constCast(&[_][]const u8{ "foo", "bar" }) },
},
.{
.name = "foo",
.value = .{ .number_set = @constCast(&[_][]const u8{ "41", "42" }) },
},
.{
.name = "foo",
.value = .{ .binary_set = @constCast(&[_][]const u8{ "Zm9v", "YmFy" }) }, // foo, bar
},
.{
.name = "foo",
.value = .{ .map = map },
},
.{
.name = "foo",
.value = .{ .list = list },
},
};
const attributes_as_string = try std.json.stringifyAlloc(
allocator,
attributes,
.{ .whitespace = .indent_2 },
);
defer allocator.free(attributes_as_string);
try std.testing.expectEqualStrings(
\\[
\\ {
\\ "name": "foo",
\\ "value": {
\\ "string": "bar"
\\ }
\\ },
\\ {
\\ "name": "foo",
\\ "value": {
\\ "number": "42"
\\ }
\\ },
\\ {
\\ "name": "foo",
\\ "value": {
\\ "binary": "YmFy"
\\ }
\\ },
\\ {
\\ "name": "foo",
\\ "value": {
\\ "boolean": true
\\ }
\\ },
\\ {
\\ "name": "foo",
\\ "value": {
\\ "null": false
\\ }
\\ },
\\ {
\\ "name": "foo",
\\ "value": {
\\ "string_set": [
\\ "foo",
\\ "bar"
\\ ]
\\ }
\\ },
\\ {
\\ "name": "foo",
\\ "value": {
\\ "number_set": [
\\ "41",
\\ "42"
\\ ]
\\ }
\\ },
\\ {
\\ "name": "foo",
\\ "value": {
\\ "binary_set": [
\\ "Zm9v",
\\ "YmFy"
\\ ]
\\ }
\\ },
\\ {
\\ "name": "foo",
\\ "value": {
\\ "map": {
\\ "Name": {
\\ "S": "Joe"
\\ },
\\ "Age": {
\\ "N": "35"
\\ }
\\ }
\\ }
\\ },
\\ {
\\ "name": "foo",
\\ "value": {
\\ "list": [
\\ {
\\ "S": "Cookies"
\\ },
\\ {
\\ "S": "Coffee"
\\ },
\\ {
\\ "N": "3.14159"
\\ }
\\ ]
\\ }
\\ }
\\]
, attributes_as_string);
var round_tripped = try std.json.parseFromSlice([]Attribute, allocator, attributes_as_string, .{});
defer round_tripped.deinit();
}

View File

@ -62,12 +62,14 @@ pub fn handler(request: *AuthenticatedRequest, writer: anytype) ![]const u8 {
allocator.free(request_params.table_info.attribute_definitions);
}
var db = try Account.dbForAccount(allocator, account_id);
defer allocator.destroy(db);
defer db.deinit();
const account = try Account.accountForId(allocator, account_id); // This will get us the encryption key needed
defer account.deinit();
try ddb.createDdbTable(
allocator,
&db,
db,
account,
request_params.table_name,
request_params.table_info,

View File

@ -5,6 +5,17 @@ const Account = @import("Account.zig");
const encryption = @import("encryption.zig");
const builtin = @import("builtin");
// We need our enryption to be able to store/retrieve and otherwise work like
// a database. So the use of a nonce here defeats these use cases
const nonce = &[_]u8{
0x55, 0x4a, 0x38, 0x16, 0x55, 0x55, 0x2d, 0x05,
0x32, 0x70, 0x3f, 0xa0, 0xde, 0x3d, 0x2c, 0xb8,
0x89, 0x40, 0x07, 0xc5, 0x57, 0x7d, 0xa0, 0xb8,
};
fn encryptAndEncode(allocator: std.mem.Allocator, key: [encryption.key_length]u8, plaintext: []const u8) ![]const u8 {
return try encryption.encryptAndEncodeWithNonce(allocator, key, nonce.*, plaintext);
}
/// Serialized into metadata table. This is an explicit enum with a twin
/// AttributeTypeName enum to make coding with these types easier. Use
/// Descriptor for storage or communication with the outside world, and
@ -60,75 +71,193 @@ pub const TableInfo = struct {
range_key_attribute_name: ?[]const u8,
};
pub const TableArray = struct {
pub const AccountTables = struct {
items: []Table,
db: *sqlite.Db,
allocator: std.mem.Allocator,
pub fn init(allocator: std.mem.Allocator, length: usize) !TableArray {
pub fn init(allocator: std.mem.Allocator, length: usize, db: *sqlite.Db) !AccountTables {
return .{
.allocator = allocator,
.items = try allocator.alloc(Table, length),
.db = db,
};
}
pub fn deinit(self: *TableArray) void {
pub fn deinit(self: *AccountTables) void {
for (self.items) |*item|
item.deinit();
self.allocator.free(self.items);
}
};
pub const Table = struct {
table_name: []const u8,
table_key: [encryption.key_length]u8,
name: []const u8,
key: [encryption.key_length]u8,
info: std.json.Parsed(TableInfo),
/// underlying data for json parsed version
info_str: []const u8,
db: *sqlite.Db,
allocator: std.mem.Allocator,
encrypted_name: []const u8,
pub fn deleteItem(self: *Table, hash_value: []const u8, range_value: ?[]const u8) !void {
const encrypted_hash = try encryptAndEncode(self.allocator, self.key, hash_value);
defer self.allocator.free(encrypted_hash);
const encrypted_range = if (range_value) |r|
try encryptAndEncode(self.allocator, self.key, r)
else
null;
defer if (encrypted_range != null) self.allocator.free(encrypted_range.?);
// TODO: hashKey and rangeKey are text, while hashvalue/rangevalue are blobx
// this is to accomodate non-string hash/range value by running a hash
// function over the data, probably base64 encoded. Do we want to do
// something like this?
const delete = try std.fmt.allocPrint(self.allocator,
\\DELETE FROM '{s}' WHERE
\\ hashKey = ?
\\ AND
\\ rangeKey = ?
, .{self.encrypted_name});
defer self.allocator.free(delete);
try self.db.execDynamic(delete, .{}, .{
encrypted_hash,
encrypted_range,
});
}
pub fn putItem(
self: *Table,
hash_value: []const u8,
range_value: ?[]const u8,
data: []const u8,
) !void {
var sp = try self.db.savepoint("putItem");
errdefer sp.rollback();
// TODO: savepoint this
try self.deleteItem(hash_value, range_value);
const encrypted_hash = try encryptAndEncode(self.allocator, self.key, hash_value);
defer self.allocator.free(encrypted_hash);
const encrypted_range = if (range_value) |r|
try encryptAndEncode(self.allocator, self.key, r)
else
null;
defer if (encrypted_range != null) self.allocator.free(encrypted_range.?);
const encrypted_data = try encryptAndEncode(self.allocator, self.key, data);
defer self.allocator.free(encrypted_data);
// TODO: hashKey and rangeKey are text, while hashvalue/rangevalue are blobx
// this is to accomodate non-string hash/range value by running a hash
// function over the data, probably base64 encoded. Do we want to do
// something like this?
const insert = try std.fmt.allocPrint(self.allocator,
\\INSERT INTO '{s}' (
\\ hashKey,
\\ rangeKey,
\\ hashValue,
\\ rangeValue,
\\ itemSize,
\\ ObjectJSON
\\ ) VALUES ( ?, ?, ?, ?, ?, ? )
// This syntax doesn't seem to work here? Used ?'s above
// \\ $encrypted_hash{{[]const u8}},
// \\ $encrypted_range{{[]const u8}},
// \\ $encrypted_hash{{[]const u8}},
// \\ $encrypted_range{{[]const u8}},
// \\ $len{{usize}},
// \\ $encrypted_data{{[]const u8}}
// \\ )
, .{self.encrypted_name});
defer self.allocator.free(insert);
var diags = sqlite.Diagnostics{};
// std.debug.print(
// \\====================
// \\Insert to table: {s}
// \\ hashKey: {s}
// \\ rangeKey: {?s}
// \\ hashValue: {s}
// \\ rangeValue: {?s}
// \\ itemSize: {d}
// \\ ObjectJSON: {s}
// \\====================
// , .{
// self.encrypted_name,
// encrypted_hash,
// encrypted_range,
// encrypted_hash,
// encrypted_range,
// encrypted_data.len,
// encrypted_data,
// });
self.db.execDynamic(insert, .{ .diags = &diags }, .{
encrypted_hash,
encrypted_range,
encrypted_hash,
encrypted_range,
encrypted_data.len,
encrypted_data,
}) catch |e| {
std.debug.print("Insert stmt: {s}\n", .{insert});
std.debug.print("SqlLite diags: {s}\n", .{diags});
return e;
};
sp.commit();
}
pub fn deinit(self: *Table) void {
std.crypto.utils.secureZero(u8, &self.table_key);
self.allocator.free(self.table_name);
std.crypto.utils.secureZero(u8, &self.key);
self.allocator.free(self.encrypted_name);
self.allocator.free(self.info_str);
self.info.deinit();
self.allocator.free(self.name);
}
};
// Gets all table names/keys for the account. Caller owns returned array
pub fn tablesForAccount(allocator: std.mem.Allocator, account_id: []const u8) !TableArray {
/// Gets all table names/keys for the account. Caller owns returned array
/// The return value will also provide the opened database. As encryption keys
/// are stored in here, realistically, this will be the first function called
/// every time anything interacts with the database, so this function opens
/// the database for you
pub fn tablesForAccount(allocator: std.mem.Allocator, account_id: []const u8) !AccountTables {
// TODO: This function should take a list of table names, which can then be used
// to filter the query below rather than just grabbing everything
var db = try Account.dbForAccount(allocator, account_id);
defer if (!builtin.is_test) db.deinit();
errdefer if (!builtin.is_test) db.deinit();
const account = try Account.accountForId(allocator, account_id); // This will get us the encryption key needed
defer account.deinit();
const query =
\\SELECT TableName as table_name, TableInfo as table_info FROM dm
\\SELECT TableName as name, TableInfo as info FROM dm
;
var stmt = try db.prepare(query);
defer stmt.deinit();
const rows = try stmt.all(struct {
table_name: []const u8,
table_info: []const u8,
name: []const u8,
info: []const u8,
}, allocator, .{}, .{});
defer allocator.free(rows);
var rc = try TableArray.init(allocator, rows.len);
var rc = try AccountTables.init(allocator, rows.len, db);
errdefer rc.deinit();
// std.debug.print(" \n===\nRow count: {d}\n===\n", .{rows.len});
for (rows, 0..) |row, inx| {
defer allocator.free(row.table_name);
defer allocator.free(row.table_info);
defer allocator.free(row.name);
defer allocator.free(row.info);
const table_name = try encryption.decodeAndDecrypt(
allocator,
account.root_account_key.*,
row.table_name,
row.name,
);
errdefer allocator.free(table_name);
const table_info_str = try encryption.decodeAndDecrypt(
allocator,
account.root_account_key.*,
row.table_info,
row.info,
);
defer allocator.free(table_info_str);
// std.debug.print(" \n===TableInfo: {s}\n===\n", .{table_info_str});
const table_info = try std.json.parseFromSlice(TableInfo, allocator, table_info_str, .{});
defer table_info.deinit();
// errdefer allocator.free(table_info.table_key);
// defer {
// // we don't even really need to defer this...
@ -141,10 +270,14 @@ pub fn tablesForAccount(allocator: std.mem.Allocator, account_id: []const u8) !T
rc.items[inx] = .{
.allocator = allocator,
.table_name = table_name,
.table_key = undefined,
.name = table_name,
.encrypted_name = try allocator.dupe(u8, row.name),
.key = undefined,
.info = try std.json.parseFromSlice(TableInfo, allocator, table_info_str, .{}),
.info_str = table_info_str,
.db = db,
};
try encryption.decodeKey(&rc.items[inx].table_key, table_info.value.table_key);
try encryption.decodeKey(&rc.items[inx].key, rc.items[inx].info.value.table_key);
}
return rc;
}
@ -220,12 +353,12 @@ fn insertIntoDatabaseMetadata(
billing_mode_pay_per_request: bool,
) ![]const u8 {
// TODO: better to do all encryption when request params are parsed?
const encrypted_table_name = try encryption.encryptAndEncode(allocator, account.root_account_key.*, table_name);
const encrypted_table_name = try encryptAndEncode(allocator, account.root_account_key.*, table_name);
errdefer allocator.free(encrypted_table_name);
// We'll json serialize our table_info structure, encrypt, encode, and plow in
const table_info_string = try std.json.stringifyAlloc(allocator, table_info, .{ .whitespace = .indent_2 });
defer allocator.free(table_info_string);
const encrypted_table_info = try encryption.encryptAndEncode(allocator, account.root_account_key.*, table_info_string);
const encrypted_table_info = try encryptAndEncode(allocator, account.root_account_key.*, table_info_string);
defer allocator.free(encrypted_table_info);
try insertIntoDm(db, encrypted_table_name, encrypted_table_info, read_capacity_units, write_capacity_units, billing_mode_pay_per_request);
return encrypted_table_name;
@ -279,7 +412,7 @@ fn insertIntoDm(
});
}
fn testCreateTable(allocator: std.mem.Allocator, account_id: []const u8) !sqlite.Db {
fn testCreateTable(allocator: std.mem.Allocator, account_id: []const u8) !*sqlite.Db {
var db = try Account.dbForAccount(allocator, account_id);
const account = try Account.accountForId(allocator, account_id); // This will get us the encryption key needed
defer account.deinit();
@ -298,7 +431,7 @@ fn testCreateTable(allocator: std.mem.Allocator, account_id: []const u8) !sqlite
encryption.randomEncodedKey(&table_info.table_key);
try createDdbTable(
allocator,
&db,
db,
account,
"MusicCollection",
table_info,
@ -312,18 +445,38 @@ test "can create a table" {
const allocator = std.testing.allocator;
const account_id = "1234";
var db = try testCreateTable(allocator, account_id);
defer allocator.destroy(db);
defer db.deinit();
}
test "can list tables in an account" {
Account.test_retain_db = true;
defer Account.test_retain_db = false;
const allocator = std.testing.allocator;
const account_id = "1234";
var db = try testCreateTable(allocator, account_id);
defer db.deinit();
defer allocator.destroy(db);
defer Account.testDbDeinit();
var table_list = try tablesForAccount(allocator, account_id);
defer table_list.deinit();
try std.testing.expectEqual(@as(usize, 1), table_list.items.len);
try std.testing.expectEqualStrings("MusicCollection", table_list.items[0].table_name);
try std.testing.expectEqualStrings("MusicCollection", table_list.items[0].name);
// std.debug.print(" \n===\nKey: {s}\n===\n", .{std.fmt.fmtSliceHexLower(&table_list.items[0].table_key)});
}
test "can put an item in a table in an account" {
Account.test_retain_db = true;
const allocator = std.testing.allocator;
const account_id = "1234";
var db = try testCreateTable(allocator, account_id);
defer allocator.destroy(db);
defer Account.testDbDeinit();
var table_list = try tablesForAccount(allocator, account_id);
defer table_list.deinit();
try std.testing.expectEqualStrings("MusicCollection", table_list.items[0].name);
var table = table_list.items[0];
try table.putItem("Foo Fighters", "Everlong", "whatevs");
// This should succeed, because putItem is an upsert mechanism
try table.putItem("Foo Fighters", "Everlong", "whatevs");
// TODO: this test should do getItem to verify data
// std.debug.print(" \n===\nKey: {s}\n===\n", .{std.fmt.fmtSliceHexLower(&table_list.items[0].table_key)});
}

View File

@ -5,6 +5,7 @@ pub const salt_length = 256 / 8; // https://crypto.stackexchange.com/a/56132
pub const encoded_salt_length = std.base64.standard.Encoder.calcSize(salt_length);
pub const key_length = std.crypto.aead.salsa_poly.XSalsa20Poly1305.key_length;
pub const encoded_key_length = std.base64.standard.Encoder.calcSize(key_length);
pub const nonce_length = std.crypto.aead.salsa_poly.XSalsa20Poly1305.nonce_length;
/// Generates a random salt of appropriate length
pub fn randomSalt(salt: *[salt_length]u8) void {
@ -56,19 +57,25 @@ pub fn deriveKeyFromEncodedSalt(derived_key: *[key_length]u8, password: []const
return derived_key.*;
}
/// Encrypts data. Use deriveKey function to get a key from password/salt
/// Encrypts data. Use deriveKey function to get a key from password/salt.
/// Uses a random nonce. To supply a nonce instead, use encryptWithNonce
/// Caller owns memory
pub fn encrypt(allocator: std.mem.Allocator, key: [key_length]u8, plaintext: []const u8) ![]const u8 {
var nonce: [std.crypto.aead.salsa_poly.XSalsa20Poly1305.nonce_length]u8 = undefined;
std.crypto.random.bytes(&nonce); // add nonce to beginning of our ciphertext
return try encryptWithNonce(allocator, key, nonce, plaintext);
}
pub fn encryptWithNonce(allocator: std.mem.Allocator, key: [key_length]u8, nonce: [nonce_length]u8, plaintext: []const u8) ![]const u8 {
var ciphertext = try allocator.alloc(
u8,
std.crypto.aead.salsa_poly.XSalsa20Poly1305.nonce_length + std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length + plaintext.len,
nonce_length + std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length + plaintext.len,
);
errdefer allocator.free(ciphertext);
// Create the nonce
const nonce_length = std.crypto.aead.salsa_poly.XSalsa20Poly1305.nonce_length;
std.crypto.random.bytes(ciphertext[0..nonce_length]); // add nonce to beginning of our ciphertext
const nonce = ciphertext[0..nonce_length];
@memcpy(ciphertext[0..nonce_length], nonce[0..]); // add nonce to beginning of our ciphertext
const nonce_copy = ciphertext[0..nonce_length];
const tag = ciphertext[nonce_length .. nonce_length + std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length];
const c = ciphertext[nonce_length + std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length ..];
@ -78,7 +85,7 @@ pub fn encrypt(allocator: std.mem.Allocator, key: [key_length]u8, plaintext: []c
tag,
plaintext,
"ad",
nonce.*,
nonce_copy.*,
key,
);
return ciphertext;
@ -95,14 +102,24 @@ pub fn encryptAndEncode(allocator: std.mem.Allocator, key: [key_length]u8, plain
return Encoder.encode(encoded_ciphertext, ciphertext);
}
/// Encrypts data. Use deriveKey function to get a key from password/salt
/// Caller owns memory
pub fn encryptAndEncodeWithNonce(allocator: std.mem.Allocator, key: [key_length]u8, nonce: [nonce_length]u8, plaintext: []const u8) ![]const u8 {
const ciphertext = try encryptWithNonce(allocator, key, nonce, plaintext);
defer allocator.free(ciphertext);
const Encoder = std.base64.standard.Encoder;
var encoded_ciphertext = try allocator.alloc(u8, Encoder.calcSize(ciphertext.len));
errdefer allocator.free(encoded_ciphertext);
return Encoder.encode(encoded_ciphertext, ciphertext);
}
/// Decrypts data. Use deriveKey function to get a key from password/salt
pub fn decrypt(allocator: std.mem.Allocator, key: [key_length]u8, ciphertext: []const u8) ![]const u8 {
var plaintext = try allocator.alloc(
u8,
ciphertext.len - std.crypto.aead.salsa_poly.XSalsa20Poly1305.nonce_length - std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length,
ciphertext.len - nonce_length - std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length,
);
errdefer allocator.free(plaintext);
const nonce_length = std.crypto.aead.salsa_poly.XSalsa20Poly1305.nonce_length;
const nonce = ciphertext[0..nonce_length].*;
const tag = ciphertext[nonce_length .. nonce_length + std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length].*;
const c = ciphertext[nonce_length + std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length ..];
@ -164,6 +181,28 @@ test "can encrypt and decrypt data with simpler api but without KDF" {
try std.testing.expectEqualStrings(plaintext, decrypted_text[0..]);
}
test "can encrypt twice with same result" {
const allocator = std.testing.allocator;
const plaintext = "Hello, Zig!";
var key: [key_length]u8 = undefined;
var nonce: [nonce_length]u8 = undefined;
var encoded_key: [encoded_key_length]u8 = undefined;
randomEncodedKey(encoded_key[0..]);
std.crypto.random.bytes(&nonce);
// std.testing.log_level = .debug;
std.log.debug("Encoded key: {s}", .{encoded_key});
try decodeKey(&key, encoded_key);
const ciphertext = try encryptWithNonce(allocator, key, nonce, plaintext);
defer allocator.free(ciphertext);
std.log.debug("Ciphertext: {s}\n", .{std.fmt.fmtSliceHexLower(ciphertext)});
const ciphertext2 = try encryptWithNonce(allocator, key, nonce, plaintext);
defer allocator.free(ciphertext2);
std.log.debug("Ciphertext: {s}\n", .{std.fmt.fmtSliceHexLower(ciphertext2)});
try std.testing.expectEqualSlices(u8, ciphertext, ciphertext2);
}
// test "can encrypt and decrypt data" {
// var tag: [std.crypto.aead.salsa_poly.XSalsa20Poly1305.tag_length]u8 = undefined;
// const password = "mySecurePassword";