remove error-based control flow/handle ipv6
This commit is contained in:
parent
8d9204ab8d
commit
703e77c19d
2 changed files with 26 additions and 28 deletions
|
|
@ -48,15 +48,18 @@ pub fn init(allocator: std.mem.Allocator, config: Config) !RateLimiter {
|
|||
};
|
||||
}
|
||||
|
||||
pub fn check(self: *RateLimiter, ip: []const u8) !void {
|
||||
/// Checks if a request from the given IP should be accepted.
|
||||
/// Note: Calling this function consumes a token from the bucket, even if it returns false.
|
||||
/// Returns true if the request should be accepted, false if rate limited.
|
||||
pub fn shouldAcceptRequest(self: *RateLimiter, ip: []const u8) bool {
|
||||
self.mutex.lock();
|
||||
defer self.mutex.unlock();
|
||||
|
||||
const now = std.time.milliTimestamp();
|
||||
|
||||
const result = try self.buckets.getOrPut(ip);
|
||||
const result = self.buckets.getOrPut(ip) catch return false;
|
||||
if (!result.found_existing) {
|
||||
const ip_copy = try self.allocator.dupe(u8, ip);
|
||||
const ip_copy = self.allocator.dupe(u8, ip) catch return false;
|
||||
result.key_ptr.* = ip_copy;
|
||||
result.value_ptr.* = TokenBucket{
|
||||
.tokens = @floatFromInt(self.config.capacity),
|
||||
|
|
@ -68,9 +71,7 @@ pub fn check(self: *RateLimiter, ip: []const u8) !void {
|
|||
var bucket = result.value_ptr;
|
||||
bucket.refill(now, self.config.refill_rate, self.config.refill_interval_ms);
|
||||
|
||||
if (!bucket.consume(1.0)) {
|
||||
return error.RateLimitExceeded;
|
||||
}
|
||||
return bucket.consume(1.0);
|
||||
}
|
||||
|
||||
pub fn deinit(self: *RateLimiter) void {
|
||||
|
|
@ -91,7 +92,7 @@ test "rate limiter allows requests within capacity" {
|
|||
|
||||
var i: usize = 0;
|
||||
while (i < 10) : (i += 1) {
|
||||
try limiter.check("1.2.3.4");
|
||||
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -105,10 +106,10 @@ test "rate limiter blocks after capacity exhausted" {
|
|||
|
||||
var i: usize = 0;
|
||||
while (i < 5) : (i += 1) {
|
||||
try limiter.check("1.2.3.4");
|
||||
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
|
||||
}
|
||||
|
||||
try std.testing.expectError(error.RateLimitExceeded, limiter.check("1.2.3.4"));
|
||||
try std.testing.expect(!limiter.shouldAcceptRequest("1.2.3.4"));
|
||||
}
|
||||
|
||||
test "rate limiter refills tokens over time" {
|
||||
|
|
@ -121,14 +122,14 @@ test "rate limiter refills tokens over time" {
|
|||
|
||||
var i: usize = 0;
|
||||
while (i < 10) : (i += 1) {
|
||||
try limiter.check("1.2.3.4");
|
||||
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
|
||||
}
|
||||
|
||||
try std.testing.expectError(error.RateLimitExceeded, limiter.check("1.2.3.4"));
|
||||
try std.testing.expect(!limiter.shouldAcceptRequest("1.2.3.4"));
|
||||
|
||||
std.Thread.sleep(250 * std.time.ns_per_ms);
|
||||
|
||||
try limiter.check("1.2.3.4");
|
||||
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
|
||||
}
|
||||
|
||||
test "rate limiter tracks different IPs separately" {
|
||||
|
|
@ -139,11 +140,11 @@ test "rate limiter tracks different IPs separately" {
|
|||
});
|
||||
defer limiter.deinit();
|
||||
|
||||
try limiter.check("1.2.3.4");
|
||||
try limiter.check("1.2.3.4");
|
||||
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
|
||||
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
|
||||
|
||||
try std.testing.expectError(error.RateLimitExceeded, limiter.check("1.2.3.4"));
|
||||
try std.testing.expect(!limiter.shouldAcceptRequest("1.2.3.4"));
|
||||
|
||||
try limiter.check("5.6.7.8");
|
||||
try limiter.check("5.6.7.8");
|
||||
try std.testing.expect(limiter.shouldAcceptRequest("5.6.7.8"));
|
||||
try std.testing.expect(limiter.shouldAcceptRequest("5.6.7.8"));
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,18 +3,15 @@ const httpz = @import("httpz");
|
|||
const RateLimiter = @import("RateLimiter.zig");
|
||||
|
||||
pub fn rateLimitMiddleware(limiter: *RateLimiter, req: *httpz.Request, res: *httpz.Response) !void {
|
||||
const ip = req.address.in.sa.addr;
|
||||
var ip_buf: [16]u8 = undefined;
|
||||
const ip_str = try std.fmt.bufPrint(&ip_buf, "{d}.{d}.{d}.{d}", .{
|
||||
ip & 0xFF,
|
||||
(ip >> 8) & 0xFF,
|
||||
(ip >> 16) & 0xFF,
|
||||
(ip >> 24) & 0xFF,
|
||||
});
|
||||
var ip_buf: [45]u8 = undefined; // https://stackoverflow.com/a/166157
|
||||
const ip_str = try std.fmt.bufPrint(
|
||||
&ip_buf,
|
||||
"{f}",
|
||||
.{req.address},
|
||||
);
|
||||
|
||||
limiter.check(ip_str) catch {
|
||||
if (!limiter.shouldAcceptRequest(ip_str)) {
|
||||
res.status = 429;
|
||||
res.body = "Too Many Requests";
|
||||
return;
|
||||
};
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Reference in a new issue