remove error-based control flow/handle ipv6
This commit is contained in:
parent
8d9204ab8d
commit
703e77c19d
2 changed files with 26 additions and 28 deletions
|
|
@ -48,15 +48,18 @@ pub fn init(allocator: std.mem.Allocator, config: Config) !RateLimiter {
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn check(self: *RateLimiter, ip: []const u8) !void {
|
/// Checks if a request from the given IP should be accepted.
|
||||||
|
/// Note: Calling this function consumes a token from the bucket, even if it returns false.
|
||||||
|
/// Returns true if the request should be accepted, false if rate limited.
|
||||||
|
pub fn shouldAcceptRequest(self: *RateLimiter, ip: []const u8) bool {
|
||||||
self.mutex.lock();
|
self.mutex.lock();
|
||||||
defer self.mutex.unlock();
|
defer self.mutex.unlock();
|
||||||
|
|
||||||
const now = std.time.milliTimestamp();
|
const now = std.time.milliTimestamp();
|
||||||
|
|
||||||
const result = try self.buckets.getOrPut(ip);
|
const result = self.buckets.getOrPut(ip) catch return false;
|
||||||
if (!result.found_existing) {
|
if (!result.found_existing) {
|
||||||
const ip_copy = try self.allocator.dupe(u8, ip);
|
const ip_copy = self.allocator.dupe(u8, ip) catch return false;
|
||||||
result.key_ptr.* = ip_copy;
|
result.key_ptr.* = ip_copy;
|
||||||
result.value_ptr.* = TokenBucket{
|
result.value_ptr.* = TokenBucket{
|
||||||
.tokens = @floatFromInt(self.config.capacity),
|
.tokens = @floatFromInt(self.config.capacity),
|
||||||
|
|
@ -68,9 +71,7 @@ pub fn check(self: *RateLimiter, ip: []const u8) !void {
|
||||||
var bucket = result.value_ptr;
|
var bucket = result.value_ptr;
|
||||||
bucket.refill(now, self.config.refill_rate, self.config.refill_interval_ms);
|
bucket.refill(now, self.config.refill_rate, self.config.refill_interval_ms);
|
||||||
|
|
||||||
if (!bucket.consume(1.0)) {
|
return bucket.consume(1.0);
|
||||||
return error.RateLimitExceeded;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn deinit(self: *RateLimiter) void {
|
pub fn deinit(self: *RateLimiter) void {
|
||||||
|
|
@ -91,7 +92,7 @@ test "rate limiter allows requests within capacity" {
|
||||||
|
|
||||||
var i: usize = 0;
|
var i: usize = 0;
|
||||||
while (i < 10) : (i += 1) {
|
while (i < 10) : (i += 1) {
|
||||||
try limiter.check("1.2.3.4");
|
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -105,10 +106,10 @@ test "rate limiter blocks after capacity exhausted" {
|
||||||
|
|
||||||
var i: usize = 0;
|
var i: usize = 0;
|
||||||
while (i < 5) : (i += 1) {
|
while (i < 5) : (i += 1) {
|
||||||
try limiter.check("1.2.3.4");
|
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
|
||||||
}
|
}
|
||||||
|
|
||||||
try std.testing.expectError(error.RateLimitExceeded, limiter.check("1.2.3.4"));
|
try std.testing.expect(!limiter.shouldAcceptRequest("1.2.3.4"));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "rate limiter refills tokens over time" {
|
test "rate limiter refills tokens over time" {
|
||||||
|
|
@ -121,14 +122,14 @@ test "rate limiter refills tokens over time" {
|
||||||
|
|
||||||
var i: usize = 0;
|
var i: usize = 0;
|
||||||
while (i < 10) : (i += 1) {
|
while (i < 10) : (i += 1) {
|
||||||
try limiter.check("1.2.3.4");
|
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
|
||||||
}
|
}
|
||||||
|
|
||||||
try std.testing.expectError(error.RateLimitExceeded, limiter.check("1.2.3.4"));
|
try std.testing.expect(!limiter.shouldAcceptRequest("1.2.3.4"));
|
||||||
|
|
||||||
std.Thread.sleep(250 * std.time.ns_per_ms);
|
std.Thread.sleep(250 * std.time.ns_per_ms);
|
||||||
|
|
||||||
try limiter.check("1.2.3.4");
|
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
|
||||||
}
|
}
|
||||||
|
|
||||||
test "rate limiter tracks different IPs separately" {
|
test "rate limiter tracks different IPs separately" {
|
||||||
|
|
@ -139,11 +140,11 @@ test "rate limiter tracks different IPs separately" {
|
||||||
});
|
});
|
||||||
defer limiter.deinit();
|
defer limiter.deinit();
|
||||||
|
|
||||||
try limiter.check("1.2.3.4");
|
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
|
||||||
try limiter.check("1.2.3.4");
|
try std.testing.expect(limiter.shouldAcceptRequest("1.2.3.4"));
|
||||||
|
|
||||||
try std.testing.expectError(error.RateLimitExceeded, limiter.check("1.2.3.4"));
|
try std.testing.expect(!limiter.shouldAcceptRequest("1.2.3.4"));
|
||||||
|
|
||||||
try limiter.check("5.6.7.8");
|
try std.testing.expect(limiter.shouldAcceptRequest("5.6.7.8"));
|
||||||
try limiter.check("5.6.7.8");
|
try std.testing.expect(limiter.shouldAcceptRequest("5.6.7.8"));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,18 +3,15 @@ const httpz = @import("httpz");
|
||||||
const RateLimiter = @import("RateLimiter.zig");
|
const RateLimiter = @import("RateLimiter.zig");
|
||||||
|
|
||||||
pub fn rateLimitMiddleware(limiter: *RateLimiter, req: *httpz.Request, res: *httpz.Response) !void {
|
pub fn rateLimitMiddleware(limiter: *RateLimiter, req: *httpz.Request, res: *httpz.Response) !void {
|
||||||
const ip = req.address.in.sa.addr;
|
var ip_buf: [45]u8 = undefined; // https://stackoverflow.com/a/166157
|
||||||
var ip_buf: [16]u8 = undefined;
|
const ip_str = try std.fmt.bufPrint(
|
||||||
const ip_str = try std.fmt.bufPrint(&ip_buf, "{d}.{d}.{d}.{d}", .{
|
&ip_buf,
|
||||||
ip & 0xFF,
|
"{f}",
|
||||||
(ip >> 8) & 0xFF,
|
.{req.address},
|
||||||
(ip >> 16) & 0xFF,
|
);
|
||||||
(ip >> 24) & 0xFF,
|
|
||||||
});
|
|
||||||
|
|
||||||
limiter.check(ip_str) catch {
|
if (!limiter.shouldAcceptRequest(ip_str)) {
|
||||||
res.status = 429;
|
res.status = 429;
|
||||||
res.body = "Too Many Requests";
|
res.body = "Too Many Requests";
|
||||||
return;
|
}
|
||||||
};
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Reference in a new issue