remove error-based control flow/handle ipv6

This commit is contained in:
Emil Lerch 2025-12-18 13:49:32 -08:00
parent 8d9204ab8d
commit 703e77c19d
Signed by: lobo
GPG key ID: A7B62D657EF764F8
2 changed files with 26 additions and 28 deletions

View file

@ -48,15 +48,18 @@ pub fn init(allocator: std.mem.Allocator, config: Config) !RateLimiter {
}; };
} }
pub fn check(self: *RateLimiter, ip: []const u8) !void { /// Checks if a request from the given IP should be accepted.
/// Note: Calling this function consumes a token from the bucket, even if it returns false.
/// Returns true if the request should be accepted, false if rate limited.
pub fn shouldAcceptRequest(self: *RateLimiter, ip: []const u8) bool {
self.mutex.lock(); self.mutex.lock();
defer self.mutex.unlock(); defer self.mutex.unlock();
const now = std.time.milliTimestamp(); const now = std.time.milliTimestamp();
const result = try self.buckets.getOrPut(ip); const result = self.buckets.getOrPut(ip) catch return false;
if (!result.found_existing) { if (!result.found_existing) {
const ip_copy = try self.allocator.dupe(u8, ip); const ip_copy = self.allocator.dupe(u8, ip) catch return false;
result.key_ptr.* = ip_copy; result.key_ptr.* = ip_copy;
result.value_ptr.* = TokenBucket{ result.value_ptr.* = TokenBucket{
.tokens = @floatFromInt(self.config.capacity), .tokens = @floatFromInt(self.config.capacity),
@ -68,9 +71,7 @@ pub fn check(self: *RateLimiter, ip: []const u8) !void {
var bucket = result.value_ptr; var bucket = result.value_ptr;
bucket.refill(now, self.config.refill_rate, self.config.refill_interval_ms); bucket.refill(now, self.config.refill_rate, self.config.refill_interval_ms);
if (!bucket.consume(1.0)) { return bucket.consume(1.0);
return error.RateLimitExceeded;
}
} }
pub fn deinit(self: *RateLimiter) void { pub fn deinit(self: *RateLimiter) void {
@ -91,7 +92,7 @@ test "rate limiter allows requests within capacity" {
var i: usize = 0; var i: usize = 0;
while (i < 10) : (i += 1) { while (i < 10) : (i += 1) {
try limiter.check("1.2.3.4"); try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
} }
} }
@ -105,10 +106,10 @@ test "rate limiter blocks after capacity exhausted" {
var i: usize = 0; var i: usize = 0;
while (i < 5) : (i += 1) { while (i < 5) : (i += 1) {
try limiter.check("1.2.3.4"); try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
} }
try std.testing.expectError(error.RateLimitExceeded, limiter.check("1.2.3.4")); try std.testing.expect(!limiter.shouldAcceptRequest("1.2.3.4"));
} }
test "rate limiter refills tokens over time" { test "rate limiter refills tokens over time" {
@ -121,14 +122,14 @@ test "rate limiter refills tokens over time" {
var i: usize = 0; var i: usize = 0;
while (i < 10) : (i += 1) { while (i < 10) : (i += 1) {
try limiter.check("1.2.3.4"); try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
} }
try std.testing.expectError(error.RateLimitExceeded, limiter.check("1.2.3.4")); try std.testing.expect(!limiter.shouldAcceptRequest("1.2.3.4"));
std.Thread.sleep(250 * std.time.ns_per_ms); std.Thread.sleep(250 * std.time.ns_per_ms);
try limiter.check("1.2.3.4"); try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
} }
test "rate limiter tracks different IPs separately" { test "rate limiter tracks different IPs separately" {
@ -139,11 +140,11 @@ test "rate limiter tracks different IPs separately" {
}); });
defer limiter.deinit(); defer limiter.deinit();
try limiter.check("1.2.3.4"); try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
try limiter.check("1.2.3.4"); try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
try std.testing.expectError(error.RateLimitExceeded, limiter.check("1.2.3.4")); try std.testing.expect(!limiter.shouldAcceptRequest("1.2.3.4"));
try limiter.check("5.6.7.8"); try std.testing.expect(limiter.shouldAcceptRequest("5.6.7.8"));
try limiter.check("5.6.7.8"); try std.testing.expect(limiter.shouldAcceptRequest("5.6.7.8"));
} }

View file

@ -3,18 +3,15 @@ const httpz = @import("httpz");
const RateLimiter = @import("RateLimiter.zig"); const RateLimiter = @import("RateLimiter.zig");
pub fn rateLimitMiddleware(limiter: *RateLimiter, req: *httpz.Request, res: *httpz.Response) !void { pub fn rateLimitMiddleware(limiter: *RateLimiter, req: *httpz.Request, res: *httpz.Response) !void {
const ip = req.address.in.sa.addr; var ip_buf: [45]u8 = undefined; // https://stackoverflow.com/a/166157
var ip_buf: [16]u8 = undefined; const ip_str = try std.fmt.bufPrint(
const ip_str = try std.fmt.bufPrint(&ip_buf, "{d}.{d}.{d}.{d}", .{ &ip_buf,
ip & 0xFF, "{f}",
(ip >> 8) & 0xFF, .{req.address},
(ip >> 16) & 0xFF, );
(ip >> 24) & 0xFF,
});
limiter.check(ip_str) catch { if (!limiter.shouldAcceptRequest(ip_str)) {
res.status = 429; res.status = 429;
res.body = "Too Many Requests"; res.body = "Too Many Requests";
return; }
};
} }