remove error-based control flow/handle ipv6

This commit is contained in:
Emil Lerch 2025-12-18 13:49:32 -08:00
parent 8d9204ab8d
commit 703e77c19d
Signed by: lobo
GPG key ID: A7B62D657EF764F8
2 changed files with 26 additions and 28 deletions

View file

@ -48,15 +48,18 @@ pub fn init(allocator: std.mem.Allocator, config: Config) !RateLimiter {
};
}
pub fn check(self: *RateLimiter, ip: []const u8) !void {
/// Checks if a request from the given IP should be accepted.
/// Note: Calling this function consumes a token from the bucket, even if it returns false.
/// Returns true if the request should be accepted, false if rate limited.
pub fn shouldAcceptRequest(self: *RateLimiter, ip: []const u8) bool {
self.mutex.lock();
defer self.mutex.unlock();
const now = std.time.milliTimestamp();
const result = try self.buckets.getOrPut(ip);
const result = self.buckets.getOrPut(ip) catch return false;
if (!result.found_existing) {
const ip_copy = try self.allocator.dupe(u8, ip);
const ip_copy = self.allocator.dupe(u8, ip) catch return false;
result.key_ptr.* = ip_copy;
result.value_ptr.* = TokenBucket{
.tokens = @floatFromInt(self.config.capacity),
@ -68,9 +71,7 @@ pub fn check(self: *RateLimiter, ip: []const u8) !void {
var bucket = result.value_ptr;
bucket.refill(now, self.config.refill_rate, self.config.refill_interval_ms);
if (!bucket.consume(1.0)) {
return error.RateLimitExceeded;
}
return bucket.consume(1.0);
}
pub fn deinit(self: *RateLimiter) void {
@ -91,7 +92,7 @@ test "rate limiter allows requests within capacity" {
var i: usize = 0;
while (i < 10) : (i += 1) {
try limiter.check("1.2.3.4");
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
}
}
@ -105,10 +106,10 @@ test "rate limiter blocks after capacity exhausted" {
var i: usize = 0;
while (i < 5) : (i += 1) {
try limiter.check("1.2.3.4");
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
}
try std.testing.expectError(error.RateLimitExceeded, limiter.check("1.2.3.4"));
try std.testing.expect(!limiter.shouldAcceptRequest("1.2.3.4"));
}
test "rate limiter refills tokens over time" {
@ -121,14 +122,14 @@ test "rate limiter refills tokens over time" {
var i: usize = 0;
while (i < 10) : (i += 1) {
try limiter.check("1.2.3.4");
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
}
try std.testing.expectError(error.RateLimitExceeded, limiter.check("1.2.3.4"));
try std.testing.expect(!limiter.shouldAcceptRequest("1.2.3.4"));
std.Thread.sleep(250 * std.time.ns_per_ms);
try limiter.check("1.2.3.4");
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
}
test "rate limiter tracks different IPs separately" {
@ -139,11 +140,11 @@ test "rate limiter tracks different IPs separately" {
});
defer limiter.deinit();
try limiter.check("1.2.3.4");
try limiter.check("1.2.3.4");
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
try std.testing.expectError(error.RateLimitExceeded, limiter.check("1.2.3.4"));
try std.testing.expect(!limiter.shouldAcceptRequest("1.2.3.4"));
try limiter.check("5.6.7.8");
try limiter.check("5.6.7.8");
try std.testing.expect(limiter.shouldAcceptRequest("5.6.7.8"));
try std.testing.expect(limiter.shouldAcceptRequest("5.6.7.8"));
}

View file

@ -3,18 +3,15 @@ const httpz = @import("httpz");
const RateLimiter = @import("RateLimiter.zig");
pub fn rateLimitMiddleware(limiter: *RateLimiter, req: *httpz.Request, res: *httpz.Response) !void {
const ip = req.address.in.sa.addr;
var ip_buf: [16]u8 = undefined;
const ip_str = try std.fmt.bufPrint(&ip_buf, "{d}.{d}.{d}.{d}", .{
ip & 0xFF,
(ip >> 8) & 0xFF,
(ip >> 16) & 0xFF,
(ip >> 24) & 0xFF,
});
var ip_buf: [45]u8 = undefined; // https://stackoverflow.com/a/166157
const ip_str = try std.fmt.bufPrint(
&ip_buf,
"{f}",
.{req.address},
);
limiter.check(ip_str) catch {
if (!limiter.shouldAcceptRequest(ip_str)) {
res.status = 429;
res.body = "Too Many Requests";
return;
};
}
}