Skip to Content

Graph Memory Pattern

The Graph Memory pattern models agent interactions and knowledge as a network of interconnected entities, relationships, and concepts. This approach excels at capturing complex relationships, enabling reasoning over connected information, and discovering indirect associations that linear memory patterns might miss.

Overview

The Graph Memory pattern represents information as nodes (entities) connected by edges (relationships), creating a knowledge graph that grows with each interaction. Key components include:

  • Entities: People, places, concepts, objects mentioned in conversations
  • Relationships: Connections between entities (works_at, located_in, similar_to)
  • Attributes: Properties of entities (age, color, category)
  • Temporal Information: When relationships were established or modified
  • Confidence Scores: Reliability measures for entities and relationships

This pattern enables sophisticated reasoning through graph traversal, relationship inference, and multi-hop question answering.

Architecture

import networkx as nx from datetime import datetime from typing import Dict, List, Set, Tuple, Optional import json class GraphMemory: def __init__(self): self.graph = nx.MultiDiGraph() # Directed graph with multiple edges self.entity_index = {} # entity_name -> node_id mapping self.interaction_log = [] # Track source interactions self.confidence_threshold = 0.5 def add_entity(self, entity_name: str, entity_type: str = "generic", attributes: Dict = None, confidence: float = 1.0): """Add or update an entity in the graph""" entity_id = self._get_or_create_entity_id(entity_name) # Update node attributes node_attrs = { 'name': entity_name, 'type': entity_type, 'confidence': confidence, 'created_at': datetime.now().isoformat(), 'last_updated': datetime.now().isoformat(), **(attributes or {}) } if entity_id in self.graph: # Merge with existing attributes existing_attrs = self.graph.nodes[entity_id] existing_attrs.update(node_attrs) existing_attrs['last_updated'] = datetime.now().isoformat() else: self.graph.add_node(entity_id, **node_attrs) return entity_id def add_relationship(self, source_entity: str, target_entity: str, relationship_type: str, confidence: float = 1.0, attributes: Dict = None): """Add a relationship between two entities""" source_id = self.add_entity(source_entity) target_id = self.add_entity(target_entity) edge_attrs = { 'type': relationship_type, 'confidence': confidence, 'created_at': datetime.now().isoformat(), 'weight': confidence, # For weighted graph algorithms **(attributes or {}) } self.graph.add_edge(source_id, target_id, **edge_attrs) def _get_or_create_entity_id(self, entity_name: str) -> str: """Get existing entity ID or create a new one""" normalized_name = entity_name.lower().strip() if normalized_name not in self.entity_index: entity_id = f"entity_{len(self.entity_index)}" self.entity_index[normalized_name] = entity_id return self.entity_index[normalized_name] def find_path(self, source: str, target: str, max_hops: int = 3) -> List[List[str]]: """Find paths between two entities""" source_id = self.entity_index.get(source.lower()) target_id = self.entity_index.get(target.lower()) if not source_id or not target_id: return [] try: # Find shortest paths up to max_hops paths = [] for path in nx.all_simple_paths(self.graph, source_id, target_id, cutoff=max_hops): # Convert IDs back to entity names entity_path = [self.graph.nodes[node_id]['name'] for node_id in path] paths.append(entity_path) return paths[:10] # Limit to top 10 paths except nx.NetworkXNoPath: return [] def get_neighbors(self, entity_name: str, relationship_types: List[str] = None, max_distance: int = 1) -> Dict: """Get neighboring entities and their relationships""" entity_id = self.entity_index.get(entity_name.lower()) if not entity_id: return {} neighbors = {} if max_distance == 1: # Direct neighbors only for neighbor_id in self.graph.neighbors(entity_id): neighbor_name = self.graph.nodes[neighbor_id]['name'] edge_data = self.graph.get_edge_data(entity_id, neighbor_id) # Filter by relationship type if specified valid_edges = [] for edge_key, edge_attrs in edge_data.items(): if not relationship_types or edge_attrs['type'] in relationship_types: valid_edges.append(edge_attrs) if valid_edges: neighbors[neighbor_name] = { 'entity': self.graph.nodes[neighbor_id], 'relationships': valid_edges } else: # Multi-hop neighbors using BFS visited = set() queue = [(entity_id, 0)] visited.add(entity_id) while queue: current_id, distance = queue.pop(0) if distance >= max_distance: continue for neighbor_id in self.graph.neighbors(current_id): if neighbor_id not in visited: neighbor_name = self.graph.nodes[neighbor_id]['name'] neighbors[neighbor_name] = { 'entity': self.graph.nodes[neighbor_id], 'distance': distance + 1 } visited.add(neighbor_id) queue.append((neighbor_id, distance + 1)) return neighbors

Implementation Considerations

Entity Extraction and Recognition

Named Entity Recognition (NER)

import spacy class EntityExtractor: def __init__(self): self.nlp = spacy.load("en_core_web_sm") self.custom_entities = {} def extract_entities(self, text: str) -> List[Dict]: """Extract entities from text using NER""" doc = self.nlp(text) entities = [] for ent in doc.ents: entity = { 'text': ent.text, 'label': ent.label_, 'start': ent.start_char, 'end': ent.end_char, 'confidence': self._calculate_confidence(ent) } entities.append(entity) return entities def _calculate_confidence(self, entity) -> float: """Calculate confidence score for extracted entity""" # Simple heuristic - can be improved with ML models base_confidence = 0.8 length_bonus = min(0.2, len(entity.text) / 50) return min(1.0, base_confidence + length_bonus) def add_custom_entity_patterns(self, patterns: Dict): """Add custom entity recognition patterns""" self.custom_entities.update(patterns)

Relationship Extraction

class RelationshipExtractor: def __init__(self): self.dependency_patterns = { 'owns': ['nsubj', 'dobj'], 'works_at': ['nsubj', 'prep_at'], 'lives_in': ['nsubj', 'prep_in'], 'is_a': ['nsubj', 'attr'], 'likes': ['nsubj', 'dobj'] } def extract_relationships(self, text: str, entities: List[Dict]) -> List[Dict]: """Extract relationships between entities in text""" doc = self.nlp(text) relationships = [] # Pattern-based relationship extraction for sent in doc.sents: relations = self._extract_from_dependencies(sent, entities) relationships.extend(relations) return relationships def _extract_from_dependencies(self, sentence, entities) -> List[Dict]: """Extract relationships using dependency parsing""" relationships = [] for token in sentence: if token.dep_ in ['nsubj', 'nsubjpass']: # Look for object relationships for child in token.head.children: if child.dep_ in ['dobj', 'attr', 'prep']: subject = self._find_entity_for_token(token, entities) obj = self._find_entity_for_token(child, entities) if subject and obj: rel_type = self._determine_relationship_type(token.head, child) relationships.append({ 'subject': subject['text'], 'predicate': rel_type, 'object': obj['text'], 'confidence': 0.7 }) return relationships

Graph Storage and Persistence

File-Based Storage

class FileGraphStorage: def __init__(self, filepath: str): self.filepath = filepath def save_graph(self, graph_memory: GraphMemory): """Save graph to JSON file""" graph_data = { 'nodes': dict(graph_memory.graph.nodes(data=True)), 'edges': [ { 'source': u, 'target': v, 'key': k, 'attributes': d } for u, v, k, d in graph_memory.graph.edges(keys=True, data=True) ], 'entity_index': graph_memory.entity_index, 'interaction_log': graph_memory.interaction_log } with open(self.filepath, 'w') as f: json.dump(graph_data, f, indent=2, default=str) def load_graph(self) -> GraphMemory: """Load graph from JSON file""" with open(self.filepath, 'r') as f: graph_data = json.load(f) graph_memory = GraphMemory() # Restore nodes for node_id, attributes in graph_data['nodes'].items(): graph_memory.graph.add_node(node_id, **attributes) # Restore edges for edge in graph_data['edges']: graph_memory.graph.add_edge( edge['source'], edge['target'], key=edge['key'], **edge['attributes'] ) # Restore indices graph_memory.entity_index = graph_data['entity_index'] graph_memory.interaction_log = graph_data.get('interaction_log', []) return graph_memory

Database Storage (Neo4j)

from neo4j import GraphDatabase class Neo4jGraphStorage: def __init__(self, uri: str, username: str, password: str): self.driver = GraphDatabase.driver(uri, auth=(username, password)) def save_entity(self, entity_name: str, entity_type: str, attributes: Dict): """Save entity to Neo4j""" with self.driver.session() as session: query = """ MERGE (e:Entity {name: $name}) SET e.type = $type, e += $attributes RETURN e """ session.run(query, name=entity_name, type=entity_type, attributes=attributes) def save_relationship(self, source: str, target: str, rel_type: str, attributes: Dict): """Save relationship to Neo4j""" with self.driver.session() as session: query = """ MATCH (s:Entity {name: $source}) MATCH (t:Entity {name: $target}) MERGE (s)-[r:RELATES]->(t) SET r.type = $rel_type, r += $attributes RETURN r """ session.run(query, source=source, target=target, rel_type=rel_type, attributes=attributes) def find_path(self, source: str, target: str, max_hops: int = 3): """Find path using Cypher query""" with self.driver.session() as session: query = """ MATCH path = shortestPath((s:Entity {name: $source})-[:RELATES*1..""" + str(max_hops) + """]-(t:Entity {name: $target})) RETURN [node in nodes(path) | node.name] as path """ result = session.run(query, source=source, target=target) return [record['path'] for record in result]

Graph Analysis and Reasoning

Centrality Analysis

class GraphAnalyzer: def __init__(self, graph_memory: GraphMemory): self.graph_memory = graph_memory def find_central_entities(self, centrality_type: str = 'degree') -> Dict[str, float]: """Find most central entities in the graph""" if centrality_type == 'degree': centrality = nx.degree_centrality(self.graph_memory.graph) elif centrality_type == 'betweenness': centrality = nx.betweenness_centrality(self.graph_memory.graph) elif centrality_type == 'pagerank': centrality = nx.pagerank(self.graph_memory.graph) else: raise ValueError(f"Unknown centrality type: {centrality_type}") # Convert node IDs to entity names entity_centrality = {} for node_id, score in centrality.items(): entity_name = self.graph_memory.graph.nodes[node_id]['name'] entity_centrality[entity_name] = score return dict(sorted(entity_centrality.items(), key=lambda x: x[1], reverse=True)) def find_communities(self) -> List[List[str]]: """Detect communities in the graph""" # Convert to undirected graph for community detection undirected = self.graph_memory.graph.to_undirected() communities = nx.community.greedy_modularity_communities(undirected) # Convert node IDs to entity names entity_communities = [] for community in communities: entity_names = [ self.graph_memory.graph.nodes[node_id]['name'] for node_id in community ] entity_communities.append(entity_names) return entity_communities def suggest_relationships(self, entity: str, top_k: int = 5) -> List[Dict]: """Suggest potential new relationships based on graph structure""" entity_id = self.graph_memory.entity_index.get(entity.lower()) if not entity_id: return [] suggestions = [] # Find entities connected through common neighbors neighbors = set(self.graph_memory.graph.neighbors(entity_id)) for neighbor_id in neighbors: neighbor_neighbors = set(self.graph_memory.graph.neighbors(neighbor_id)) # Potential connections are neighbors of neighbors not directly connected potential = neighbor_neighbors - neighbors - {entity_id} for potential_id in potential: potential_entity = self.graph_memory.graph.nodes[potential_id]['name'] # Calculate connection strength based on common neighbors common_count = len(neighbors & set(self.graph_memory.graph.neighbors(potential_id))) suggestions.append({ 'entity': potential_entity, 'strength': common_count, 'reason': f"Common connections through {self.graph_memory.graph.nodes[neighbor_id]['name']}" }) # Sort by connection strength and return top suggestions suggestions.sort(key=lambda x: x['strength'], reverse=True) return suggestions[:top_k]

Performance Characteristics

Pros

  • Rich Relationships: Captures complex multi-entity relationships
  • Inference Capability: Discover indirect connections and patterns
  • Flexible Structure: Adapts to various domain knowledge structures
  • Query Power: Support complex graph queries and traversals
  • Knowledge Integration: Merge information from multiple sources

Cons

  • Construction Complexity: Requires sophisticated entity/relationship extraction
  • Storage Overhead: More complex storage than linear patterns
  • Query Complexity: Graph queries can be computationally expensive
  • Maintenance Burden: Keeping graph accurate and consistent over time
  • Cold Start Problem: Requires substantial data to be effective

Performance Metrics

# Typical performance characteristics INSERTION_TIME = "O(1)" # Per entity/relationship PATH_FINDING = "O(V + E)" # Breadth-first search CENTRALITY_ANALYSIS = "O(VÂł)" # Depending on algorithm STORAGE_OVERHEAD = "2-5x" # Compared to simple key-value storage QUERY_COMPLEXITY = "O(V^k)" # Where k is max path length

When to Use

Ideal Scenarios

  • Knowledge-intensive applications requiring reasoning
  • Multi-entity domains with complex relationships
  • Question-answering systems needing inference
  • Recommendation engines based on entity relationships
  • Domain modeling for scientific or technical fields
  • Simple conversational agents without complex entities
  • High-frequency, low-latency systems requiring fast responses
  • Resource-constrained environments with limited processing power
  • Linear workflows without relationship dependencies

Implementation Examples

Conversational Graph Memory

class ConversationalGraphMemory: def __init__(self): self.graph_memory = GraphMemory() self.entity_extractor = EntityExtractor() self.relation_extractor = RelationshipExtractor() def process_interaction(self, user_input: str, agent_response: str): """Process a conversation turn and update graph""" interaction_id = len(self.graph_memory.interaction_log) # Extract entities from both input and response user_entities = self.entity_extractor.extract_entities(user_input) agent_entities = self.entity_extractor.extract_entities(agent_response) all_entities = user_entities + agent_entities # Add entities to graph for entity in all_entities: self.graph_memory.add_entity( entity['text'], entity['label'], confidence=entity['confidence'] ) # Extract relationships combined_text = f"{user_input} {agent_response}" relationships = self.relation_extractor.extract_relationships(combined_text, all_entities) # Add relationships to graph for rel in relationships: self.graph_memory.add_relationship( rel['subject'], rel['object'], rel['predicate'], confidence=rel['confidence'] ) # Log interaction self.graph_memory.interaction_log.append({ 'id': interaction_id, 'user_input': user_input, 'agent_response': agent_response, 'entities': [e['text'] for e in all_entities], 'relationships': relationships, 'timestamp': datetime.now().isoformat() }) def answer_graph_query(self, query: str) -> str: """Answer questions using graph knowledge""" # Extract entities from query query_entities = self.entity_extractor.extract_entities(query) if len(query_entities) < 2: return "I need at least two entities to find relationships." entity1 = query_entities[0]['text'] entity2 = query_entities[1]['text'] # Find paths between entities paths = self.graph_memory.find_path(entity1, entity2, max_hops=3) if not paths: return f"I don't know of any connection between {entity1} and {entity2}." # Format the response shortest_path = paths[0] if len(shortest_path) == 2: return f"{entity1} is directly connected to {entity2}." else: path_str = " -> ".join(shortest_path) return f"I found this connection: {path_str}"

Domain-Specific Graph Memory

class ProductKnowledgeGraph(GraphMemory): def __init__(self): super().__init__() self.product_schema = { 'Product': ['name', 'category', 'price', 'brand'], 'Category': ['name', 'parent_category'], 'Brand': ['name', 'country'], 'User': ['name', 'preferences'] } def add_product(self, product_info: Dict): """Add a product with structured information""" product_id = self.add_entity( product_info['name'], 'Product', attributes={ 'category': product_info.get('category'), 'price': product_info.get('price'), 'brand': product_info.get('brand') } ) # Add relationships if product_info.get('category'): category_id = self.add_entity(product_info['category'], 'Category') self.add_relationship(product_info['name'], product_info['category'], 'belongs_to_category') if product_info.get('brand'): brand_id = self.add_entity(product_info['brand'], 'Brand') self.add_relationship(product_info['name'], product_info['brand'], 'manufactured_by') return product_id def find_similar_products(self, product_name: str, similarity_types: List[str] = None) -> List[str]: """Find products similar to the given product""" if similarity_types is None: similarity_types = ['belongs_to_category', 'manufactured_by'] similar_products = set() # Find products sharing relationships for rel_type in similarity_types: neighbors = self.get_neighbors(product_name, [rel_type]) for neighbor_name, neighbor_data in neighbors.items(): if neighbor_data['entity']['type'] in ['Category', 'Brand']: # Find other products related to this category/brand category_products = self.get_neighbors(neighbor_name, ['belongs_to_category', 'manufactured_by']) for prod_name, prod_data in category_products.items(): if (prod_data['entity']['type'] == 'Product' and prod_name != product_name): similar_products.add(prod_name) return list(similar_products) def recommend_products(self, user_name: str, top_k: int = 5) -> List[Dict]: """Recommend products based on user preferences and graph structure""" # Get user's purchase/preference history user_neighbors = self.get_neighbors(user_name, ['purchased', 'viewed', 'liked']) # Collect categories and brands from user history preferred_categories = set() preferred_brands = set() for item_name, item_data in user_neighbors.items(): item_neighbors = self.get_neighbors(item_name, ['belongs_to_category', 'manufactured_by']) for neighbor_name, neighbor_data in item_neighbors.items(): if neighbor_data['entity']['type'] == 'Category': preferred_categories.add(neighbor_name) elif neighbor_data['entity']['type'] == 'Brand': preferred_brands.add(neighbor_name) # Find products matching preferences recommendations = [] for category in preferred_categories: category_products = self.get_neighbors(category, ['belongs_to_category']) for prod_name, prod_data in category_products.items(): if prod_name not in user_neighbors: # Don't recommend already purchased score = self._calculate_recommendation_score(prod_name, preferred_categories, preferred_brands) recommendations.append({ 'product': prod_name, 'score': score, 'reason': f"Similar to your category preferences" }) # Sort by score and return top recommendations recommendations.sort(key=lambda x: x['score'], reverse=True) return recommendations[:top_k] def _calculate_recommendation_score(self, product_name: str, preferred_categories: Set[str], preferred_brands: Set[str]) -> float: """Calculate recommendation score for a product""" score = 0.0 product_neighbors = self.get_neighbors(product_name, ['belongs_to_category', 'manufactured_by']) for neighbor_name, neighbor_data in product_neighbors.items(): if neighbor_data['entity']['type'] == 'Category' and neighbor_name in preferred_categories: score += 0.7 elif neighbor_data['entity']['type'] == 'Brand' and neighbor_name in preferred_brands: score += 0.5 return score

Temporal Graph Memory

class TemporalGraphMemory(GraphMemory): def __init__(self): super().__init__() self.time_windows = {} # entity -> list of time windows def add_temporal_relationship(self, source: str, target: str, rel_type: str, start_time: datetime, end_time: datetime = None): """Add a relationship with temporal validity""" self.add_relationship(source, target, rel_type, attributes={ 'start_time': start_time.isoformat(), 'end_time': end_time.isoformat() if end_time else None, 'is_temporal': True }) def get_relationships_at_time(self, entity: str, query_time: datetime) -> List[Dict]: """Get relationships valid at a specific time""" entity_id = self.entity_index.get(entity.lower()) if not entity_id: return [] valid_relationships = [] for neighbor_id in self.graph.neighbors(entity_id): edge_data = self.graph.get_edge_data(entity_id, neighbor_id) for edge_key, edge_attrs in edge_data.items(): if edge_attrs.get('is_temporal'): start_time = datetime.fromisoformat(edge_attrs['start_time']) end_time_str = edge_attrs.get('end_time') end_time = datetime.fromisoformat(end_time_str) if end_time_str else datetime.now() if start_time <= query_time <= end_time: neighbor_name = self.graph.nodes[neighbor_id]['name'] valid_relationships.append({ 'target': neighbor_name, 'type': edge_attrs['type'], 'start_time': start_time, 'end_time': end_time }) return valid_relationships def evolve_relationships(self): """Update relationship strengths based on temporal patterns""" current_time = datetime.now() for u, v, key, data in self.graph.edges(keys=True, data=True): if 'last_accessed' in data: last_access = datetime.fromisoformat(data['last_accessed']) days_since_access = (current_time - last_access).days # Decay relationship strength over time decay_factor = 0.99 ** days_since_access data['confidence'] *= decay_factor data['weight'] = data['confidence'] # Remove weak relationships if data['confidence'] < self.confidence_threshold: self.graph.remove_edge(u, v, key)

Best Practices

Entity Resolution and Deduplication

class EntityResolver: def __init__(self): self.similarity_threshold = 0.8 def resolve_entity(self, new_entity: str, existing_entities: List[str]) -> Optional[str]: """Resolve entity mentions to canonical forms""" # Simple string similarity approach from difflib import SequenceMatcher best_match = None best_score = 0 for existing in existing_entities: similarity = SequenceMatcher(None, new_entity.lower(), existing.lower()).ratio() if similarity > best_score: best_score = similarity best_match = existing if best_score > self.similarity_threshold: return best_match return None # No good match found, create new entity def merge_entities(self, graph_memory: GraphMemory, entity1: str, entity2: str): """Merge two entities in the graph""" entity1_id = graph_memory.entity_index.get(entity1.lower()) entity2_id = graph_memory.entity_index.get(entity2.lower()) if not entity1_id or not entity2_id: return False # Merge attributes (entity1 is the canonical form) entity1_attrs = graph_memory.graph.nodes[entity1_id] entity2_attrs = graph_memory.graph.nodes[entity2_id] # Merge non-conflicting attributes for key, value in entity2_attrs.items(): if key not in entity1_attrs: entity1_attrs[key] = value # Transfer all edges from entity2 to entity1 # Incoming edges for pred_id in graph_memory.graph.predecessors(entity2_id): edge_data = graph_memory.graph.get_edge_data(pred_id, entity2_id) for key, attrs in edge_data.items(): graph_memory.graph.add_edge(pred_id, entity1_id, **attrs) # Outgoing edges for succ_id in graph_memory.graph.successors(entity2_id): edge_data = graph_memory.graph.get_edge_data(entity2_id, succ_id) for key, attrs in edge_data.items(): graph_memory.graph.add_edge(entity1_id, succ_id, **attrs) # Remove entity2 graph_memory.graph.remove_node(entity2_id) # Update entity index del graph_memory.entity_index[entity2.lower()] return True

Graph Maintenance and Cleanup

class GraphMaintenanceManager: def __init__(self, graph_memory: GraphMemory): self.graph_memory = graph_memory def cleanup_low_confidence_data(self, min_confidence: float = 0.3): """Remove entities and relationships with low confidence""" # Remove low confidence edges edges_to_remove = [] for u, v, key, data in self.graph_memory.graph.edges(keys=True, data=True): if data.get('confidence', 1.0) < min_confidence: edges_to_remove.append((u, v, key)) for edge in edges_to_remove: self.graph_memory.graph.remove_edge(*edge) # Remove isolated nodes with low confidence nodes_to_remove = [] for node_id, data in self.graph_memory.graph.nodes(data=True): if (data.get('confidence', 1.0) < min_confidence and self.graph_memory.graph.degree(node_id) == 0): nodes_to_remove.append(node_id) for node_id in nodes_to_remove: self.graph_memory.graph.remove_node(node_id) # Clean up entity index entity_name = data['name'] if entity_name.lower() in self.graph_memory.entity_index: del self.graph_memory.entity_index[entity_name.lower()] def prune_graph_by_centrality(self, keep_top_percent: float = 0.8): """Keep only the most central entities""" analyzer = GraphAnalyzer(self.graph_memory) centrality_scores = analyzer.find_central_entities('degree') # Calculate cutoff threshold scores = list(centrality_scores.values()) cutoff_index = int(len(scores) * keep_top_percent) cutoff_score = sorted(scores, reverse=True)[cutoff_index] # Remove entities below threshold entities_to_remove = [ entity for entity, score in centrality_scores.items() if score < cutoff_score ] for entity in entities_to_remove: entity_id = self.graph_memory.entity_index.get(entity.lower()) if entity_id: self.graph_memory.graph.remove_node(entity_id) del self.graph_memory.entity_index[entity.lower()] def update_relationship_strengths(self): """Update relationship strengths based on usage patterns""" current_time = datetime.now() for u, v, key, data in self.graph_memory.graph.edges(keys=True, data=True): # Increase strength for recently accessed relationships if 'last_accessed' in data: last_access = datetime.fromisoformat(data['last_accessed']) days_since = (current_time - last_access).days if days_since < 7: # Accessed within a week boost = 1.1 - (days_since / 7) * 0.1 data['confidence'] = min(1.0, data['confidence'] * boost) data['weight'] = data['confidence']

Integration with Other Patterns

Graph + Vector Hybrid

class GraphVectorHybrid: def __init__(self): self.graph_memory = GraphMemory() self.vector_memory = VectorRetrievalMemory() def add_interaction(self, user_input: str, agent_response: str): # Add to vector memory for semantic search self.vector_memory.add_interaction(user_input, agent_response) # Add to graph memory for relationship modeling conversational_graph = ConversationalGraphMemory() conversational_graph.graph_memory = self.graph_memory conversational_graph.process_interaction(user_input, agent_response) def retrieve_context(self, query: str, max_hops: int = 2) -> Dict: # Get semantically relevant interactions vector_results = self.vector_memory.retrieve_relevant(query, top_k=5) # Extract entities from vector results entities_mentioned = set() for result in vector_results: extracted = self._extract_entities(result['user_input'] + " " + result['agent_response']) entities_mentioned.update([e['text'] for e in extracted]) # Expand through graph relationships graph_expansion = {} for entity in entities_mentioned: neighbors = self.graph_memory.get_neighbors(entity, max_distance=max_hops) graph_expansion[entity] = neighbors return { 'semantic_results': vector_results, 'graph_expansion': graph_expansion, 'reasoning_paths': self._find_reasoning_paths(entities_mentioned) } def _find_reasoning_paths(self, entities: Set[str]) -> Dict: """Find interesting paths between mentioned entities""" paths = {} entity_list = list(entities) for i in range(len(entity_list)): for j in range(i + 1, len(entity_list)): entity1, entity2 = entity_list[i], entity_list[j] found_paths = self.graph_memory.find_path(entity1, entity2, max_hops=3) if found_paths: paths[f"{entity1} -> {entity2}"] = found_paths return paths

Testing and Validation

Unit Tests

import pytest def test_graph_memory_basic_operations(): graph = GraphMemory() # Test entity addition entity_id = graph.add_entity("John", "Person", {"age": 30}) assert entity_id in graph.graph.nodes assert graph.graph.nodes[entity_id]['name'] == "John" def test_relationship_creation(): graph = GraphMemory() graph.add_relationship("John", "Apple Inc", "works_at", confidence=0.9) # Verify both entities exist assert "john" in graph.entity_index assert "apple inc" in graph.entity_index # Verify relationship exists john_id = graph.entity_index["john"] apple_id = graph.entity_index["apple inc"] assert graph.graph.has_edge(john_id, apple_id) def test_path_finding(): graph = GraphMemory() # Create a chain: A -> B -> C graph.add_relationship("A", "B", "connects_to") graph.add_relationship("B", "C", "connects_to") paths = graph.find_path("A", "C", max_hops=3) assert len(paths) > 0 assert paths[0] == ["A", "B", "C"] def test_neighbor_discovery(): graph = GraphMemory() graph.add_relationship("Central", "Node1", "connects") graph.add_relationship("Central", "Node2", "connects") graph.add_relationship("Node1", "Distant", "connects") # Test direct neighbors neighbors = graph.get_neighbors("Central", max_distance=1) assert "Node1" in neighbors assert "Node2" in neighbors assert "Distant" not in neighbors # Test extended neighbors extended = graph.get_neighbors("Central", max_distance=2) assert "Distant" in extended

Performance Tests

def test_graph_scalability(): import time graph = GraphMemory() # Test large graph creation start_time = time.time() # Create 1000 entities with random relationships import random entities = [f"Entity_{i}" for i in range(1000)] for entity in entities: graph.add_entity(entity, "TestType") # Add 5000 random relationships for _ in range(5000): source = random.choice(entities) target = random.choice(entities) if source != target: graph.add_relationship(source, target, "test_relation") creation_time = time.time() - start_time print(f"Graph creation: {creation_time:.2f}s for 1000 entities, 5000 relationships") # Test path finding performance start_time = time.time() for _ in range(100): source = random.choice(entities) target = random.choice(entities) paths = graph.find_path(source, target, max_hops=4) pathfinding_time = time.time() - start_time print(f"Path finding: {pathfinding_time:.2f}s for 100 queries") print(f"Graph size: {graph.graph.number_of_nodes()} nodes, {graph.graph.number_of_edges()} edges")

Migration and Scaling

Migration from Other Patterns

def migrate_to_graph_memory(source_memory, entity_extractor, relation_extractor): """Migrate from linear memory patterns to graph memory""" graph = ConversationalGraphMemory() graph.entity_extractor = entity_extractor graph.relation_extractor = relation_extractor # Get all interactions from source if hasattr(source_memory, 'get_all_interactions'): interactions = source_memory.get_all_interactions() else: interactions = source_memory.interactions # Process each interaction to build graph for interaction in interactions: graph.process_interaction( interaction['user_input'], interaction['agent_response'] ) return graph.graph_memory

Distributed Graph Memory

class DistributedGraphMemory: def __init__(self, shard_count: int = 4): self.shards = [GraphMemory() for _ in range(shard_count)] self.shard_count = shard_count self.entity_shard_map = {} # entity -> shard_id def add_entity(self, entity_name: str, entity_type: str = "generic"): # Consistent hashing for entity placement shard_id = hash(entity_name.lower()) % self.shard_count self.entity_shard_map[entity_name.lower()] = shard_id return self.shards[shard_id].add_entity(entity_name, entity_type) def add_relationship(self, source: str, target: str, rel_type: str): # Both entities might be on different shards source_shard = self.entity_shard_map.get(source.lower()) target_shard = self.entity_shard_map.get(target.lower()) if source_shard is None: source_shard = self._add_entity_to_shard(source) if target_shard is None: target_shard = self._add_entity_to_shard(target) # If entities are on different shards, need cross-shard relationship handling if source_shard != target_shard: self._handle_cross_shard_relationship(source, target, rel_type, source_shard, target_shard) else: self.shards[source_shard].add_relationship(source, target, rel_type) def find_path_distributed(self, source: str, target: str, max_hops: int = 3): """Find paths across distributed shards""" source_shard = self.entity_shard_map.get(source.lower()) target_shard = self.entity_shard_map.get(target.lower()) if source_shard == target_shard: # Same shard - simple path finding return self.shards[source_shard].find_path(source, target, max_hops) else: # Cross-shard path finding - more complex algorithm needed return self._find_cross_shard_path(source, target, max_hops)

Next Steps

  1. Choose appropriate entity and relationship extraction tools
  2. Design your domain-specific graph schema
  3. Implement graph storage solution (file, database, or cloud)
  4. Build query and reasoning capabilities
  5. Add graph maintenance and cleanup procedures
  6. Test performance with realistic data volumes
  7. Consider integration with vector retrieval for hybrid approach

The Graph Memory pattern provides the most sophisticated relationship modeling capabilities but requires careful design and implementation to handle the complexity effectively.