diff --git a/src/net/http.zig b/src/net/http.zig index b9664c3..8942fad 100644 --- a/src/net/http.zig +++ b/src/net/http.zig @@ -51,16 +51,26 @@ pub const Client = struct { fn request(self: *Client, method: std.http.Method, url: []const u8, body: ?[]const u8, extra_headers: []const std.http.Header) HttpError!Response { var attempt: u8 = 0; while (true) : (attempt += 1) { - if (self.doRequest(method, url, body, extra_headers)) |response| { - return classifyResponse(response); - } else |_| { + const response = self.doRequest(method, url, body, extra_headers) catch { if (attempt >= self.max_retries) return HttpError.RequestFailed; - const backoff = self.base_backoff_ms * std.math.shl(u64, 1, attempt); - std.Thread.sleep(backoff * std.time.ns_per_ms); - } + self.backoffSleep(attempt); + continue; + }; + return classifyResponse(response) catch |err| { + if (err == HttpError.ServerError and attempt < self.max_retries) { + self.backoffSleep(attempt); + continue; + } + return err; + }; } } + fn backoffSleep(self: *Client, attempt: u8) void { + const backoff = self.base_backoff_ms * std.math.shl(u64, 1, attempt); + std.Thread.sleep(backoff * std.time.ns_per_ms); + } + fn doRequest(self: *Client, method: std.http.Method, url: []const u8, body: ?[]const u8, extra_headers: []const std.http.Header) HttpError!Response { var aw: std.Io.Writer.Allocating = .init(self.allocator); @@ -88,18 +98,23 @@ pub const Client = struct { } fn classifyResponse(response: Response) HttpError!Response { - return switch (response.status) { - .ok => response, - .too_many_requests => HttpError.RateLimited, - .unauthorized, .forbidden => HttpError.Unauthorized, - .not_found => HttpError.NotFound, - .internal_server_error, .bad_gateway, .service_unavailable, .gateway_timeout => HttpError.ServerError, - else => HttpError.InvalidResponse, - }; + switch (response.status) { + .ok => return response, + else => { + response.allocator.free(response.body); + return switch (response.status) { + .too_many_requests => HttpError.RateLimited, + .unauthorized, .forbidden => HttpError.Unauthorized, + .not_found => HttpError.NotFound, + .internal_server_error, .bad_gateway, .service_unavailable, .gateway_timeout => HttpError.ServerError, + else => HttpError.InvalidResponse, + }; + }, + } } }; -/// Build a URL with query parameters. +/// Build a URL with query parameters. Values are percent-encoded per RFC 3986. pub fn buildUrl( allocator: std.mem.Allocator, base: []const u8, @@ -113,20 +128,26 @@ pub fn buildUrl( try aw.writer.writeByte(if (i == 0) '?' else '&'); try aw.writer.writeAll(param[0]); try aw.writer.writeByte('='); - for (param[1]) |c| { - switch (c) { - ' ' => try aw.writer.writeAll("%20"), - '&' => try aw.writer.writeAll("%26"), - '=' => try aw.writer.writeAll("%3D"), - '+' => try aw.writer.writeAll("%2B"), - else => try aw.writer.writeByte(c), - } - } + try std.Uri.Component.percentEncode(&aw.writer, param[1], isQueryValueChar); } return aw.toOwnedSlice(); } +/// RFC 3986 query-safe characters, excluding '&' and '=' which delimit +/// key=value pairs within the query string. +fn isQueryValueChar(c: u8) bool { + return switch (c) { + // Unreserved characters (RFC 3986 section 2.3) + 'A'...'Z', 'a'...'z', '0'...'9', '-', '.', '_', '~' => true, + // Sub-delimiters safe in query values (excludes '&' and '=') + '!', '$', '\'', '(', ')', '*', '+', ',', ';' => true, + // Additional query/path characters + ':', '@', '/', '?' => true, + else => false, + }; +} + test "buildUrl" { const allocator = std.testing.allocator; const url = try buildUrl(allocator, "https://api.example.com/v1/data", &.{