Conversation Memory Manager
from dataclasses import dataclass, field
from typing import Optional
from enum import Enum
import tiktoken
import time
class MemoryStrategy(str, Enum):
SLIDING_WINDOW = "sliding_window" # Keep last N messages
TOKEN_LIMIT = "token_limit" # Keep within token budget
SUMMARY_BUFFER = "summary_buffer" # Summarise old messages
@dataclass
class Message:
role: str # "system", "user", "assistant"
content: str
timestamp: float = field(default_factory=time.time)
token_count: int = 0
class ConversationMemory:
'''
Manages conversation history with multiple strategies
to stay within LLM context window limits.
'''
def __init__(
self,
system_prompt: Optional[str] = None,
strategy: MemoryStrategy = MemoryStrategy.TOKEN_LIMIT,
max_tokens: int = 4096, # Keep within this token budget
max_messages: int = 20, # For sliding window
model: str = "gpt-4o",
summarise_fn=None, # Callable for summary strategy
):
self.strategy = strategy
self.max_tokens = max_tokens
self.max_messages = max_messages
self.summarise_fn = summarise_fn
self._enc = tiktoken.encoding_for_model(model)
self._messages: list[Message] = []
self._summary: Optional[str] = None # For summary buffer strategy
if system_prompt:
self.add("system", system_prompt)
def _count_tokens(self, text: str) -> int:
return len(self._enc.encode(text))
def add(self, role: str, content: str) -> None:
'''Add a message and trim history if needed.'''
msg = Message(
role=role,
content=content,
token_count=self._count_tokens(content),
)
self._messages.append(msg)
self._trim()
def _trim(self) -> None:
if self.strategy == MemoryStrategy.SLIDING_WINDOW:
self._trim_sliding_window()
elif self.strategy == MemoryStrategy.TOKEN_LIMIT:
self._trim_by_tokens()
elif self.strategy == MemoryStrategy.SUMMARY_BUFFER:
self._trim_with_summary()
def _trim_sliding_window(self) -> None:
'''Keep system message + last N messages.'''
system = [m for m in self._messages if m.role == "system"]
non_system = [m for m in self._messages if m.role != "system"]
self._messages = system + non_system[-self.max_messages:]
def _trim_by_tokens(self) -> None:
'''Remove oldest non-system messages until under token budget.'''
while self.total_tokens > self.max_tokens and len(self._messages) > 1:
# Find first non-system message to remove
for i, msg in enumerate(self._messages):
if msg.role != "system":
self._messages.pop(i)
break
def _trim_with_summary(self) -> None:
'''Summarise old messages when approaching token limit.'''
if self.total_tokens <= self.max_tokens:
return
if not self.summarise_fn:
self._trim_by_tokens()
return
# Summarise messages beyond first 4
system = [m for m in self._messages if m.role == "system"]
non_system = [m for m in self._messages if m.role != "system"]
if len(non_system) <= 4:
return
to_summarise = non_system[:-4]
to_keep = non_system[-4:]
summary_text = self.summarise_fn(
[{"role": m.role, "content": m.content} for m in to_summarise]
)
self._summary = summary_text
# Rebuild messages with summary as context
summary_msg = Message(
role="system",
content=f"[Earlier conversation summary]: {summary_text}",
token_count=self._count_tokens(summary_text),
)
self._messages = system + [summary_msg] + to_keep
@property
def total_tokens(self) -> int:
return sum(m.token_count for m in self._messages)
@property
def messages(self) -> list[dict]:
'''Return messages in format expected by OpenAI API.'''
return [{"role": m.role, "content": m.content} for m in self._messages]
def clear(self, keep_system: bool = True) -> None:
'''Reset conversation history.'''
if keep_system:
self._messages = [m for m in self._messages if m.role == "system"]
else:
self._messages = []
def __repr__(self) -> str:
return (f"ConversationMemory({len(self._messages)} messages, "
f"{self.total_tokens} tokens)")
# Usage
if __name__ == "__main__":
from openai import OpenAI
client = OpenAI()
memory = ConversationMemory(
system_prompt="You are a helpful AI assistant.",
strategy=MemoryStrategy.TOKEN_LIMIT,
max_tokens=4096,
)
def chat(user_input: str) -> str:
memory.add("user", user_input)
response = client.chat.completions.create(
model="gpt-4o",
messages=memory.messages,
).choices[0].message.content
memory.add("assistant", response)
return response
print(chat("What is RAG?"))
print(chat("How does it differ from fine-tuning?"))
print(f"Memory state: {memory}")