stt/src/main.zig
2025-10-01 10:03:47 -07:00

501 lines
21 KiB
Zig

//! STT with callback-based event handling
const std = @import("std");
const builtin = @import("builtin");
const stt = @import("stt.zig");
/// Global flag for signal handling
var should_exit = std.atomic.Value(bool).init(false);
// SAFETY: we are setting this value at top of main before use
/// We need a global here to reclaim process when getting SIGCHLD
var handler: SpeechHandler = undefined;
const SpeechHandler = struct {
allocator: std.mem.Allocator,
speech_count: u32 = 0,
error_count: u32 = 0,
warning_count: u32 = 0,
recoverable_error_count: u32 = 0,
exec_program: ?[]const u8 = null,
child_processes: std.ArrayList(*Process) = .{},
reclaiming: std.atomic.Value(bool) = std.atomic.Value(bool).init(false),
const max_children = 5;
const Process = struct { child: ?*std.process.Child, start: i64, id: std.process.Child.Id }; // why id?
/// Handle detected speech
fn onSpeech(ctx: *anyopaque, text: []const u8) void {
if (builtin.is_test) return; // Suppress output during tests
const self: *SpeechHandler = @ptrCast(@alignCast(ctx));
self.speech_count += 1;
// Print with timestamp for better experience
const timestamp = std.time.timestamp();
var stdout_buffer: [1024]u8 = undefined;
var stdout_writer = std.fs.File.stdout().writer(&stdout_buffer);
const stdout = &stdout_writer.interface;
defer stdout.flush() catch std.log.warn("Caught error writing speech data to stdout", .{});
stdout.print("[{}] Speech {}->{?s}: {s}\n", .{
timestamp,
self.speech_count,
self.exec_program,
text,
}) catch std.log.warn("Caught error writing speech data to stdout", .{});
// Execute program if specified
if (self.exec_program) |program| self.exec(text) catch |err| {
std.log.err("Failed to execute program '{s}': {}", .{ program, err });
};
}
fn exec(self: *SpeechHandler, text: []const u8) !void {
const program = self.exec_program.?; // should only be called when exec_program is not null
// We need to be able to clean up at some point in the future, but we don't
// care about these processes otherwise
const process = try self.allocator.create(Process);
errdefer self.allocator.destroy(process);
process.* = .{
.start = std.time.timestamp(),
.child = try self.allocator.create(std.process.Child),
// SAFETY: this is set 8 lines below before use
.id = undefined,
};
process.child.?.* = std.process.Child.init(&[_][]const u8{ program, text }, self.allocator);
try self.child_processes.append(self.allocator, process);
errdefer _ = self.child_processes.pop();
try process.child.?.spawn();
try process.child.?.waitForSpawn();
process.id = process.child.?.id;
try self.reclaimProcessesPosix(false);
}
fn reclaimProcessesPosix(self: *SpeechHandler, reap_all: bool) !void {
// We could end up called by two threads at the same time (via SIGCHLD and an actual speech event)
// This code should present that
if (self.reclaiming.cmpxchgStrong(false, true, .acquire, .acquire)) |_| return;
defer self.reclaiming.store(false, .release);
if (!reap_all and self.child_processes.items.len <= max_children) return;
std.log.debug("Reclaiming memory from {s} processes", .{if (reap_all) "ALL" else "completed"});
if (self.child_processes.items.len == 0) return;
// If we're not reaping everything, we can just as well skip the last
// one as we just started it
const end = self.child_processes.items.len - @as(usize, if (reap_all) 0 else 1);
const now = std.time.timestamp();
for (0..end) |i| {
const proc = self.child_processes.items[i];
// Check timestamp. If we're either a) whacking everything, or b) 10 seconds have elapsed,
// we kill it
const should_kill = (proc.start + 10) <= now;
if (proc.child == null or !should_kill) continue;
const child = proc.child.?;
const proc_exists = posixPidRunning(child.*) catch |err| {
// not sure what we do here
switch (err) {
error.ProcessNotFound => unreachable, // handled in posixPidRunning
error.PermissionDenied => {
std.log.err("Permission denied trying to reap pid {d}", .{child.id});
continue; // guess we'll keep it on the list and the OS will deal when we exit?
},
error.Unexpected => @panic("Unexpected error getting pid information. This should not happen"),
}
};
if (!proc_exists) {
_ = try child.wait(); // effectively deinit(). We don't care about term value (I hope?)
self.allocator.destroy(child);
proc.child = null;
continue;
}
std.log.warn("Process ran longer than 10 seconds, killing pid {d}", .{child.id});
proc.child = null; // avoid race condition between the kill below and the SIGCHLD processing
_ = child.kill() catch |err| {
// really should work at this point
std.log.err("Permission denied trying to kill pid {d}: {}", .{ child.id, err });
continue;
};
self.allocator.destroy(child);
}
if (reap_all) {
std.log.debug("Shutting down, waiting for processes to finish", .{});
for (self.child_processes.items) |proc| {
if (proc.child) |c| {
// Child id seems undefined here for some reason, but on sigchld we're ok
// I suspect this might be a race condition somehow but not sure how
// We've worked around it by copying the pid out of the child into the process
// when we spawn it, then read that here, but it is the only place we use
// this value
std.log.info("Waiting on pid {d}", .{proc.id});
_ = try c.wait();
self.allocator.destroy(c);
}
self.allocator.destroy(proc);
}
self.child_processes.deinit(self.allocator);
std.log.debug("All processes finished", .{});
return;
}
// TODO: What's the right number here? We want to clear out memory from
// the array list
if (self.child_processes.items.len > 20) {
std.log.debug("consolidating process tracking array", .{});
var open_procs: usize = 0;
for (self.child_processes.items) |proc| {
if (proc.child) |_| open_procs += 1;
}
const cp = try self.child_processes.toOwnedSlice(self.allocator);
defer self.allocator.free(cp);
try self.child_processes.ensureTotalCapacity(self.allocator, open_procs);
for (cp) |proc| {
if (proc.child) |_|
self.child_processes.appendAssumeCapacity(proc)
else
self.allocator.destroy(proc);
}
}
}
fn posixPidRunning(process: std.process.Child) std.posix.KillError!bool {
// From https://man7.org/linux/man-pages/man2/kill.2.html:
//
// If sig is 0, then no signal is sent, but existence and permission
// checks are still performed; this can be used to check for the
// existence of a process ID or process group ID that the caller is
// permitted to signal.
std.posix.kill(process.id, 0) catch |err| {
if (err == error.ProcessNotFound) return false;
return err; // Permission denied
};
return true; // process is running
}
pub fn deinit(self: *SpeechHandler) void {
self.reclaimProcessesPosix(true) catch |err| std.log.err("Error reclaiming processes: {}", .{err});
}
/// Handle basic errors (fallback for compatibility)
fn onError(ctx: *anyopaque, error_code: stt.Error, message: []const u8) void {
if (builtin.is_test) return; // Suppress output during tests
const self: *SpeechHandler = @ptrCast(@alignCast(ctx));
self.error_count += 1;
// Print error with timestamp
const timestamp = std.time.timestamp();
std.log.err("[{}] Error #{} ({}): {s}", .{ timestamp, self.error_count, error_code, message });
}
/// Handle detailed errors with comprehensive information
fn onDetailedError(ctx: *anyopaque, error_info: stt.ErrorInfo) void {
const self: *SpeechHandler = @ptrCast(@alignCast(ctx));
logDetail(self, error_info) catch |e|
std.log.err("Error writing error {}. Original message: {s}", .{ e, error_info.message });
}
fn logDetail(self: *SpeechHandler, error_info: stt.ErrorInfo) !void {
const log = std.log.scoped(.stt);
// Categorize the error for statistics
if (error_info.recoverable)
self.recoverable_error_count += 1
else
self.error_count += 1;
if (builtin.is_test) return; // Suppress output during tests
// Format complete error message in a buffer
var buffer: [2048]u8 = undefined;
var stream = std.io.fixedBufferStream(&buffer);
const writer = stream.writer();
try writer.print("{s}", .{error_info.message});
try writer.print("\n\tCode: {}", .{error_info.error_code});
if (error_info.context) |context|
try writer.print("\n\tContext: {s}", .{context});
if (error_info.system_error) |sys_err|
try writer.print("\n\tSystem Error: {} ({any})", .{ sys_err, error_info.error_code });
if (error_info.recovery_suggestion) |suggestion| {
if (std.mem.eql(u8, "Ready to start speech recognition", suggestion)) {
self.recoverable_error_count -= 1;
return; // that's stupid
}
try writer.print("\n\tSuggestion: {s}", .{suggestion});
}
if (error_info.recoverable)
try writer.print("\n\tStatus: Recoverable - system will attempt to continue", .{})
else
try writer.print("\n\tStatus: Fatal - intervention may be required", .{});
const message = stream.getWritten();
// Determine and call appropriate log function once
switch (error_info.error_code) {
stt.Error.InternalError => if (error_info.recoverable) {
log.info("{s}", .{message});
} else {
log.warn("{s}", .{message});
},
stt.Error.OutOfMemory,
stt.Error.ModelLoadError,
stt.Error.InitializationFailed,
=> log.err("{s}", .{message}),
else => if (error_info.recoverable)
log.warn("{s}", .{message})
else
log.err("{s}", .{message}),
}
}
/// Get comprehensive statistics for summary
fn getStats(self: *const SpeechHandler) struct {
speech_count: u32,
error_count: u32,
warning_count: u32,
recoverable_error_count: u32,
total_issues: u32,
} {
const total_issues = self.error_count + self.warning_count + self.recoverable_error_count;
return .{
.speech_count = self.speech_count,
.error_count = self.error_count,
.warning_count = self.warning_count,
.recoverable_error_count = self.recoverable_error_count,
.total_issues = total_issues,
};
}
};
/// Signal handler for graceful shutdown
fn signalHandler(sig: i32) callconv(.c) void {
if (sig == std.posix.SIG.INT) {
should_exit.store(true, .release);
}
}
fn signalAction(sig: i32, info: *const std.posix.siginfo_t, _: ?*anyopaque) callconv(.c) void {
// NOTE: info only works correctly if std.posix.SA.SIGINFO is in the flags
// std.log.debug("signal action. sig {d}", .{sig});
if (sig == std.posix.SIG.CHLD) {
const pid = info.fields.common.first.piduid.pid;
std.log.debug("SIGCHLD on pid {d}", .{pid});
for (handler.child_processes.items) |proc| {
if (proc.child) |child| {
if (child.id == pid) {
const term = child.wait() catch @panic("child.wait should not throw error at this point"); // I don't *think* this could fail at this point...
if (term == .Exited) { // this should be the only possible term value since the handler is set up with SA_NOCLDSTOP
if (term.Exited > 0) std.log.warn("Child process exited with non-zero return code {d}", .{term.Exited});
handler.allocator.destroy(child);
proc.child = null;
}
}
}
}
handler.reclaimProcessesPosix(false) catch |err| {
std.log.err("Caught error reclaiming processes. This is fatal, shutting down. Error: {}", .{err});
signalHandler(std.posix.SIG.INT);
};
}
}
pub fn main() !void {
const stdout = std.fs.File.stdout();
const stderr = std.fs.File.stderr();
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
defer _ = gpa.deinit();
const allocator = gpa.allocator();
// Check and set ALSA_CONFIG_PATH if not set
if (std.posix.getenv("ALSA_CONFIG_PATH") == null) {
std.fs.cwd().access("alsa.conf", .{}) catch {
_ = std.fs.File.stderr().writeAll("Error: alsa.conf file not found. Please put alsa.conf in the current directory or set ALSA_CONFIG_PATH\n") catch {};
std.process.exit(1);
};
const c = @cImport({
@cInclude("stdlib.h");
});
_ = c.setenv("ALSA_CONFIG_PATH", "alsa.conf", 1);
}
const sigintact = std.posix.Sigaction{
.handler = .{ .handler = signalHandler },
.mask = std.posix.sigemptyset(),
.flags = 0,
};
std.posix.sigaction(std.c.SIG.INT, &sigintact, null);
const sigchldact = std.posix.Sigaction{
.handler = .{ .sigaction = signalAction },
.mask = std.posix.sigemptyset(),
.flags = std.posix.SA.NOCLDSTOP | std.posix.SA.SIGINFO,
};
std.posix.sigaction(std.c.SIG.CHLD, &sigchldact, null);
// Parse command line arguments
const args = try std.process.argsAlloc(allocator);
defer std.process.argsFree(allocator, args);
var model_path: ?[]const u8 = null;
var exec_program: ?[]const u8 = null;
// Parse --model and --exec arguments
for (args[1..]) |arg| {
if (std.mem.startsWith(u8, arg, "--model=")) {
model_path = arg[8..]; // Skip "--model="
} else if (std.mem.startsWith(u8, arg, "--exec=")) {
exec_program = arg[7..]; // Skip "--exec="
}
}
// Create handler with statistics tracking
handler = SpeechHandler{
.allocator = allocator,
.exec_program = exec_program,
};
defer handler.deinit();
const speech_handler = stt.SpeechEventHandler{
.onSpeechFn = SpeechHandler.onSpeech,
.onErrorFn = SpeechHandler.onError,
.onDetailedErrorFn = SpeechHandler.onDetailedError,
.ctx = &handler,
};
// If no model specified, try default locations
const default_paths = [_][]const u8{
"vosk-model-small-en-us-0.15",
"zig-out/bin/vosk-model-small-en-us-0.15",
"/usr/share/vosk/models/vosk-model-small-en-us-0.15",
};
if (model_path == null) {
for (default_paths) |path| {
std.fs.cwd().access(path, .{}) catch continue;
model_path = path;
break;
}
}
// Check if model path exists
if (model_path == null) {
_ = try stderr.writeAll("Error: Vosk model not found.\n\n");
_ = try stderr.writeAll("Usage: stt [--model=<path>] [--exec=<program>]\n\n");
_ = try stderr.writeAll("Locations searched:\n");
inline for (default_paths) |path|
_ = try stderr.writeAll("\t" ++ path ++ "\n");
_ = try stderr.writeAll("Please download the model. A fine model can be downloaded from:\n");
_ = try stderr.writeAll("\thttps://alphacephei.com/vosk/models/vosk-model-small-en-us-0.15.zip\n");
std.process.exit(1);
}
std.fs.cwd().access(model_path.?, .{}) catch {
std.log.err("Model path does not exist: {s}", .{model_path.?});
std.process.exit(1);
};
// Initialize STT session with resolved model path
const options = stt.Options{
.model_path = model_path.?,
.audio_device = "default", // Use ALSA default device from alsa.conf
.event_handler = speech_handler,
.sample_rate = 16000, // Standard sample rate for speech recognition
.channels = 2, // Stereo input (will be converted to mono internally)
.buffer_size = 256, // Existing buffer size for low latency
};
std.log.debug("Initializing STT library...", .{});
var session = stt.Session.init(allocator, options) catch |err| {
std.log.err("Failed to initialize STT library: {}", .{err});
std.log.err("Please ensure:", .{});
std.log.err(" - Audio device '{s}' is available", .{options.audio_device});
std.log.err(" - Model directory exists at: {s}", .{options.model_path});
std.log.err(" - You have permission to access the audio device", .{});
return;
};
defer session.deinit();
std.log.info("Program to execute on speech detection: {?s}", .{exec_program});
std.log.info("STT library initialized successfully with configuration:", .{});
std.log.info(" Model path: {s}", .{options.model_path});
std.log.info(" Audio device: {s}", .{options.audio_device});
std.log.info(" Sample rate: {} Hz", .{options.sample_rate});
std.log.info(" Channels: {} (converted to mono)", .{options.channels});
std.log.info(" Buffer size: {} frames", .{options.buffer_size});
std.log.info("", .{});
// Start listening for speech with error handling
_ = stdout.writeAll("Starting speech recognition...\n") catch {};
session.start_listening() catch |err| {
std.log.err("Failed to start listening: {}", .{err});
switch (err) {
stt.Error.AudioDeviceError => {
std.log.err("Audio device error. Please check:", .{});
std.log.err(" - Device '{s}' exists and is accessible", .{options.audio_device});
std.log.err(" - No other application is using the device", .{});
std.log.err(" - You have permission to access audio devices", .{});
},
stt.Error.ThreadingError => {
std.log.err("Threading error. System may be under heavy load.", .{});
},
else => {
std.log.err("Unexpected error during startup.", .{});
},
}
return;
};
defer session.stop_listening();
std.log.info("Speech recognition started successfully", .{});
_ = stdout.writeAll("Listening for speech... (Press Ctrl+C to exit)\n") catch {};
_ = stdout.writeAll("Speak into your microphone to see speech recognition results\n") catch {};
_ = stdout.writeAll("------------------------------------------------------------\n") catch {};
// Main loop - wait for Ctrl+C signal
while (!should_exit.load(.acquire)) {
// Sleep for a short time to avoid busy waiting
std.Thread.sleep(100 * std.time.ns_per_ms); // 100ms
// Check if session is still listening (in case of errors)
if (!session.is_listening()) {
std.log.err("Speech recognition stopped unexpectedly.", .{});
break;
}
}
_ = stdout.writeAll("\n----------------------------------------\n") catch {};
_ = stdout.writeAll("Shutdown signal received, stopping...\n") catch {};
// Get final statistics from handler
const stats = handler.getStats();
std.log.info("Demo Session Summary:", .{});
std.log.info(" Speech detections: {}", .{stats.speech_count});
std.log.info(" Fatal errors: {}", .{stats.error_count});
std.log.info(" Recoverable errors: {}", .{stats.recoverable_error_count});
std.log.info(" Total issues: {}\n", .{stats.total_issues});
// Print seperately since ^^ are info calls and vv is an error call
if (stats.error_count > 0)
std.log.err("✗ {d} fatal errors occurred during speech recognition.", .{stats.error_count});
_ = stdout.writeAll("Session completed successfully.\n") catch {};
}
test "handler callbacks" {
const testing = std.testing;
var sh = SpeechHandler{ .allocator = std.testing.allocator };
const speech_handler = stt.SpeechEventHandler{
.onSpeechFn = SpeechHandler.onSpeech,
.onErrorFn = SpeechHandler.onError,
.ctx = &sh,
};
// Test that callbacks can be invoked without crashing
speech_handler.onSpeech("test speech");
speech_handler.onError(stt.Error.AudioDeviceError, "test error");
// If we get here without crashing, the test passes
try testing.expect(true);
}