feat: add voice-ai-engine-development skill for building real-time conversational AI
This commit is contained in:
@@ -0,0 +1,423 @@
|
||||
"""
|
||||
Example: Complete Voice AI Engine Implementation
|
||||
|
||||
This example demonstrates a minimal but complete voice AI engine
|
||||
with all core components: Transcriber, Agent, Synthesizer, and WebSocket integration.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, AsyncGenerator
|
||||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# ============================================================================
|
||||
# Data Models
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class Transcription:
|
||||
message: str
|
||||
confidence: float
|
||||
is_final: bool
|
||||
is_interrupt: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class AgentResponse:
|
||||
message: str
|
||||
is_interruptible: bool = True
|
||||
|
||||
|
||||
@dataclass
|
||||
class SynthesisResult:
|
||||
chunk_generator: AsyncGenerator[bytes, None]
|
||||
get_message_up_to: callable
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Base Worker Pattern
|
||||
# ============================================================================
|
||||
|
||||
class BaseWorker:
|
||||
"""Base class for all workers in the pipeline"""
|
||||
|
||||
def __init__(self, input_queue: asyncio.Queue, output_queue: asyncio.Queue):
|
||||
self.input_queue = input_queue
|
||||
self.output_queue = output_queue
|
||||
self.active = False
|
||||
self._task = None
|
||||
|
||||
def start(self):
|
||||
"""Start the worker's processing loop"""
|
||||
self.active = True
|
||||
self._task = asyncio.create_task(self._run_loop())
|
||||
|
||||
async def _run_loop(self):
|
||||
"""Main processing loop - runs forever until terminated"""
|
||||
while self.active:
|
||||
try:
|
||||
item = await self.input_queue.get()
|
||||
await self.process(item)
|
||||
except Exception as e:
|
||||
logger.error(f"Worker error: {e}", exc_info=True)
|
||||
|
||||
async def process(self, item):
|
||||
"""Override this - does the actual work"""
|
||||
raise NotImplementedError
|
||||
|
||||
def terminate(self):
|
||||
"""Stop the worker"""
|
||||
self.active = False
|
||||
if self._task:
|
||||
self._task.cancel()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Transcriber Component
|
||||
# ============================================================================
|
||||
|
||||
class DeepgramTranscriber(BaseWorker):
|
||||
"""Converts audio chunks to text transcriptions using Deepgram"""
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
super().__init__(asyncio.Queue(), asyncio.Queue())
|
||||
self.config = config
|
||||
self.is_muted = False
|
||||
|
||||
def send_audio(self, chunk: bytes):
|
||||
"""Client calls this to send audio"""
|
||||
if not self.is_muted:
|
||||
self.input_queue.put_nowait(chunk)
|
||||
else:
|
||||
# Send silence instead (prevents echo during bot speech)
|
||||
self.input_queue.put_nowait(self.create_silent_chunk(len(chunk)))
|
||||
|
||||
def create_silent_chunk(self, size: int) -> bytes:
|
||||
"""Create a silent audio chunk"""
|
||||
return b'\x00' * size
|
||||
|
||||
def mute(self):
|
||||
"""Called when bot starts speaking (prevents echo)"""
|
||||
self.is_muted = True
|
||||
logger.info("🔇 [TRANSCRIBER] Muted")
|
||||
|
||||
def unmute(self):
|
||||
"""Called when bot stops speaking"""
|
||||
self.is_muted = False
|
||||
logger.info("🔊 [TRANSCRIBER] Unmuted")
|
||||
|
||||
async def process(self, audio_chunk: bytes):
|
||||
"""Process audio chunk and generate transcription"""
|
||||
# In a real implementation, this would call Deepgram API
|
||||
# For this example, we'll simulate a transcription
|
||||
|
||||
# Simulate API call delay
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Mock transcription
|
||||
transcription = Transcription(
|
||||
message="Hello, how can I help you?",
|
||||
confidence=0.95,
|
||||
is_final=True
|
||||
)
|
||||
|
||||
logger.info(f"🎤 [TRANSCRIBER] Received: '{transcription.message}'")
|
||||
self.output_queue.put_nowait(transcription)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Agent Component
|
||||
# ============================================================================
|
||||
|
||||
class GeminiAgent(BaseWorker):
|
||||
"""LLM-powered conversational agent using Google Gemini"""
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
super().__init__(asyncio.Queue(), asyncio.Queue())
|
||||
self.config = config
|
||||
self.conversation_history = []
|
||||
|
||||
async def process(self, transcription: Transcription):
|
||||
"""Process transcription and generate response"""
|
||||
# Add user message to history
|
||||
self.conversation_history.append({
|
||||
"role": "user",
|
||||
"content": transcription.message
|
||||
})
|
||||
|
||||
logger.info(f"🤖 [AGENT] Generating response for: '{transcription.message}'")
|
||||
|
||||
# Generate response (streaming)
|
||||
async for response in self.generate_response(transcription.message):
|
||||
self.output_queue.put_nowait(response)
|
||||
|
||||
async def generate_response(self, user_input: str) -> AsyncGenerator[AgentResponse, None]:
|
||||
"""Generate streaming response from LLM"""
|
||||
# In a real implementation, this would call Gemini API
|
||||
# For this example, we'll simulate a streaming response
|
||||
|
||||
# Simulate streaming delay
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# IMPORTANT: Buffer entire response before yielding
|
||||
# This prevents audio jumping/cutting off
|
||||
full_response = f"I understand you said: {user_input}. How can I assist you further?"
|
||||
|
||||
# Add to conversation history
|
||||
self.conversation_history.append({
|
||||
"role": "assistant",
|
||||
"content": full_response
|
||||
})
|
||||
|
||||
logger.info(f"🤖 [AGENT] Generated: '{full_response}'")
|
||||
|
||||
# Yield complete response
|
||||
yield AgentResponse(
|
||||
message=full_response,
|
||||
is_interruptible=True
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Synthesizer Component
|
||||
# ============================================================================
|
||||
|
||||
class ElevenLabsSynthesizer:
|
||||
"""Converts text to speech using ElevenLabs"""
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
|
||||
async def create_speech(self, message: str, chunk_size: int = 1024) -> SynthesisResult:
|
||||
"""
|
||||
Generate speech audio from text
|
||||
|
||||
Returns SynthesisResult with:
|
||||
- chunk_generator: AsyncGenerator yielding audio chunks
|
||||
- get_message_up_to: Function to get partial text for interrupts
|
||||
"""
|
||||
|
||||
# In a real implementation, this would call ElevenLabs API
|
||||
# For this example, we'll simulate audio generation
|
||||
|
||||
logger.info(f"🔊 [SYNTHESIZER] Synthesizing {len(message)} characters")
|
||||
|
||||
async def chunk_generator():
|
||||
# Simulate streaming audio chunks
|
||||
num_chunks = len(message) // 10 + 1
|
||||
for i in range(num_chunks):
|
||||
# Simulate API delay
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
# Mock audio chunk (in reality, this would be PCM audio)
|
||||
chunk = b'\x00' * chunk_size
|
||||
yield chunk
|
||||
|
||||
def get_message_up_to(seconds: float) -> str:
|
||||
"""Calculate partial message based on playback time"""
|
||||
# Estimate: ~150 words per minute = ~2.5 words per second
|
||||
# Rough estimate: 5 characters per word
|
||||
chars_per_second = 12.5
|
||||
char_index = int(seconds * chars_per_second)
|
||||
return message[:char_index]
|
||||
|
||||
return SynthesisResult(
|
||||
chunk_generator=chunk_generator(),
|
||||
get_message_up_to=get_message_up_to
|
||||
)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Output Device
|
||||
# ============================================================================
|
||||
|
||||
class WebsocketOutputDevice:
|
||||
"""Sends audio chunks to client via WebSocket"""
|
||||
|
||||
def __init__(self, websocket: WebSocket):
|
||||
self.websocket = websocket
|
||||
|
||||
async def consume_nonblocking(self, chunk: bytes):
|
||||
"""Send audio chunk to client"""
|
||||
await self.websocket.send_bytes(chunk)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Conversation Orchestrator
|
||||
# ============================================================================
|
||||
|
||||
class StreamingConversation:
|
||||
"""Orchestrates the entire voice conversation pipeline"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_device: WebsocketOutputDevice,
|
||||
transcriber: DeepgramTranscriber,
|
||||
agent: GeminiAgent,
|
||||
synthesizer: ElevenLabsSynthesizer
|
||||
):
|
||||
self.output_device = output_device
|
||||
self.transcriber = transcriber
|
||||
self.agent = agent
|
||||
self.synthesizer = synthesizer
|
||||
self.is_human_speaking = True
|
||||
self.interrupt_event = asyncio.Event()
|
||||
|
||||
async def start(self):
|
||||
"""Start all workers"""
|
||||
logger.info("🚀 [CONVERSATION] Starting...")
|
||||
|
||||
# Start workers
|
||||
self.transcriber.start()
|
||||
self.agent.start()
|
||||
|
||||
# Start processing pipelines
|
||||
asyncio.create_task(self._process_transcriptions())
|
||||
asyncio.create_task(self._process_agent_responses())
|
||||
|
||||
async def _process_transcriptions(self):
|
||||
"""Process transcriptions from transcriber"""
|
||||
while True:
|
||||
transcription = await self.transcriber.output_queue.get()
|
||||
|
||||
# Check if this is an interrupt
|
||||
if not self.is_human_speaking:
|
||||
logger.info("⚠️ [INTERRUPT] User interrupted bot")
|
||||
self.interrupt_event.set()
|
||||
transcription.is_interrupt = True
|
||||
|
||||
self.is_human_speaking = True
|
||||
|
||||
# Send to agent
|
||||
await self.agent.input_queue.put(transcription)
|
||||
|
||||
async def _process_agent_responses(self):
|
||||
"""Process responses from agent and synthesize"""
|
||||
while True:
|
||||
response = await self.agent.output_queue.get()
|
||||
|
||||
self.is_human_speaking = False
|
||||
|
||||
# Mute transcriber to prevent echo
|
||||
self.transcriber.mute()
|
||||
|
||||
# Synthesize and play
|
||||
synthesis_result = await self.synthesizer.create_speech(response.message)
|
||||
await self._send_speech_to_output(synthesis_result, seconds_per_chunk=0.1)
|
||||
|
||||
# Unmute transcriber
|
||||
self.transcriber.unmute()
|
||||
|
||||
self.is_human_speaking = True
|
||||
|
||||
async def _send_speech_to_output(self, synthesis_result: SynthesisResult, seconds_per_chunk: float):
|
||||
"""
|
||||
Send synthesized audio to output with rate limiting
|
||||
|
||||
CRITICAL: Rate limiting enables interrupts to work
|
||||
"""
|
||||
chunk_idx = 0
|
||||
|
||||
async for chunk in synthesis_result.chunk_generator:
|
||||
# Check for interrupt
|
||||
if self.interrupt_event.is_set():
|
||||
logger.info(f"🛑 [INTERRUPT] Stopped after {chunk_idx} chunks")
|
||||
|
||||
# Calculate what was actually spoken
|
||||
seconds_spoken = chunk_idx * seconds_per_chunk
|
||||
partial_message = synthesis_result.get_message_up_to(seconds_spoken)
|
||||
logger.info(f"📝 [INTERRUPT] Partial message: '{partial_message}'")
|
||||
|
||||
# Clear interrupt event
|
||||
self.interrupt_event.clear()
|
||||
return
|
||||
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Send chunk to output device
|
||||
await self.output_device.consume_nonblocking(chunk)
|
||||
|
||||
# CRITICAL: Wait for chunk to play before sending next one
|
||||
# This is what makes interrupts work!
|
||||
processing_time = asyncio.get_event_loop().time() - start_time
|
||||
await asyncio.sleep(max(seconds_per_chunk - processing_time, 0))
|
||||
|
||||
chunk_idx += 1
|
||||
|
||||
def receive_audio(self, audio_chunk: bytes):
|
||||
"""Receive audio from client"""
|
||||
self.transcriber.send_audio(audio_chunk)
|
||||
|
||||
async def terminate(self):
|
||||
"""Gracefully shut down all workers"""
|
||||
logger.info("🛑 [CONVERSATION] Terminating...")
|
||||
|
||||
self.transcriber.terminate()
|
||||
self.agent.terminate()
|
||||
|
||||
# Wait for queues to drain
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# WebSocket Endpoint
|
||||
# ============================================================================
|
||||
|
||||
@app.websocket("/conversation")
|
||||
async def conversation_endpoint(websocket: WebSocket):
|
||||
"""WebSocket endpoint for voice conversations"""
|
||||
await websocket.accept()
|
||||
logger.info("✅ [WEBSOCKET] Client connected")
|
||||
|
||||
# Configuration
|
||||
config = {
|
||||
"transcriberProvider": "deepgram",
|
||||
"llmProvider": "gemini",
|
||||
"voiceProvider": "elevenlabs",
|
||||
"prompt": "You are a helpful AI assistant.",
|
||||
}
|
||||
|
||||
# Create components
|
||||
transcriber = DeepgramTranscriber(config)
|
||||
agent = GeminiAgent(config)
|
||||
synthesizer = ElevenLabsSynthesizer(config)
|
||||
output_device = WebsocketOutputDevice(websocket)
|
||||
|
||||
# Create conversation
|
||||
conversation = StreamingConversation(
|
||||
output_device=output_device,
|
||||
transcriber=transcriber,
|
||||
agent=agent,
|
||||
synthesizer=synthesizer
|
||||
)
|
||||
|
||||
# Start conversation
|
||||
await conversation.start()
|
||||
|
||||
try:
|
||||
# Process incoming audio
|
||||
async for message in websocket.iter_bytes():
|
||||
conversation.receive_audio(message)
|
||||
except WebSocketDisconnect:
|
||||
logger.info("❌ [WEBSOCKET] Client disconnected")
|
||||
except Exception as e:
|
||||
logger.error(f"❌ [WEBSOCKET] Error: {e}", exc_info=True)
|
||||
finally:
|
||||
await conversation.terminate()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Main Entry Point
|
||||
# ============================================================================
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
logger.info("🚀 Starting Voice AI Engine...")
|
||||
uvicorn.run(app, host="0.0.0.0", port=8000)
|
||||
@@ -0,0 +1,239 @@
|
||||
"""
|
||||
Example: Gemini Agent Implementation with Streaming
|
||||
|
||||
This example shows how to implement a Gemini-powered agent
|
||||
that properly buffers responses to prevent audio jumping.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import AsyncGenerator, List, Dict
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Message:
|
||||
role: str # "user" or "assistant"
|
||||
content: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class GeneratedResponse:
|
||||
message: str
|
||||
is_interruptible: bool = True
|
||||
|
||||
|
||||
class GeminiAgent:
|
||||
"""
|
||||
LLM-powered conversational agent using Google Gemini
|
||||
|
||||
Key Features:
|
||||
- Maintains conversation history
|
||||
- Streams responses from Gemini API
|
||||
- Buffers entire response before yielding (prevents audio jumping)
|
||||
- Handles interrupts gracefully
|
||||
"""
|
||||
|
||||
def __init__(self, config: Dict):
|
||||
self.config = config
|
||||
self.conversation_history: List[Message] = []
|
||||
self.system_prompt = config.get("prompt", "You are a helpful AI assistant.")
|
||||
self.current_task = None
|
||||
|
||||
async def generate_response(
|
||||
self,
|
||||
user_input: str,
|
||||
is_interrupt: bool = False
|
||||
) -> AsyncGenerator[GeneratedResponse, None]:
|
||||
"""
|
||||
Generate streaming response from Gemini
|
||||
|
||||
IMPORTANT: This buffers the entire LLM response before yielding
|
||||
to prevent audio jumping/cutting off.
|
||||
|
||||
Args:
|
||||
user_input: The user's message
|
||||
is_interrupt: Whether this is an interrupt
|
||||
|
||||
Yields:
|
||||
GeneratedResponse with complete buffered message
|
||||
"""
|
||||
|
||||
# Add user message to history
|
||||
self.conversation_history.append(
|
||||
Message(role="user", content=user_input)
|
||||
)
|
||||
|
||||
logger.info(f"🤖 [AGENT] Generating response for: '{user_input}'")
|
||||
|
||||
# Build conversation context for Gemini
|
||||
contents = self._build_gemini_contents()
|
||||
|
||||
# Stream response from Gemini and buffer it
|
||||
full_response = ""
|
||||
|
||||
try:
|
||||
# In a real implementation, this would call Gemini API
|
||||
# async for chunk in self._create_gemini_stream(contents):
|
||||
# if isinstance(chunk, str):
|
||||
# full_response += chunk
|
||||
|
||||
# For this example, simulate streaming
|
||||
async for chunk in self._simulate_gemini_stream(user_input):
|
||||
full_response += chunk
|
||||
|
||||
# Log progress (optional)
|
||||
if len(full_response) % 50 == 0:
|
||||
logger.debug(f"🤖 [AGENT] Buffered {len(full_response)} chars...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ [AGENT] Error generating response: {e}")
|
||||
full_response = "I apologize, but I encountered an error. Could you please try again?"
|
||||
|
||||
# CRITICAL: Only yield after buffering the ENTIRE response
|
||||
# This prevents multiple TTS calls that cause audio jumping
|
||||
if full_response.strip():
|
||||
# Add to conversation history
|
||||
self.conversation_history.append(
|
||||
Message(role="assistant", content=full_response)
|
||||
)
|
||||
|
||||
logger.info(f"✅ [AGENT] Generated complete response ({len(full_response)} chars)")
|
||||
|
||||
yield GeneratedResponse(
|
||||
message=full_response.strip(),
|
||||
is_interruptible=True
|
||||
)
|
||||
|
||||
def _build_gemini_contents(self) -> List[Dict]:
|
||||
"""
|
||||
Build conversation contents for Gemini API
|
||||
|
||||
Format:
|
||||
[
|
||||
{"role": "user", "parts": [{"text": "System: ..."}]},
|
||||
{"role": "model", "parts": [{"text": "Understood."}]},
|
||||
{"role": "user", "parts": [{"text": "Hello"}]},
|
||||
{"role": "model", "parts": [{"text": "Hi there!"}]},
|
||||
...
|
||||
]
|
||||
"""
|
||||
contents = []
|
||||
|
||||
# Add system prompt as first user message
|
||||
if self.system_prompt:
|
||||
contents.append({
|
||||
"role": "user",
|
||||
"parts": [{"text": f"System Instruction: {self.system_prompt}"}]
|
||||
})
|
||||
contents.append({
|
||||
"role": "model",
|
||||
"parts": [{"text": "Understood."}]
|
||||
})
|
||||
|
||||
# Add conversation history
|
||||
for message in self.conversation_history:
|
||||
role = "user" if message.role == "user" else "model"
|
||||
contents.append({
|
||||
"role": role,
|
||||
"parts": [{"text": message.content}]
|
||||
})
|
||||
|
||||
return contents
|
||||
|
||||
async def _simulate_gemini_stream(self, user_input: str) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Simulate Gemini streaming response
|
||||
|
||||
In a real implementation, this would be:
|
||||
|
||||
async def _create_gemini_stream(self, contents):
|
||||
response = await genai.GenerativeModel('gemini-pro').generate_content_async(
|
||||
contents,
|
||||
stream=True
|
||||
)
|
||||
async for chunk in response:
|
||||
if chunk.text:
|
||||
yield chunk.text
|
||||
"""
|
||||
# Simulate response
|
||||
response = f"I understand you said: {user_input}. How can I assist you further?"
|
||||
|
||||
# Simulate streaming by yielding chunks
|
||||
chunk_size = 10
|
||||
for i in range(0, len(response), chunk_size):
|
||||
chunk = response[i:i + chunk_size]
|
||||
await asyncio.sleep(0.05) # Simulate network delay
|
||||
yield chunk
|
||||
|
||||
def update_last_bot_message_on_cut_off(self, partial_message: str):
|
||||
"""
|
||||
Update conversation history when bot is interrupted
|
||||
|
||||
This ensures the conversation history reflects what was actually spoken,
|
||||
not what was planned to be spoken.
|
||||
|
||||
Args:
|
||||
partial_message: The partial message that was actually spoken
|
||||
"""
|
||||
if self.conversation_history and self.conversation_history[-1].role == "assistant":
|
||||
# Update the last bot message with the partial message
|
||||
self.conversation_history[-1].content = partial_message
|
||||
logger.info(f"📝 [AGENT] Updated history with partial message: '{partial_message}'")
|
||||
|
||||
def cancel_current_task(self):
|
||||
"""Cancel the current generation task (for interrupts)"""
|
||||
if self.current_task and not self.current_task.done():
|
||||
self.current_task.cancel()
|
||||
logger.info("🛑 [AGENT] Cancelled current generation task")
|
||||
|
||||
def get_conversation_history(self) -> List[Message]:
|
||||
"""Get the full conversation history"""
|
||||
return self.conversation_history.copy()
|
||||
|
||||
def clear_conversation_history(self):
|
||||
"""Clear the conversation history"""
|
||||
self.conversation_history.clear()
|
||||
logger.info("🗑️ [AGENT] Cleared conversation history")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Example Usage
|
||||
# ============================================================================
|
||||
|
||||
async def example_usage():
|
||||
"""Example of how to use the GeminiAgent"""
|
||||
|
||||
# Configure agent
|
||||
config = {
|
||||
"prompt": "You are a helpful AI assistant specializing in voice conversations.",
|
||||
"llmProvider": "gemini"
|
||||
}
|
||||
|
||||
# Create agent
|
||||
agent = GeminiAgent(config)
|
||||
|
||||
# Simulate conversation
|
||||
user_messages = [
|
||||
"Hello, how are you?",
|
||||
"What's the weather like today?",
|
||||
"Thank you!"
|
||||
]
|
||||
|
||||
for user_message in user_messages:
|
||||
print(f"\n👤 User: {user_message}")
|
||||
|
||||
# Generate response
|
||||
async for response in agent.generate_response(user_message):
|
||||
print(f"🤖 Bot: {response.message}")
|
||||
|
||||
# Print conversation history
|
||||
print("\n📜 Conversation History:")
|
||||
for i, message in enumerate(agent.get_conversation_history(), 1):
|
||||
print(f"{i}. {message.role}: {message.content}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(example_usage())
|
||||
@@ -0,0 +1,334 @@
|
||||
"""
|
||||
Example: Interrupt System Implementation
|
||||
|
||||
This example demonstrates how to implement a robust interrupt system
|
||||
that allows users to interrupt the bot mid-sentence.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
from typing import Any
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# InterruptibleEvent Pattern
|
||||
# ============================================================================
|
||||
|
||||
class InterruptibleEvent:
|
||||
"""
|
||||
Wrapper for events that can be interrupted
|
||||
|
||||
Every event in the pipeline is wrapped in an InterruptibleEvent,
|
||||
allowing the system to stop processing mid-stream.
|
||||
"""
|
||||
|
||||
def __init__(self, payload: Any, is_interruptible: bool = True):
|
||||
self.payload = payload
|
||||
self.is_interruptible = is_interruptible
|
||||
self.interruption_event = threading.Event() # Initially not set
|
||||
self.interrupted = False
|
||||
|
||||
def interrupt(self) -> bool:
|
||||
"""
|
||||
Interrupt this event
|
||||
|
||||
Returns:
|
||||
True if the event was interrupted, False if it was not interruptible
|
||||
"""
|
||||
if not self.is_interruptible:
|
||||
return False
|
||||
|
||||
if not self.interrupted:
|
||||
self.interruption_event.set() # Signal to stop!
|
||||
self.interrupted = True
|
||||
logger.info("⚠️ [INTERRUPT] Event interrupted")
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_interrupted(self) -> bool:
|
||||
"""Check if this event has been interrupted"""
|
||||
return self.interruption_event.is_set()
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Conversation with Interrupt Support
|
||||
# ============================================================================
|
||||
|
||||
class ConversationWithInterrupts:
|
||||
"""
|
||||
Conversation orchestrator with interrupt support
|
||||
|
||||
Key Features:
|
||||
- Tracks all in-flight interruptible events
|
||||
- Broadcasts interrupts to all workers
|
||||
- Cancels current tasks
|
||||
- Updates conversation history with partial messages
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.is_human_speaking = True
|
||||
self.interruptible_events = asyncio.Queue()
|
||||
self.agent = None # Set externally
|
||||
self.synthesizer_worker = None # Set externally
|
||||
|
||||
def broadcast_interrupt(self) -> bool:
|
||||
"""
|
||||
Broadcast interrupt to all in-flight events
|
||||
|
||||
This is called when the user starts speaking while the bot is speaking.
|
||||
|
||||
Returns:
|
||||
True if any events were interrupted
|
||||
"""
|
||||
num_interrupts = 0
|
||||
|
||||
# Interrupt all queued events
|
||||
while True:
|
||||
try:
|
||||
interruptible_event = self.interruptible_events.get_nowait()
|
||||
if interruptible_event.interrupt():
|
||||
num_interrupts += 1
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
# Cancel current tasks
|
||||
if self.agent:
|
||||
self.agent.cancel_current_task()
|
||||
|
||||
if self.synthesizer_worker:
|
||||
self.synthesizer_worker.cancel_current_task()
|
||||
|
||||
logger.info(f"⚠️ [INTERRUPT] Interrupted {num_interrupts} events")
|
||||
|
||||
return num_interrupts > 0
|
||||
|
||||
def add_interruptible_event(self, event: InterruptibleEvent):
|
||||
"""Add an event to the interruptible queue"""
|
||||
self.interruptible_events.put_nowait(event)
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Synthesis Worker with Interrupt Support
|
||||
# ============================================================================
|
||||
|
||||
class SynthesisWorkerWithInterrupts:
|
||||
"""
|
||||
Synthesis worker that supports interrupts
|
||||
|
||||
Key Features:
|
||||
- Checks for interrupts before sending each audio chunk
|
||||
- Calculates partial message when interrupted
|
||||
- Updates agent's conversation history with partial message
|
||||
"""
|
||||
|
||||
def __init__(self, agent, output_device):
|
||||
self.agent = agent
|
||||
self.output_device = output_device
|
||||
self.current_task = None
|
||||
|
||||
async def send_speech_to_output(
|
||||
self,
|
||||
message: str,
|
||||
synthesis_result,
|
||||
stop_event: threading.Event,
|
||||
seconds_per_chunk: float = 0.1
|
||||
) -> tuple[str, bool]:
|
||||
"""
|
||||
Send synthesized speech to output with interrupt support
|
||||
|
||||
Args:
|
||||
message: The full message being synthesized
|
||||
synthesis_result: SynthesisResult with chunk_generator and get_message_up_to
|
||||
stop_event: Event that signals when to stop (interrupt)
|
||||
seconds_per_chunk: Duration of each audio chunk in seconds
|
||||
|
||||
Returns:
|
||||
Tuple of (message_sent, was_cut_off)
|
||||
- message_sent: The actual message sent (partial if interrupted)
|
||||
- was_cut_off: True if interrupted, False if completed
|
||||
"""
|
||||
chunk_idx = 0
|
||||
|
||||
async for chunk_result in synthesis_result.chunk_generator:
|
||||
# CRITICAL: Check for interrupt before sending each chunk
|
||||
if stop_event.is_set():
|
||||
logger.info(f"🛑 [SYNTHESIZER] Interrupted after {chunk_idx} chunks")
|
||||
|
||||
# Calculate what was actually spoken
|
||||
seconds_spoken = chunk_idx * seconds_per_chunk
|
||||
partial_message = synthesis_result.get_message_up_to(seconds_spoken)
|
||||
|
||||
logger.info(f"📝 [SYNTHESIZER] Partial message: '{partial_message}'")
|
||||
|
||||
return partial_message, True # cut_off = True
|
||||
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
# Send chunk to output device
|
||||
await self.output_device.consume_nonblocking(chunk_result.chunk)
|
||||
|
||||
# CRITICAL: Wait for chunk to play before sending next one
|
||||
# This is what makes interrupts work!
|
||||
processing_time = asyncio.get_event_loop().time() - start_time
|
||||
await asyncio.sleep(max(seconds_per_chunk - processing_time, 0))
|
||||
|
||||
chunk_idx += 1
|
||||
|
||||
# Completed without interruption
|
||||
logger.info(f"✅ [SYNTHESIZER] Completed {chunk_idx} chunks")
|
||||
return message, False # cut_off = False
|
||||
|
||||
def cancel_current_task(self):
|
||||
"""Cancel the current synthesis task"""
|
||||
if self.current_task and not self.current_task.done():
|
||||
self.current_task.cancel()
|
||||
logger.info("🛑 [SYNTHESIZER] Cancelled current task")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Transcription Worker with Interrupt Detection
|
||||
# ============================================================================
|
||||
|
||||
class TranscriptionWorkerWithInterrupts:
|
||||
"""
|
||||
Transcription worker that detects interrupts
|
||||
|
||||
Key Features:
|
||||
- Detects when user speaks while bot is speaking
|
||||
- Marks transcription as interrupt
|
||||
- Triggers broadcast_interrupt()
|
||||
"""
|
||||
|
||||
def __init__(self, conversation):
|
||||
self.conversation = conversation
|
||||
|
||||
async def process(self, transcription):
|
||||
"""
|
||||
Process transcription and detect interrupts
|
||||
|
||||
If the user starts speaking while the bot is speaking,
|
||||
this is an interrupt.
|
||||
"""
|
||||
|
||||
# Check if this is an interrupt
|
||||
if not self.conversation.is_human_speaking:
|
||||
logger.info("⚠️ [TRANSCRIPTION] User interrupted bot!")
|
||||
|
||||
# Broadcast interrupt to all in-flight events
|
||||
interrupted = self.conversation.broadcast_interrupt()
|
||||
transcription.is_interrupt = interrupted
|
||||
|
||||
# Update speaking state
|
||||
self.conversation.is_human_speaking = True
|
||||
|
||||
# Continue processing transcription...
|
||||
logger.info(f"🎤 [TRANSCRIPTION] Received: '{transcription.message}'")
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Example Usage
|
||||
# ============================================================================
|
||||
|
||||
@dataclass
|
||||
class MockTranscription:
|
||||
message: str
|
||||
is_interrupt: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class MockSynthesisResult:
|
||||
async def chunk_generator(self):
|
||||
"""Generate mock audio chunks"""
|
||||
for i in range(10):
|
||||
await asyncio.sleep(0.1)
|
||||
yield type('obj', (object,), {'chunk': b'\x00' * 1024})()
|
||||
|
||||
def get_message_up_to(self, seconds: float) -> str:
|
||||
"""Get partial message up to specified seconds"""
|
||||
full_message = "I think the weather will be nice today and tomorrow and the day after."
|
||||
chars_per_second = len(full_message) / 1.0 # Assume 1 second total
|
||||
char_index = int(seconds * chars_per_second)
|
||||
return full_message[:char_index]
|
||||
|
||||
|
||||
async def example_interrupt_scenario():
|
||||
"""
|
||||
Example scenario: User interrupts bot mid-sentence
|
||||
"""
|
||||
|
||||
print("🎬 Scenario: User interrupts bot mid-sentence\n")
|
||||
|
||||
# Create conversation
|
||||
conversation = ConversationWithInterrupts()
|
||||
|
||||
# Create mock components
|
||||
class MockAgent:
|
||||
def cancel_current_task(self):
|
||||
print("🛑 [AGENT] Task cancelled")
|
||||
|
||||
def update_last_bot_message_on_cut_off(self, partial_message):
|
||||
print(f"📝 [AGENT] Updated history: '{partial_message}'")
|
||||
|
||||
class MockOutputDevice:
|
||||
async def consume_nonblocking(self, chunk):
|
||||
pass
|
||||
|
||||
agent = MockAgent()
|
||||
output_device = MockOutputDevice()
|
||||
conversation.agent = agent
|
||||
|
||||
# Create synthesis worker
|
||||
synthesis_worker = SynthesisWorkerWithInterrupts(agent, output_device)
|
||||
conversation.synthesizer_worker = synthesis_worker
|
||||
|
||||
# Create interruptible event
|
||||
stop_event = threading.Event()
|
||||
interruptible_event = InterruptibleEvent(
|
||||
payload="Bot is speaking...",
|
||||
is_interruptible=True
|
||||
)
|
||||
conversation.add_interruptible_event(interruptible_event)
|
||||
|
||||
# Start bot speaking
|
||||
print("🤖 Bot starts speaking: 'I think the weather will be nice today and tomorrow and the day after.'\n")
|
||||
conversation.is_human_speaking = False
|
||||
|
||||
# Simulate synthesis in background
|
||||
synthesis_result = MockSynthesisResult()
|
||||
synthesis_task = asyncio.create_task(
|
||||
synthesis_worker.send_speech_to_output(
|
||||
message="I think the weather will be nice today and tomorrow and the day after.",
|
||||
synthesis_result=synthesis_result,
|
||||
stop_event=stop_event,
|
||||
seconds_per_chunk=0.1
|
||||
)
|
||||
)
|
||||
|
||||
# Wait a bit, then interrupt
|
||||
await asyncio.sleep(0.3)
|
||||
|
||||
print("👤 User interrupts: 'Stop!'\n")
|
||||
|
||||
# Trigger interrupt
|
||||
conversation.broadcast_interrupt()
|
||||
stop_event.set()
|
||||
|
||||
# Wait for synthesis to finish
|
||||
message_sent, was_cut_off = await synthesis_task
|
||||
|
||||
print(f"\n✅ Result:")
|
||||
print(f" - Message sent: '{message_sent}'")
|
||||
print(f" - Was cut off: {was_cut_off}")
|
||||
|
||||
# Update agent history
|
||||
if was_cut_off:
|
||||
agent.update_last_bot_message_on_cut_off(message_sent)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(example_interrupt_scenario())
|
||||
Reference in New Issue
Block a user