diff --git a/src/Config.zig b/src/Config.zig index 675b24d..a7bdfa3 100644 --- a/src/Config.zig +++ b/src/Config.zig @@ -2,13 +2,19 @@ const std = @import("std"); const Config = @This(); +pub const GeoIpFallback = enum { + ipwhois, + ip2location, +}; + listen_host: []const u8, listen_port: u16, cache_size: usize, cache_dir: []const u8, /// GeoLite2 is used for GeoIP (IP -> geographic location) -/// IP2Location is a fallback if IP is not found in this db +/// When GeoLite2 data is missing or low-confidence, the configured +/// fallback provider is used (ipwho.is by default, or IP2Location) geolite_path: []const u8, /// Geocache file stores location lookups @@ -16,11 +22,18 @@ geolite_path: []const u8, /// a web service from Nominatum (https://nominatim.org/) is used geocache_file: ?[]const u8, +/// Which online service to use as a fallback when GeoLite2 has no data. +/// Default: ipwhois (ipwho.is). Alternative: ip2location (ip2location.io) +geoip_fallback: GeoIpFallback, + /// If provided, when GeoLite2 is missing data, https://www.ip2location.com/ /// can be used. This will also be cached in the cached file ip2location_api_key: ?[]const u8, ip2location_cache_file: []const u8, +/// Cache file for ipwho.is lookups +ipwhois_cache_file: []const u8, + pub fn load(allocator: std.mem.Allocator) !Config { var env = try std.process.getEnvMap(allocator); defer env.deinit(); @@ -54,6 +67,12 @@ pub fn load(allocator: std.mem.Allocator) !Config { }); }, .geocache_file = if (env.get("WTTR_GEOCACHE_FILE")) |v| try allocator.dupe(u8, v) else try std.fs.path.join(allocator, &[_][]const u8{ default_cache_dir, "geocache.json" }), + .geoip_fallback = blk: { + if (env.get("WTTR_GEOIP_FALLBACK")) |v| { + if (std.mem.eql(u8, v, "ip2location")) break :blk .ip2location; + } + break :blk .ipwhois; + }, .ip2location_api_key = if (env.get("IP2LOCATION_API_KEY")) |v| try allocator.dupe(u8, v) else null, .ip2location_cache_file = blk: { if (env.get("IP2LOCATION_CACHE_FILE")) |v| { @@ -61,6 +80,12 @@ pub fn load(allocator: std.mem.Allocator) !Config { } break :blk try std.fmt.allocPrint(allocator, "{s}/ip2location.cache", .{env.get("WTTR_CACHE_DIR") orelse default_cache_dir}); }, + .ipwhois_cache_file = blk: { + if (env.get("IPWHOIS_CACHE_FILE")) |v| { + break :blk try allocator.dupe(u8, v); + } + break :blk try std.fmt.allocPrint(allocator, "{s}/ipwhois.cache", .{env.get("WTTR_CACHE_DIR") orelse default_cache_dir}); + }, }; } @@ -71,6 +96,7 @@ pub fn deinit(self: Config, allocator: std.mem.Allocator) void { if (self.geocache_file) |f| allocator.free(f); if (self.ip2location_api_key) |k| allocator.free(k); allocator.free(self.ip2location_cache_file); + allocator.free(self.ipwhois_cache_file); } test "config loads defaults" { diff --git a/src/http/Server.zig b/src/http/Server.zig index e6e6ec1..e8264d2 100644 --- a/src/http/Server.zig +++ b/src/http/Server.zig @@ -143,7 +143,7 @@ pub const MockHarness = struct { const geoip = try allocator.create(GeoIp); errdefer allocator.destroy(geoip); - geoip.* = GeoIp.init(allocator, config.geolite_path, null, config.ip2location_cache_file) catch + geoip.* = GeoIp.init(allocator, config.geolite_path, config) catch return error.SkipZigTest; errdefer geoip.deinit(); diff --git a/src/location/GeoIp.zig b/src/location/GeoIp.zig index f7f73ec..d1bef6a 100644 --- a/src/location/GeoIp.zig +++ b/src/location/GeoIp.zig @@ -1,6 +1,8 @@ const std = @import("std"); const Ip2location = @import("Ip2location.zig"); +const IpWhoIs = @import("IpWhoIs.zig"); const Location = @import("resolver.zig").Location; +const Config = @import("../Config.zig"); const c = @cImport({ @cInclude("maxminddb.h"); @@ -9,11 +11,36 @@ const c = @cImport({ const GeoIP = @This(); const log = std.log.scoped(.geoip); +const FallbackClient = union(enum) { + ip2location: *Ip2location, + ipwhois: *IpWhoIs, + + fn lookup(self: FallbackClient, ip: []const u8) ?Location { + return switch (self) { + .ip2location => |client| client.lookup(ip), + .ipwhois => |client| client.lookup(ip), + }; + } + + fn deinit(self: FallbackClient, allocator: std.mem.Allocator) void { + switch (self) { + .ip2location => |client| { + client.deinit(); + allocator.destroy(client); + }, + .ipwhois => |client| { + client.deinit(); + allocator.destroy(client); + }, + } + } +}; + mmdb: *c.MMDB_s, -ip2location_client: *Ip2location, +fallback_client: FallbackClient, allocator: std.mem.Allocator, -pub fn init(allocator: std.mem.Allocator, db_path: []const u8, api_key: ?[]const u8, cache_path: []const u8) !GeoIP { +pub fn init(allocator: std.mem.Allocator, db_path: []const u8, config: Config) !GeoIP { const path_z = try std.heap.c_allocator.dupeZ(u8, db_path); defer std.heap.c_allocator.free(path_z); @@ -24,22 +51,29 @@ pub fn init(allocator: std.mem.Allocator, db_path: []const u8, api_key: ?[]const if (status != c.MMDB_SUCCESS) return error.CannotOpenDatabase; - const client: *Ip2location = try allocator.create(Ip2location); - errdefer allocator.destroy(client); - client.* = try Ip2location.init(allocator, api_key, cache_path); - errdefer { - client.deinit(); - allocator.destroy(client); - } - - std.log.info( - "IP2Location fallback: {s} (cache: {s})", - .{ if (api_key) |_| "key provided, 50k/mo limit" else "no key, 1k/day limit", cache_path }, - ); + const fallback_client: FallbackClient = switch (config.geoip_fallback) { + .ip2location => blk: { + const client = try allocator.create(Ip2location); + errdefer allocator.destroy(client); + client.* = try Ip2location.init(allocator, config.ip2location_api_key, config.ip2location_cache_file); + std.log.info( + "GeoIP fallback: IP2Location ({s}, cache: {s})", + .{ if (config.ip2location_api_key) |_| "key provided, 50k/mo limit" else "no key, 1k/day limit", config.ip2location_cache_file }, + ); + break :blk .{ .ip2location = client }; + }, + .ipwhois => blk: { + const client = try allocator.create(IpWhoIs); + errdefer allocator.destroy(client); + client.* = try IpWhoIs.init(allocator, config.ipwhois_cache_file); + std.log.info("GeoIP fallback: ipwho.is (cache: {s})", .{config.ipwhois_cache_file}); + break :blk .{ .ipwhois = client }; + }, + }; return GeoIP{ .mmdb = mmdb, - .ip2location_client = client, + .fallback_client = fallback_client, .allocator = allocator, }; } @@ -47,8 +81,7 @@ pub fn init(allocator: std.mem.Allocator, db_path: []const u8, api_key: ?[]const pub fn deinit(self: *GeoIP) void { c.MMDB_close(self.mmdb); self.allocator.destroy(self.mmdb); - self.ip2location_client.deinit(); - self.allocator.destroy(self.ip2location_client); + self.fallback_client.deinit(self.allocator); } pub fn lookup(self: *GeoIP, ip: []const u8) ?Location { @@ -60,8 +93,8 @@ pub fn lookup(self: *GeoIP, ip: []const u8) ?Location { if (self.extractCoordinates(ip, result)) |coords| return coords; - // Fallback to IP2Location - return self.ip2location_client.lookup(ip); + // Fallback to configured online provider + return self.fallback_client.lookup(ip); } fn lookupInternal(mmdb: *c.MMDB_s, ip: []const u8) !c.MMDB_lookup_result_s { @@ -196,13 +229,14 @@ test "MMDB functions are callable" { } test "GeoIP init with invalid path fails" { - const result = GeoIP.init(std.testing.allocator, "/nonexistent/path.mmdb", null, ""); + const config = try Config.load(std.testing.allocator); + defer config.deinit(std.testing.allocator); + const result = GeoIP.init(std.testing.allocator, "/nonexistent/path.mmdb", config); try std.testing.expectError(error.CannotOpenDatabase, result); } test "isUSIp detects US IPs" { const allocator = std.testing.allocator; - const Config = @import("../Config.zig"); const config = try Config.load(allocator); defer config.deinit(allocator); const build_options = @import("build_options"); @@ -213,7 +247,7 @@ test "isUSIp detects US IPs" { try GeoLite2.ensureDatabase(std.testing.allocator, db_path); } - var geoip = GeoIP.init(std.testing.allocator, db_path, null, config.ip2location_cache_file) catch + var geoip = GeoIP.init(std.testing.allocator, db_path, config) catch return error.SkipZigTest; defer geoip.deinit(); @@ -226,7 +260,6 @@ test "isUSIp detects US IPs" { } test "lookup works" { const allocator = std.testing.allocator; - const Config = @import("../Config.zig"); const config = try Config.load(allocator); defer config.deinit(allocator); const build_options = @import("build_options"); @@ -237,7 +270,7 @@ test "lookup works" { try GeoLite2.ensureDatabase(std.testing.allocator, db_path); } - var geoip = GeoIP.init(std.testing.allocator, db_path, null, config.ip2location_cache_file) catch + var geoip = GeoIP.init(std.testing.allocator, db_path, config) catch return error.SkipZigTest; defer geoip.deinit(); diff --git a/src/location/IpWhoIs.zig b/src/location/IpWhoIs.zig new file mode 100644 index 0000000..8c51c6a --- /dev/null +++ b/src/location/IpWhoIs.zig @@ -0,0 +1,150 @@ +const std = @import("std"); +const Allocator = std.mem.Allocator; +const Location = @import("resolver.zig").Location; +const Cache = @import("Ip2location.zig").Cache; + +const Self = @This(); + +const log = std.log.scoped(.ipwhois); + +allocator: Allocator, +http_client: std.http.Client, +cache: *Cache, + +pub fn init(allocator: Allocator, cache_path: []const u8) !Self { + const cache = try allocator.create(Cache); + errdefer allocator.destroy(cache); + cache.* = try .init(allocator, cache_path); + return .{ + .allocator = allocator, + .http_client = std.http.Client{ .allocator = allocator }, + .cache = cache, + }; +} + +pub fn deinit(self: *Self) void { + self.cache.deinit(); + self.allocator.destroy(self.cache); + self.http_client.deinit(); +} + +pub fn lookup(self: *Self, ip_str: []const u8) ?Location { + // Parse IP to u128 for cache lookup + const addr = std.net.Address.parseIp(ip_str, 0) catch return null; + const ip_u128: u128 = switch (addr.any.family) { + std.posix.AF.INET => @as(u128, @intCast(std.mem.readInt(u32, @ptrCast(&addr.in.sa.addr), .big))), + std.posix.AF.INET6 => std.mem.readInt(u128, @ptrCast(&addr.in6.sa.addr), .big), + else => return null, + }; + const family: u8 = if (addr.any.family == std.posix.AF.INET) 4 else 6; + + // Check cache first + if (self.cache.get(ip_u128)) |result| + return result; + + // Fetch from API + const result = self.fetch(ip_str) catch |err| { + log.err("API lookup failed: {}", .{err}); + return null; + }; + + // Store in cache + self.cache.put(ip_u128, family, result) catch |err| { + log.warn("Failed to cache result: {}", .{err}); + }; + + return result; +} + +fn fetch(self: *Self, ip_str: []const u8) !Location { + log.info("Fetching geolocation for IP {s}", .{ip_str}); + + if (@import("builtin").is_test) return error.LookupUnavailableInUnitTest; + + var buf: [256]u8 = undefined; + var w = std.Io.Writer.fixed(&buf); + try w.writeAll("https://ipwho.is/"); + try w.writeAll(ip_str); + // Request only the fields we need + try w.writeAll("?fields=city,region,country,latitude,longitude&output=json"); + + var response_buf: [4096]u8 = undefined; + var writer = std.Io.Writer.fixed(&response_buf); + const result = try self.http_client.fetch(.{ + .location = .{ .url = w.buffered() }, + .method = .GET, + .response_writer = &writer, + .extra_headers = &.{ + .{ .name = "User-Agent", .value = "wttr.in" }, + }, + }); + + if (result.status != .ok) { + log.err("API returned status {}", .{result.status}); + return error.ApiError; + } + + const response_body = response_buf[0..writer.end]; + + // Parse JSON response + const parsed = try std.json.parseFromSlice( + std.json.Value, + self.allocator, + response_body, + .{}, + ); + defer parsed.deinit(); + + const obj = parsed.value.object; + + // Check for success field + if (obj.get("success")) |success| { + if (success == .bool and !success.bool) { + const msg = if (obj.get("message")) |m| if (m == .string) m.string else "unknown" else "unknown"; + log.err("API returned error for ip {s}: {s}", .{ ip_str, msg }); + return error.ApiError; + } + } + + const lat = obj.get("latitude") orelse return error.MissingLatitude; + const lon = obj.get("longitude") orelse return error.MissingLongitude; + if (lat != .float and lat != .integer) { + log.err("Latitude returned from ipwho.is for ip {s} is not a number", .{ip_str}); + return error.MissingLatitude; + } + if (lon != .float and lon != .integer) { + log.err("Longitude returned from ipwho.is for ip {s} is not a number", .{ip_str}); + return error.MissingLongitude; + } + + const city = getString(obj, "city"); + const region = getString(obj, "region"); + const country = getString(obj, "country"); + + const display_name = Location.buildDisplayName( + self.allocator, + city, + region, + country, + ip_str, + ); + + const lat_val: f64 = if (lat == .float) @floatCast(lat.float) else @floatFromInt(lat.integer); + const lon_val: f64 = if (lon == .float) @floatCast(lon.float) else @floatFromInt(lon.integer); + + return Location{ + .allocator = self.allocator, + .name = display_name, + .coords = .{ + .latitude = lat_val, + .longitude = lon_val, + }, + }; +} + +inline fn getString(obj: std.json.ObjectMap, key: []const u8) []const u8 { + const maybe_val = obj.get(key); + if (maybe_val == null) return ""; + if (maybe_val.? != .string) return ""; + return maybe_val.?.string; +} diff --git a/src/location/resolver.zig b/src/location/resolver.zig index 295cf59..df35468 100644 --- a/src/location/resolver.zig +++ b/src/location/resolver.zig @@ -304,7 +304,7 @@ test "resolve IP address with GeoIP" { try GeoLite2.ensureDatabase(allocator, config.geolite_path); } - var geoip = GeoIp.init(allocator, config.geolite_path, null, config.ip2location_cache_file) catch + var geoip = GeoIp.init(allocator, config.geolite_path, config) catch return error.SkipZigTest; defer geoip.deinit(); diff --git a/src/main.zig b/src/main.zig index 9b32bee..9bb9217 100644 --- a/src/main.zig +++ b/src/main.zig @@ -38,12 +38,11 @@ pub fn main() !u8 { // Ensure GeoLite2 database exists try GeoLite2.ensureDatabase(allocator, cfg.geolite_path); - // Initialize GeoIP database with optional IP2Location fallback + // Initialize GeoIP database with configured fallback var geoip = GeoIp.init( allocator, cfg.geolite_path, - cfg.ip2location_api_key, - cfg.ip2location_cache_file, + cfg, ) catch |err| { std.log.err("Failed to load GeoIP database from {s}: {}", .{ cfg.geolite_path, err }); return err; @@ -105,4 +104,5 @@ test { _ = @import("location/GeoCache.zig"); _ = @import("location/Airports.zig"); _ = @import("location/resolver.zig"); + _ = @import("location/IpWhoIs.zig"); }