clean up and channel detection

This commit is contained in:
Emil Lerch 2025-10-01 20:30:17 -07:00
parent 34754de586
commit 6559bf86a3
Signed by: lobo
GPG key ID: A7B62D657EF764F8
3 changed files with 42 additions and 156 deletions

View file

@ -249,24 +249,6 @@ const SpeechHandler = struct {
log.err("{s}", .{message}), log.err("{s}", .{message}),
} }
} }
/// Get comprehensive statistics for summary
fn getStats(self: *const SpeechHandler) struct {
speech_count: u32,
error_count: u32,
warning_count: u32,
recoverable_error_count: u32,
total_issues: u32,
} {
const total_issues = self.error_count + self.warning_count + self.recoverable_error_count;
return .{
.speech_count = self.speech_count,
.error_count = self.error_count,
.warning_count = self.warning_count,
.recoverable_error_count = self.recoverable_error_count,
.total_issues = total_issues,
};
}
}; };
/// Signal handler for graceful shutdown /// Signal handler for graceful shutdown
@ -401,7 +383,6 @@ pub fn main() !void {
.audio_device = "default", // Use ALSA default device from alsa.conf .audio_device = "default", // Use ALSA default device from alsa.conf
.event_handler = speech_handler, .event_handler = speech_handler,
.sample_rate = 16000, // Standard sample rate for speech recognition .sample_rate = 16000, // Standard sample rate for speech recognition
.channels = 2, // Stereo input (will be converted to mono internally)
.buffer_size = 256, // Existing buffer size for low latency .buffer_size = 256, // Existing buffer size for low latency
}; };
@ -421,13 +402,12 @@ pub fn main() !void {
std.log.info(" Model path: {s}", .{options.model_path}); std.log.info(" Model path: {s}", .{options.model_path});
std.log.info(" Audio device: {s}", .{options.audio_device}); std.log.info(" Audio device: {s}", .{options.audio_device});
std.log.info(" Sample rate: {} Hz", .{options.sample_rate}); std.log.info(" Sample rate: {} Hz", .{options.sample_rate});
std.log.info(" Channels: {} (converted to mono)", .{options.channels});
std.log.info(" Buffer size: {} frames", .{options.buffer_size}); std.log.info(" Buffer size: {} frames", .{options.buffer_size});
std.log.info("", .{}); std.log.info("", .{});
// Start listening for speech with error handling // Start listening for speech with error handling
_ = stdout.writeAll("Starting speech recognition...\n") catch {}; _ = stdout.writeAll("Starting speech recognition...\n") catch {};
session.start_listening() catch |err| { session.start() catch |err| {
std.log.err("Failed to start listening: {}", .{err}); std.log.err("Failed to start listening: {}", .{err});
switch (err) { switch (err) {
stt.Error.AudioDeviceError => { stt.Error.AudioDeviceError => {
@ -445,7 +425,7 @@ pub fn main() !void {
} }
return; return;
}; };
defer session.stop_listening(); defer session.stop();
std.log.info("Speech recognition started successfully", .{}); std.log.info("Speech recognition started successfully", .{});
_ = stdout.writeAll("Listening for speech... (Press Ctrl+C to exit)\n") catch {}; _ = stdout.writeAll("Listening for speech... (Press Ctrl+C to exit)\n") catch {};
@ -458,7 +438,7 @@ pub fn main() !void {
std.Thread.sleep(100 * std.time.ns_per_ms); // 100ms std.Thread.sleep(100 * std.time.ns_per_ms); // 100ms
// Check if session is still listening (in case of errors) // Check if session is still listening (in case of errors)
if (!session.is_listening()) { if (!session.listening) {
std.log.err("Speech recognition stopped unexpectedly.", .{}); std.log.err("Speech recognition stopped unexpectedly.", .{});
break; break;
} }
@ -468,16 +448,15 @@ pub fn main() !void {
_ = stdout.writeAll("Shutdown signal received, stopping...\n") catch {}; _ = stdout.writeAll("Shutdown signal received, stopping...\n") catch {};
// Get final statistics from handler // Get final statistics from handler
const stats = handler.getStats();
std.log.info("Demo Session Summary:", .{}); std.log.info("Demo Session Summary:", .{});
std.log.info(" Speech detections: {}", .{stats.speech_count}); std.log.info(" Speech detections: {}", .{handler.speech_count});
std.log.info(" Fatal errors: {}", .{stats.error_count}); std.log.info(" Fatal errors: {}", .{handler.error_count});
std.log.info(" Recoverable errors: {}", .{stats.recoverable_error_count}); std.log.info(" Recoverable errors: {}", .{handler.recoverable_error_count});
std.log.info(" Total issues: {}\n", .{stats.total_issues}); std.log.info(" Total issues: {}\n", .{handler.error_count + handler.warning_count + handler.recoverable_error_count});
// Print seperately since ^^ are info calls and vv is an error call // Print seperately since ^^ are info calls and vv is an error call
if (stats.error_count > 0) if (handler.error_count > 0)
std.log.err("✗ {d} fatal errors occurred during speech recognition.", .{stats.error_count}); std.log.err("✗ {d} fatal errors occurred during speech recognition.", .{handler.error_count});
_ = stdout.writeAll("Session completed successfully.\n") catch {}; _ = stdout.writeAll("Session completed successfully.\n") catch {};
} }

View file

@ -82,16 +82,6 @@ pub const ErrorInfo = struct {
}; };
} }
/// Create error info with system error code
pub fn initWithSystemError(error_code: Error, message: []const u8, system_error: i32) ErrorInfo {
return ErrorInfo{
.error_code = error_code,
.message = message,
.system_error = system_error,
.timestamp = std.time.timestamp(),
};
}
/// Create error info with context /// Create error info with context
pub fn initWithContext(error_code: Error, message: []const u8, context: []const u8) ErrorInfo { pub fn initWithContext(error_code: Error, message: []const u8, context: []const u8) ErrorInfo {
return ErrorInfo{ return ErrorInfo{
@ -332,7 +322,7 @@ pub const AlsaCapture = struct {
device_name: []const u8, device_name: []const u8,
/// Sample rate /// Sample rate
sample_rate: u32, sample_rate: u32,
/// Number of channels /// Number of channels. Available after open()
channels: u32, channels: u32,
/// Buffer size in frames /// Buffer size in frames
buffer_size: u32, buffer_size: u32,
@ -340,30 +330,25 @@ pub const AlsaCapture = struct {
period_size: u32, period_size: u32,
/// Audio buffer for captured data /// Audio buffer for captured data
audio_buffer: AudioBuffer, audio_buffer: AudioBuffer,
/// Temporary buffer for ALSA reads
temp_buffer: []i16,
/// Allocator for memory management /// Allocator for memory management
allocator: std.mem.Allocator, allocator: std.mem.Allocator,
/// Initialize ALSA capture with specified parameters /// Initialize ALSA capture with specified parameters
pub fn init(allocator: std.mem.Allocator, device_name: []const u8, sample_rate: u32, channels: u32, buffer_size: u32) !Self { pub fn init(allocator: std.mem.Allocator, device_name: []const u8, sample_rate: u32, buffer_size: u32) !Self {
// Calculate period size (typically 1/4 of buffer size) // Calculate period size (typically 1/4 of buffer size)
const period_size = buffer_size / 4; const period_size = buffer_size / 4;
// Create audio buffer (make it larger than ALSA buffer to prevent overruns) // Create audio buffer (make it larger than ALSA buffer to prevent overruns)
const audio_buffer = try AudioBuffer.init(allocator, buffer_size * 4); const audio_buffer = try AudioBuffer.init(allocator, buffer_size * 4);
// Allocate temporary buffer for ALSA reads
const temp_buffer = try allocator.alloc(i16, period_size * channels);
return Self{ return Self{
.device_name = device_name, .device_name = device_name,
.sample_rate = sample_rate, .sample_rate = sample_rate,
.channels = channels,
.buffer_size = buffer_size, .buffer_size = buffer_size,
.period_size = period_size, .period_size = period_size,
.audio_buffer = audio_buffer, .audio_buffer = audio_buffer,
.temp_buffer = temp_buffer, // SAFETY: this is set based on number of channels detected during open()
.channels = undefined,
.allocator = allocator, .allocator = allocator,
}; };
} }
@ -372,7 +357,6 @@ pub const AlsaCapture = struct {
pub fn deinit(self: *Self) void { pub fn deinit(self: *Self) void {
self.close(); self.close();
self.audio_buffer.deinit(); self.audio_buffer.deinit();
self.allocator.free(self.temp_buffer);
} }
/// Open ALSA device and configure parameters with detailed error reporting /// Open ALSA device and configure parameters with detailed error reporting
@ -415,18 +399,17 @@ pub const AlsaCapture = struct {
err = c.snd_pcm_hw_params_set_format(self.pcm_handle, hw_params, c.SND_PCM_FORMAT_S16_LE); err = c.snd_pcm_hw_params_set_format(self.pcm_handle, hw_params, c.SND_PCM_FORMAT_S16_LE);
if (err < 0) return Error.SetFormatError; if (err < 0) return Error.SetFormatError;
// SAFETY: min/max is set in c calls before use just below
var min: c_uint = undefined;
err = c.snd_pcm_hw_params_get_channels_min(hw_params, &min);
if (err < 0) return Error.SetChannelError;
self.channels = min;
// Set number of channels // Set number of channels
err = c.snd_pcm_hw_params_set_channels(self.pcm_handle, hw_params, self.channels); err = c.snd_pcm_hw_params_set_channels(self.pcm_handle, hw_params, self.channels);
if (err < 0) { if (err < 0) {
// SAFETY: min/max is set in c calls before use just below std.log.err("error setting number of channels. Must be at least {d}", .{min});
var min: c_uint = undefined;
// SAFETY: min/max is set in c calls before use just below
var max: c_uint = undefined;
err = c.snd_pcm_hw_params_get_channels_min(hw_params, &min);
if (err < 0) return Error.SetChannelError;
err = c.snd_pcm_hw_params_get_channels_max(hw_params, &max);
if (err < 0) return Error.SetChannelError;
std.log.err("error setting number of channels. Must be between {d} and {d}", .{ min, max });
return Error.SetChannelError; return Error.SetChannelError;
} }
@ -463,13 +446,16 @@ pub const AlsaCapture = struct {
} }
/// Read audio data from ALSA device and process it /// Read audio data from ALSA device and process it
pub fn readAudio(self: *Self) !usize { fn readAudio(self: *Self) !usize {
if (self.pcm_handle == null) { if (self.pcm_handle == null)
return Error.AudioDeviceError; return Error.AudioDeviceError;
}
// Allocate temporary buffer for ALSA reads
const temp_buffer = try self.allocator.alloc(i16, self.period_size * self.channels);
defer self.allocator.free(temp_buffer);
// Read audio data from ALSA // Read audio data from ALSA
const frames_read = c.snd_pcm_readi(self.pcm_handle, self.temp_buffer.ptr, self.period_size); const frames_read = c.snd_pcm_readi(self.pcm_handle, temp_buffer.ptr, self.period_size);
if (frames_read < 0) { if (frames_read < 0) {
// Handle underrun or other errors // Handle underrun or other errors
@ -486,13 +472,13 @@ pub const AlsaCapture = struct {
// Process audio based on channel configuration // Process audio based on channel configuration
if (self.channels == 1) { if (self.channels == 1) {
// Mono input - write directly to buffer // Mono input - write directly to buffer
_ = self.audio_buffer.write(self.temp_buffer[0..samples_read]); _ = self.audio_buffer.write(temp_buffer[0..samples_read]);
} else if (self.channels == 2) { } else if (self.channels == 2) {
// Stereo input - convert to mono // Stereo input - convert to mono
const mono_buffer = try self.allocator.alloc(i16, @as(usize, @intCast(frames_read))); const mono_buffer = try self.allocator.alloc(i16, @as(usize, @intCast(frames_read)));
defer self.allocator.free(mono_buffer); defer self.allocator.free(mono_buffer);
const mono_samples = AudioConverter.stereoToMono(self.temp_buffer[0..samples_read], mono_buffer); const mono_samples = AudioConverter.stereoToMono(temp_buffer[0..samples_read], mono_buffer);
_ = self.audio_buffer.write(mono_buffer[0..mono_samples]); _ = self.audio_buffer.write(mono_buffer[0..mono_samples]);
} else { } else {
// Multi-channel input - take first channel only // Multi-channel input - take first channel only
@ -500,7 +486,7 @@ pub const AlsaCapture = struct {
defer self.allocator.free(mono_buffer); defer self.allocator.free(mono_buffer);
for (0..@as(usize, @intCast(frames_read))) |i| { for (0..@as(usize, @intCast(frames_read))) |i| {
mono_buffer[i] = self.temp_buffer[i * self.channels]; mono_buffer[i] = temp_buffer[i * self.channels];
} }
_ = self.audio_buffer.write(mono_buffer); _ = self.audio_buffer.write(mono_buffer);
} }
@ -529,8 +515,10 @@ pub const Options = struct {
event_handler: SpeechEventHandler, event_handler: SpeechEventHandler,
/// Sample rate for audio processing (default: 16000) /// Sample rate for audio processing (default: 16000)
sample_rate: u32 = 16000, sample_rate: u32 = 16000,
/// Number of audio channels (default: 2 for stereo)
channels: u32 = 2, // Channels will be detected and used
// /// Number of audio channels (default: 2 for stereo)
// channels: u32 = 2,
/// Audio buffer size in frames (default: 256) /// Audio buffer size in frames (default: 256)
buffer_size: u32 = 256, buffer_size: u32 = 256,
}; };
@ -600,7 +588,6 @@ pub const Session = struct {
allocator, allocator,
options.audio_device, options.audio_device,
options.sample_rate, options.sample_rate,
options.channels,
options.buffer_size, options.buffer_size,
) catch |err| { ) catch |err| {
const error_info = switch (err) { const error_info = switch (err) {
@ -1024,10 +1011,10 @@ pub const Session = struct {
} }
} }
if (quote_start) |start| { if (quote_start) |s| {
// Find closing quote // Find closing quote
if (std.mem.indexOf(u8, json_str[start..], "\"")) |quote_end| { if (std.mem.indexOf(u8, json_str[s..], "\"")) |quote_end| {
const text = json_str[start .. start + quote_end]; const text = json_str[s .. s + quote_end];
// Only invoke callback if text is not empty // Only invoke callback if text is not empty
if (text.len > 0 and !std.mem.eql(u8, text, " ")) { if (text.len > 0 and !std.mem.eql(u8, text, " ")) {
@ -1131,9 +1118,6 @@ pub const Session = struct {
if (options.sample_rate == 0 or options.sample_rate > 48000) { if (options.sample_rate == 0 or options.sample_rate > 48000) {
return Error.InvalidParameter; return Error.InvalidParameter;
} }
if (options.channels == 0 or options.channels > 8) {
return Error.InvalidParameter;
}
if (options.buffer_size == 0 or options.buffer_size > 8192) { if (options.buffer_size == 0 or options.buffer_size > 8192) {
return Error.InvalidParameter; return Error.InvalidParameter;
} }
@ -1176,7 +1160,7 @@ pub const Session = struct {
/// Returns: /// Returns:
/// - void on success /// - void on success
/// - Error on failure /// - Error on failure
pub fn start_listening(self: *Session) Error!void { pub fn start(self: *Session) Error!void {
if (!self.initialized) { if (!self.initialized) {
return Error.InvalidState; return Error.InvalidState;
} }
@ -1226,7 +1210,7 @@ pub const Session = struct {
/// ///
/// This stops audio capture and speech recognition processing. /// This stops audio capture and speech recognition processing.
/// Any ongoing processing will be completed before returning. /// Any ongoing processing will be completed before returning.
pub fn stop_listening(self: *Session) void { pub fn stop(self: *Session) void {
if (!self.listening) { if (!self.listening) {
return; return;
} }
@ -1257,16 +1241,6 @@ pub const Session = struct {
self.listening = false; self.listening = false;
} }
/// Check if the session is currently listening
pub fn is_listening(self: *const Session) bool {
return self.listening;
}
/// Check if the session is initialized
pub fn is_initialized(self: *const Session) bool {
return self.initialized;
}
/// Deinitialize the STT session and free all resources /// Deinitialize the STT session and free all resources
/// ///
/// This must be called to properly clean up the session. /// This must be called to properly clean up the session.
@ -1274,7 +1248,7 @@ pub const Session = struct {
pub fn deinit(self: *Session) void { pub fn deinit(self: *Session) void {
// Ensure we're not listening before cleanup // Ensure we're not listening before cleanup
if (self.listening) { if (self.listening) {
self.stop_listening(); self.stop();
} }
// Detach any remaining threads to prevent hanging // Detach any remaining threads to prevent hanging
@ -1333,62 +1307,6 @@ pub fn init(allocator: std.mem.Allocator, options: Options) Error!Session {
return Session.init(allocator, options); return Session.init(allocator, options);
} }
/// C-compatible API functions for use from other languages
/// These wrap the Zig API with C calling conventions
/// Opaque handle type for C API
pub const Handle = opaque {};
/// Initialize STT library (C API)
///
/// Parameters:
/// - model_path: Null-terminated path to Vosk model
/// - audio_device: Null-terminated ALSA device name
///
/// Returns:
/// - Pointer to Handle on success
/// - null on failure
pub export fn stt_init(model_path: [*:0]const u8, audio_device: [*:0]const u8) ?*Handle {
// TODO: Implement C API wrapper in subsequent tasks
_ = model_path;
_ = audio_device;
return null;
}
/// Set speech detection callback (C API)
/// TODO: Implement in subsequent tasks with proper C-compatible callback types
// pub export fn stt_set_speech_callback(handle: *Handle, callback: SpeechCallback, user_data: ?*anyopaque) void {
// _ = handle;
// _ = callback;
// _ = user_data;
// }
/// Set error callback (C API)
/// TODO: Implement in subsequent tasks with proper C-compatible callback types
// pub export fn stt_set_error_callback(handle: *Handle, callback: ErrorCallback, user_data: ?*anyopaque) void {
// _ = handle;
// _ = callback;
// _ = user_data;
// }
/// Start listening (C API)
pub export fn stt_start_listening(handle: *Handle) c_int {
_ = handle;
// TODO: Implement in subsequent tasks
return -1; // Error for now
}
/// Stop listening (C API)
pub export fn stt_stop_listening(handle: *Handle) void {
_ = handle;
// TODO: Implement in subsequent tasks
}
/// Deinitialize STT library (C API)
pub export fn stt_deinit(handle: *Handle) void {
_ = handle;
// TODO: Implement in subsequent tasks
}
// Tests // Tests
test "Error enum" { test "Error enum" {
const testing = std.testing; const testing = std.testing;
@ -1435,7 +1353,6 @@ test "Options validation" {
try testing.expectEqualStrings("/path/to/model", valid_options.model_path); try testing.expectEqualStrings("/path/to/model", valid_options.model_path);
try testing.expectEqualStrings("hw:0,0", valid_options.audio_device); try testing.expectEqualStrings("hw:0,0", valid_options.audio_device);
try testing.expect(valid_options.sample_rate == 16000); try testing.expect(valid_options.sample_rate == 16000);
try testing.expect(valid_options.channels == 2);
try testing.expect(valid_options.buffer_size == 256); try testing.expect(valid_options.buffer_size == 256);
} }
@ -1472,7 +1389,6 @@ test "Session state management" {
try testing.expectEqualStrings("/path/to/model", options.model_path); try testing.expectEqualStrings("/path/to/model", options.model_path);
try testing.expectEqualStrings("hw:0,0", options.audio_device); try testing.expectEqualStrings("hw:0,0", options.audio_device);
try testing.expect(options.sample_rate == 16000); try testing.expect(options.sample_rate == 16000);
try testing.expect(options.channels == 2);
try testing.expect(options.buffer_size == 256); try testing.expect(options.buffer_size == 256);
} }
@ -1610,7 +1526,7 @@ test "AlsaCapture initialization" {
const allocator = gpa.allocator(); const allocator = gpa.allocator();
// Test ALSA capture initialization (without actually opening device) // Test ALSA capture initialization (without actually opening device)
var capture = AlsaCapture.init(allocator, "hw:0,0", 16000, 2, 1024) catch |err| { var capture = AlsaCapture.init(allocator, "hw:0,0", 16000, 1024) catch |err| {
// If ALSA initialization fails (e.g., no audio device), that's expected in test environment // If ALSA initialization fails (e.g., no audio device), that's expected in test environment
if (err == error.OutOfMemory) { if (err == error.OutOfMemory) {
return err; return err;
@ -1621,7 +1537,6 @@ test "AlsaCapture initialization" {
// Test basic properties // Test basic properties
try testing.expect(capture.sample_rate == 16000); try testing.expect(capture.sample_rate == 16000);
try testing.expect(capture.channels == 2);
try testing.expect(capture.buffer_size == 1024); try testing.expect(capture.buffer_size == 1024);
try testing.expect(capture.period_size == 256); // buffer_size / 4 try testing.expect(capture.period_size == 256); // buffer_size / 4
try testing.expect(capture.pcm_handle == null); // Not opened yet try testing.expect(capture.pcm_handle == null); // Not opened yet
@ -1701,7 +1616,6 @@ test "Session session management API" {
try testing.expectEqualStrings("/invalid/path", options.model_path); try testing.expectEqualStrings("/invalid/path", options.model_path);
try testing.expectEqualStrings("hw:0,0", options.audio_device); try testing.expectEqualStrings("hw:0,0", options.audio_device);
try testing.expect(options.sample_rate == 16000); try testing.expect(options.sample_rate == 16000);
try testing.expect(options.channels == 2);
try testing.expect(options.buffer_size == 256); try testing.expect(options.buffer_size == 256);
// Test options validation // Test options validation

View file

@ -137,10 +137,6 @@ test "Error types and ErrorInfo" {
try testing.expect(basic_error.context == null); try testing.expect(basic_error.context == null);
try testing.expect(basic_error.recoverable == false); try testing.expect(basic_error.recoverable == false);
// Test error info with system error
const system_error = stt.ErrorInfo.initWithSystemError(stt.Error.AudioDeviceError, "System error", -1);
try testing.expect(system_error.system_error.? == -1);
// Test error info with context // Test error info with context
const context_error = stt.ErrorInfo.initWithContext(stt.Error.ModelLoadError, "Context error", "/path/to/model"); const context_error = stt.ErrorInfo.initWithContext(stt.Error.ModelLoadError, "Context error", "/path/to/model");
try testing.expectEqualStrings("/path/to/model", context_error.context.?); try testing.expectEqualStrings("/path/to/model", context_error.context.?);
@ -270,7 +266,6 @@ test "Session initialization error handling" {
try testing.expectEqualStrings("/nonexistent/path", invalid_options.model_path); try testing.expectEqualStrings("/nonexistent/path", invalid_options.model_path);
try testing.expectEqualStrings("hw:999,0", invalid_options.audio_device); try testing.expectEqualStrings("hw:999,0", invalid_options.audio_device);
try testing.expect(invalid_options.sample_rate == 16000); try testing.expect(invalid_options.sample_rate == 16000);
try testing.expect(invalid_options.channels == 2);
try testing.expect(invalid_options.buffer_size == 256); try testing.expect(invalid_options.buffer_size == 256);
} }
@ -289,7 +284,6 @@ test "Session mock initialization and cleanup" {
.audio_device = "hw:0,0", .audio_device = "hw:0,0",
.event_handler = speech_handler, .event_handler = speech_handler,
.sample_rate = 16000, .sample_rate = 16000,
.channels = 2,
.buffer_size = 256, .buffer_size = 256,
}; };
@ -297,7 +291,6 @@ test "Session mock initialization and cleanup" {
try testing.expectEqualStrings("test/model/path", valid_options.model_path); try testing.expectEqualStrings("test/model/path", valid_options.model_path);
try testing.expectEqualStrings("hw:0,0", valid_options.audio_device); try testing.expectEqualStrings("hw:0,0", valid_options.audio_device);
try testing.expect(valid_options.sample_rate == 16000); try testing.expect(valid_options.sample_rate == 16000);
try testing.expect(valid_options.channels == 2);
try testing.expect(valid_options.buffer_size == 256); try testing.expect(valid_options.buffer_size == 256);
} }