40 lines
1.3 KiB
Python
40 lines
1.3 KiB
Python
import faiss
|
|
import os
|
|
import pickle
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
# Embedder: small model is fast & good enough
|
|
embedder = SentenceTransformer("all-MiniLM-L6-v2") # Replaceable
|
|
|
|
class Memory:
|
|
def __init__(self, index_path="kshama.index", metadata_path="memory_meta.pkl"):
|
|
self.index_path = index_path
|
|
self.metadata_path = metadata_path
|
|
self.index = None
|
|
self.metadata = []
|
|
self._load()
|
|
|
|
def _load(self):
|
|
if os.path.exists(self.index_path):
|
|
self.index = faiss.read_index(self.index_path)
|
|
with open(self.metadata_path, "rb") as f:
|
|
self.metadata = pickle.load(f)
|
|
else:
|
|
self.index = faiss.IndexFlatL2(384) # Depends on embedder output dim
|
|
|
|
def add(self, text, tags=None):
|
|
vec = embedder.encode([text])
|
|
self.index.add(vec)
|
|
self.metadata.append({"text": text, "tags": tags or []})
|
|
self._save()
|
|
|
|
def query(self, text, top_k=5):
|
|
vec = embedder.encode([text])
|
|
D, I = self.index.search(vec, top_k)
|
|
return [self.metadata[i]["text"] for i in I[0]]
|
|
|
|
def _save(self):
|
|
faiss.write_index(self.index, self.index_path)
|
|
with open(self.metadata_path, "wb") as f:
|
|
pickle.dump(self.metadata, f)
|