Concept #42Hardpractical-coding

Write a streaming response handler for LLM output.

#gen-ai#python#llm

Answer

Streaming Response Handler for LLM Output

python
import sys
import time
from typing import Iterator, Callable, Optional
from openai import OpenAI
from dataclasses import dataclass, field


@dataclass
class StreamingStats:
    total_tokens: int = 0
    elapsed_ms: float = 0.0
    tokens_per_second: float = 0.0
    first_token_ms: float = 0.0


def stream_to_stdout(prompt: str, model: str = "gpt-4o") -> StreamingStats:
    '''Stream LLM response directly to stdout, return timing stats.'''
    client = OpenAI()
    stats = StreamingStats()
    start = time.perf_counter()
    first_token_time = None

    stream = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        stream=True,
    )

    for chunk in stream:
        token = chunk.choices[0].delta.content
        if token:
            if first_token_time is None:
                first_token_time = time.perf_counter()
                stats.first_token_ms = (first_token_time - start) * 1000

            sys.stdout.write(token)
            sys.stdout.flush()
            stats.total_tokens += 1

    stats.elapsed_ms = (time.perf_counter() - start) * 1000
    stats.tokens_per_second = stats.total_tokens / (stats.elapsed_ms / 1000)
    print()  # Newline after stream
    return stats


def stream_with_callback(
    prompt: str,
    on_token: Callable[[str], None],
    on_complete: Optional[Callable[[str], None]] = None,
    model: str = "gpt-4o",
) -> str:
    '''
    Stream tokens, calling on_token for each. Returns full text.
    Useful for WebSocket / SSE handlers.
    '''
    client = OpenAI()
    full_response = []

    stream = client.chat.completions.create(
        model=model,
        messages=[{"role": "user", "content": prompt}],
        stream=True,
    )

    for chunk in stream:
        token = chunk.choices[0].delta.content
        if token:
            on_token(token)
            full_response.append(token)

    complete_text = "".join(full_response)
    if on_complete:
        on_complete(complete_text)
    return complete_text


def stream_generator(
    prompt: str,
    system: Optional[str] = None,
    model: str = "gpt-4o",
) -> Iterator[str]:
    '''
    Generator that yields tokens one at a time.
    Ideal for FastAPI StreamingResponse or async consumers.
    '''
    client = OpenAI()
    messages = []
    if system:
        messages.append({"role": "system", "content": system})
    messages.append({"role": "user", "content": prompt})

    stream = client.chat.completions.create(
        model=model,
        messages=messages,
        stream=True,
    )

    for chunk in stream:
        token = chunk.choices[0].delta.content
        if token:
            yield token


# FastAPI SSE endpoint example
def fastapi_streaming_example():
    from fastapi import FastAPI
    from fastapi.responses import StreamingResponse

    app = FastAPI()

    def sse_generator(prompt: str):
        '''Wrap token stream in Server-Sent Events format.'''
        for token in stream_generator(prompt):
            # Escape newlines for SSE protocol
            token_escaped = token.replace("\n", "\\n")
            yield f"data: {token_escaped}\n\n"
        yield "data: [DONE]\n\n"

    @app.post("/stream")
    async def stream_endpoint(prompt: str):
        return StreamingResponse(
            sse_generator(prompt),
            media_type="text/event-stream",
            headers={
                "Cache-Control": "no-cache",
                "X-Accel-Buffering": "no",  # Disable nginx buffering
            }
        )
    return app


# Async streaming with error handling
async def async_stream(prompt: str, model: str = "gpt-4o") -> AsyncIterator[str]:
    from openai import AsyncOpenAI
    client = AsyncOpenAI()

    try:
        async with client.chat.completions.stream(
            model=model,
            messages=[{"role": "user", "content": prompt}],
        ) as stream:
            async for chunk in stream:
                token = chunk.choices[0].delta.content
                if token:
                    yield token
    except Exception as e:
        yield f"[Error: {str(e)}]"


# Demo
if __name__ == "__main__":
    print("Streaming response:")
    print("-" * 40)
    stats = stream_to_stdout("Explain RAG in 3 sentences.")
    print(f"\nStats: {stats.total_tokens} tokens in {stats.elapsed_ms:.0f}ms "
          f"({stats.tokens_per_second:.1f} tok/s), "
          f"first token: {stats.first_token_ms:.0f}ms")