262 lines
9.6 KiB
Python
262 lines
9.6 KiB
Python
"""
|
|
Voice Handler for LiveKit Agent
|
|
|
|
This module handles speech recognition and text-to-speech functionality
|
|
for the LiveKit Chrome automation agent.
|
|
"""
|
|
|
|
import asyncio
|
|
import logging
|
|
import io
|
|
import wave
|
|
from typing import Optional, Dict, Any
|
|
import numpy as np
|
|
|
|
from livekit import rtc
|
|
from livekit.plugins import openai, deepgram
|
|
|
|
|
|
class VoiceHandler:
|
|
"""Handles voice recognition and synthesis for the LiveKit agent"""
|
|
|
|
def __init__(self, config: Optional[Dict[str, Any]] = None):
|
|
self.config = config or {}
|
|
self.logger = logging.getLogger(__name__)
|
|
|
|
# Speech recognition settings
|
|
self.stt_provider = self.config.get('speech', {}).get('provider', 'openai')
|
|
self.language = self.config.get('speech', {}).get('language', 'en-US')
|
|
self.confidence_threshold = self.config.get('speech', {}).get('confidence_threshold', 0.7)
|
|
|
|
# Text-to-speech settings
|
|
self.tts_provider = self.config.get('tts', {}).get('provider', 'openai')
|
|
self.voice = self.config.get('tts', {}).get('voice', 'alloy')
|
|
self.speed = self.config.get('tts', {}).get('speed', 1.0)
|
|
|
|
# Audio processing
|
|
self.sample_rate = 16000
|
|
self.channels = 1
|
|
self.chunk_size = 1024
|
|
|
|
# Components
|
|
self.stt_engine = None
|
|
self.tts_engine = None
|
|
self.audio_buffer = []
|
|
|
|
async def initialize(self):
|
|
"""Initialize speech recognition and synthesis engines"""
|
|
try:
|
|
# Check if OpenAI API key is available
|
|
import os
|
|
openai_key = os.getenv('OPENAI_API_KEY')
|
|
|
|
# Initialize STT engine
|
|
if self.stt_provider == 'openai' and openai_key:
|
|
self.stt_engine = openai.STT(
|
|
language=self.language,
|
|
detect_language=True
|
|
)
|
|
elif self.stt_provider == 'deepgram':
|
|
self.stt_engine = deepgram.STT(
|
|
language=self.language,
|
|
model="nova-2"
|
|
)
|
|
else:
|
|
self.logger.warning(f"STT provider {self.stt_provider} not available or API key missing")
|
|
|
|
# Initialize TTS engine
|
|
if self.tts_provider == 'openai' and openai_key:
|
|
self.tts_engine = openai.TTS(
|
|
voice=self.voice,
|
|
speed=self.speed
|
|
)
|
|
else:
|
|
self.logger.warning(f"TTS provider {self.tts_provider} not available or API key missing")
|
|
|
|
self.logger.info(f"Voice handler initialized with STT: {self.stt_provider}, TTS: {self.tts_provider}")
|
|
|
|
except Exception as e:
|
|
self.logger.warning(f"Voice handler initialization failed (this is expected without API keys): {e}")
|
|
# Don't raise the exception, just log it
|
|
|
|
async def process_audio_frame(self, frame: rtc.AudioFrame) -> Optional[str]:
|
|
"""Process an audio frame and return recognized text"""
|
|
try:
|
|
# Convert frame to numpy array
|
|
audio_data = np.frombuffer(frame.data, dtype=np.int16)
|
|
|
|
# Add to buffer
|
|
self.audio_buffer.extend(audio_data)
|
|
|
|
# Process when we have enough data (e.g., 1 second of audio)
|
|
if len(self.audio_buffer) >= self.sample_rate:
|
|
text = await self._recognize_speech(self.audio_buffer)
|
|
self.audio_buffer = [] # Clear buffer
|
|
return text
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error processing audio frame: {e}")
|
|
|
|
return None
|
|
|
|
async def _recognize_speech(self, audio_data: list) -> Optional[str]:
|
|
"""Recognize speech from audio data"""
|
|
try:
|
|
if not self.stt_engine:
|
|
return None
|
|
|
|
# Convert to audio format expected by STT engine
|
|
audio_array = np.array(audio_data, dtype=np.int16)
|
|
|
|
# Create audio stream
|
|
stream = self._create_audio_stream(audio_array)
|
|
|
|
# Recognize speech
|
|
if self.stt_provider == 'openai':
|
|
result = await self.stt_engine.recognize(stream)
|
|
elif self.stt_provider == 'deepgram':
|
|
result = await self.stt_engine.recognize(stream)
|
|
else:
|
|
return None
|
|
|
|
# Check confidence and return text
|
|
if hasattr(result, 'confidence') and result.confidence < self.confidence_threshold:
|
|
return None
|
|
|
|
text = result.text.strip() if hasattr(result, 'text') else str(result).strip()
|
|
|
|
if text:
|
|
self.logger.info(f"Recognized speech: {text}")
|
|
return text
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error recognizing speech: {e}")
|
|
|
|
return None
|
|
|
|
def _create_audio_stream(self, audio_data: np.ndarray) -> io.BytesIO:
|
|
"""Create an audio stream from numpy array"""
|
|
# Convert to bytes
|
|
audio_bytes = audio_data.tobytes()
|
|
|
|
# Create WAV file in memory
|
|
wav_buffer = io.BytesIO()
|
|
with wave.open(wav_buffer, 'wb') as wav_file:
|
|
wav_file.setnchannels(self.channels)
|
|
wav_file.setsampwidth(2) # 16-bit
|
|
wav_file.setframerate(self.sample_rate)
|
|
wav_file.writeframes(audio_bytes)
|
|
|
|
wav_buffer.seek(0)
|
|
return wav_buffer
|
|
|
|
async def speak_response(self, text: str, room: Optional[rtc.Room] = None) -> bool:
|
|
"""Convert text to speech and play it"""
|
|
try:
|
|
if not self.tts_engine:
|
|
self.logger.warning("TTS engine not initialized")
|
|
return False
|
|
|
|
self.logger.info(f"Speaking: {text}")
|
|
|
|
# Generate speech
|
|
if self.tts_provider == 'openai':
|
|
audio_stream = await self.tts_engine.synthesize(text)
|
|
else:
|
|
return False
|
|
|
|
# If room is provided, publish audio track
|
|
if room:
|
|
await self._publish_audio_track(room, audio_stream)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error speaking response: {e}")
|
|
return False
|
|
|
|
async def provide_action_feedback(self, action: str, result: str, room: Optional[rtc.Room] = None) -> bool:
|
|
"""Provide immediate voice feedback about automation actions"""
|
|
try:
|
|
# Create concise feedback based on action type
|
|
feedback_text = self._generate_action_feedback(action, result)
|
|
|
|
if feedback_text:
|
|
return await self.speak_response(feedback_text, room)
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error providing action feedback: {e}")
|
|
return False
|
|
|
|
def _generate_action_feedback(self, action: str, result: str) -> str:
|
|
"""Generate concise feedback text for different actions"""
|
|
try:
|
|
# Parse result to determine success/failure
|
|
success = "success" in result.lower() or "clicked" in result.lower() or "filled" in result.lower()
|
|
|
|
if action == "click":
|
|
return "Clicked" if success else "Click failed"
|
|
elif action == "fill":
|
|
return "Field filled" if success else "Fill failed"
|
|
elif action == "navigate":
|
|
return "Navigated" if success else "Navigation failed"
|
|
elif action == "search":
|
|
return "Search completed" if success else "Search failed"
|
|
elif action == "type":
|
|
return "Text entered" if success else "Text entry failed"
|
|
else:
|
|
return "Action completed" if success else "Action failed"
|
|
|
|
except Exception:
|
|
return "Action processed"
|
|
|
|
async def _publish_audio_track(self, room: rtc.Room, audio_stream):
|
|
"""Publish audio track to the room"""
|
|
try:
|
|
# Create audio source
|
|
source = rtc.AudioSource(self.sample_rate, self.channels)
|
|
track = rtc.LocalAudioTrack.create_audio_track("agent-voice", source)
|
|
|
|
# Publish track
|
|
options = rtc.TrackPublishOptions()
|
|
options.source = rtc.TrackSource.SOURCE_MICROPHONE
|
|
|
|
publication = await room.local_participant.publish_track(track, options)
|
|
|
|
# Stream audio data
|
|
async for frame in audio_stream:
|
|
await source.capture_frame(frame)
|
|
|
|
# Unpublish when done
|
|
await room.local_participant.unpublish_track(publication.sid)
|
|
|
|
except Exception as e:
|
|
self.logger.error(f"Error publishing audio track: {e}")
|
|
|
|
async def set_language(self, language: str):
|
|
"""Change the recognition language"""
|
|
self.language = language
|
|
# Reinitialize STT engine with new language
|
|
await self.initialize()
|
|
|
|
async def set_voice(self, voice: str):
|
|
"""Change the TTS voice"""
|
|
self.voice = voice
|
|
# Reinitialize TTS engine with new voice
|
|
await self.initialize()
|
|
|
|
def get_supported_languages(self) -> list:
|
|
"""Get list of supported languages"""
|
|
return [
|
|
'en-US', 'en-GB', 'es-ES', 'fr-FR', 'de-DE',
|
|
'it-IT', 'pt-BR', 'ru-RU', 'ja-JP', 'ko-KR', 'zh-CN'
|
|
]
|
|
|
|
def get_supported_voices(self) -> list:
|
|
"""Get list of supported voices"""
|
|
if self.tts_provider == 'openai':
|
|
return ['alloy', 'echo', 'fable', 'onyx', 'nova', 'shimmer']
|
|
return []
|