From 703e77c19d82e4d66c0cd68d0b77c0c31956b616 Mon Sep 17 00:00:00 2001 From: Emil Lerch Date: Thu, 18 Dec 2025 13:49:32 -0800 Subject: [PATCH] remove error-based control flow/handle ipv6 --- src/http/RateLimiter.zig | 35 ++++++++++++++++++----------------- src/http/middleware.zig | 19 ++++++++----------- 2 files changed, 26 insertions(+), 28 deletions(-) diff --git a/src/http/RateLimiter.zig b/src/http/RateLimiter.zig index 0fb75cf..4b4ab73 100644 --- a/src/http/RateLimiter.zig +++ b/src/http/RateLimiter.zig @@ -48,15 +48,18 @@ pub fn init(allocator: std.mem.Allocator, config: Config) !RateLimiter { }; } -pub fn check(self: *RateLimiter, ip: []const u8) !void { +/// Checks if a request from the given IP should be accepted. +/// Note: Calling this function consumes a token from the bucket, even if it returns false. +/// Returns true if the request should be accepted, false if rate limited. +pub fn shouldAcceptRequest(self: *RateLimiter, ip: []const u8) bool { self.mutex.lock(); defer self.mutex.unlock(); const now = std.time.milliTimestamp(); - const result = try self.buckets.getOrPut(ip); + const result = self.buckets.getOrPut(ip) catch return false; if (!result.found_existing) { - const ip_copy = try self.allocator.dupe(u8, ip); + const ip_copy = self.allocator.dupe(u8, ip) catch return false; result.key_ptr.* = ip_copy; result.value_ptr.* = TokenBucket{ .tokens = @floatFromInt(self.config.capacity), @@ -68,9 +71,7 @@ pub fn check(self: *RateLimiter, ip: []const u8) !void { var bucket = result.value_ptr; bucket.refill(now, self.config.refill_rate, self.config.refill_interval_ms); - if (!bucket.consume(1.0)) { - return error.RateLimitExceeded; - } + return bucket.consume(1.0); } pub fn deinit(self: *RateLimiter) void { @@ -91,7 +92,7 @@ test "rate limiter allows requests within capacity" { var i: usize = 0; while (i < 10) : (i += 1) { - try limiter.check("1.2.3.4"); + try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4")); } } @@ -105,10 +106,10 @@ test "rate limiter blocks after capacity exhausted" { var i: usize = 0; while (i < 5) : (i += 1) { - try limiter.check("1.2.3.4"); + try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4")); } - try std.testing.expectError(error.RateLimitExceeded, limiter.check("1.2.3.4")); + try std.testing.expect(!limiter.shouldAcceptRequest("1.2.3.4")); } test "rate limiter refills tokens over time" { @@ -121,14 +122,14 @@ test "rate limiter refills tokens over time" { var i: usize = 0; while (i < 10) : (i += 1) { - try limiter.check("1.2.3.4"); + try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4")); } - try std.testing.expectError(error.RateLimitExceeded, limiter.check("1.2.3.4")); + try std.testing.expect(!limiter.shouldAcceptRequest("1.2.3.4")); std.Thread.sleep(250 * std.time.ns_per_ms); - try limiter.check("1.2.3.4"); + try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4")); } test "rate limiter tracks different IPs separately" { @@ -139,11 +140,11 @@ test "rate limiter tracks different IPs separately" { }); defer limiter.deinit(); - try limiter.check("1.2.3.4"); - try limiter.check("1.2.3.4"); + try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4")); + try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4")); - try std.testing.expectError(error.RateLimitExceeded, limiter.check("1.2.3.4")); + try std.testing.expect(!limiter.shouldAcceptRequest("1.2.3.4")); - try limiter.check("5.6.7.8"); - try limiter.check("5.6.7.8"); + try std.testing.expect(limiter.shouldAcceptRequest("5.6.7.8")); + try std.testing.expect(limiter.shouldAcceptRequest("5.6.7.8")); } diff --git a/src/http/middleware.zig b/src/http/middleware.zig index 8c404f7..203976e 100644 --- a/src/http/middleware.zig +++ b/src/http/middleware.zig @@ -3,18 +3,15 @@ const httpz = @import("httpz"); const RateLimiter = @import("RateLimiter.zig"); pub fn rateLimitMiddleware(limiter: *RateLimiter, req: *httpz.Request, res: *httpz.Response) !void { - const ip = req.address.in.sa.addr; - var ip_buf: [16]u8 = undefined; - const ip_str = try std.fmt.bufPrint(&ip_buf, "{d}.{d}.{d}.{d}", .{ - ip & 0xFF, - (ip >> 8) & 0xFF, - (ip >> 16) & 0xFF, - (ip >> 24) & 0xFF, - }); + var ip_buf: [45]u8 = undefined; // https://stackoverflow.com/a/166157 + const ip_str = try std.fmt.bufPrint( + &ip_buf, + "{f}", + .{req.address}, + ); - limiter.check(ip_str) catch { + if (!limiter.shouldAcceptRequest(ip_str)) { res.status = 429; res.body = "Too Many Requests"; - return; - }; + } }