diff --git a/agent.py b/agent.py index ea07a6e..1797b03 100644 --- a/agent.py +++ b/agent.py @@ -6,7 +6,7 @@ from web_search_helper import WebSearchHelper begin_time = time.time() -# === šŸ”§ Initialize model + tokenizer === +# === šŸ”§ Load model + tokenizer === model_id = "meta-llama/Llama-3.2-1B-Instruct" tokenizer = AutoTokenizer.from_pretrained(model_id) pipe = pipeline( @@ -15,63 +15,76 @@ pipe = pipeline( tokenizer=tokenizer, torch_dtype=torch.bfloat16, device_map="auto", - pad_token_id=128001 # Prevents warning spam + pad_token_id=128001 ) -# === 🧠 Core components === +# === šŸ”Œ Core modules === memory = Memory() searcher = WebSearchHelper() -# === 🧭 System behavior prompt === +# === 🧭 System behavior instruction === SYSTEM_PROMPT = """ -You are ą¦•ą§ą¦·ą¦®ą¦¾ (Kshama), Abu's personal AI assistant. You are insightful, methodical, and intentional. -Capabilities: -- Recall useful information from persistent memory. -- Decide when a web search is truly necessary. -- Summarize web content when requested using clear language. +You are personal AI assistant. You're wise, efficient, and intentional. -Protocols: -- To store new memory: ##MEM:add("...") -- To request search: ##SEARCH:yes -- If no search is needed: ##SEARCH:no +You can: +- Recall long-term memory and use it to answer. +- Summarize long documents clearly. +- Perform web search *only if you believe it's necessary*, and clearly state that with ##SEARCH:yes. -Be precise and only initiate web search when memory is insufficient. Don't guess. Use memory and web knowledge actively. +You also refine web search queries using what you understand of the user's intent. +Always follow this format: +- ##MEM:add("...") to add memories +- ##SEARCH:yes or ##SEARCH:no on its own line to trigger or skip web search +- After search: generate a clear answer, using memory and the retrieved summaries """ -# === šŸ“ Summarizer using same model === +# === šŸ“˜ Summarization using main model === def summarize_with_llama(text: str) -> str: - prompt = f"Summarize the following content briefly:\n\n{text.strip()}\n\nSummary:" + prompt = f"Summarize the following:\n\n{text.strip()}\n\nSummary:" output = pipe(prompt, max_new_tokens=256) return output[0]["generated_text"].replace(prompt, "").strip() -# === šŸ” Check if agent requests web search === -def should_search(user_input: str, mem_text: str, kb_text: str) -> bool: +# === šŸ” Ask if search is needed === +def ask_should_search(user_input, mem_text, kb_text): messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": f"User asked: {user_input}"}, {"role": "user", "content": f"Memory:\n{mem_text or '[None]'}"}, {"role": "user", "content": f"Web Knowledge:\n{kb_text or '[None]'}"}, - {"role": "user", "content": "Should you search the web to answer this? Reply with ##SEARCH:yes or ##SEARCH:no only on the first line."} + {"role": "user", "content": "Do you need to search the web to answer this? Reply ##SEARCH:yes or ##SEARCH:no on the first line only."} ] prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - output = pipe(prompt, max_new_tokens=16, do_sample=False) + output = pipe(prompt, max_new_tokens=16) reply = output[0]["generated_text"].strip().lower() - print(output) - return reply.splitlines()[0].strip() == "##SEARCH:yes" + return reply.splitlines()[0].strip().__contains__("##SEARCH:yes") -# === 🧠 Main agent response handler === +# === āœļø Compose better search query === +def compose_search_query(user_input, mem_text): + messages = [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": f"User asked: {user_input}"}, + {"role": "user", "content": f"Relevant memory:\n{mem_text or '[None]'}"}, + {"role": "user", "content": "Rewrite a concise web search query to find useful info. Output only the query string, nothing else."} + ] + prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + output = pipe(prompt, max_new_tokens=32) + return output[0]["generated_text"].strip().splitlines()[0] + +# === 🧠 Main reasoning function === def generate_response(user_input: str): - # Step 1: Retrieve memory and knowledgebase + # Step 1: Recall memory and web KB mem_hits = memory.query(user_input, top_k=3) - mem_text = "\n".join([f"- {m}" for m in mem_hits]) + mem_text = "\n".join([f"- {x}" for x in mem_hits]) - _, kb_hits = searcher.query_kb(user_input, top_k=3) + _, kb_hits = searcher.query_kb(user_input) kb_text = "\n".join([f"- {k['summary']}" for k in kb_hits]) - # Step 2: Ask if search is needed - if should_search(user_input, mem_text, kb_text): + # Step 2: Ask model if search is truly required + if ask_should_search(user_input, mem_text, kb_text): print("[🌐 Search Triggered]") - urls = searcher.search_duckduckgo(user_input) + search_query = compose_search_query(user_input, mem_text) + print(f"[šŸ”Ž Composed Query] {search_query}") + urls = searcher.search_duckduckgo(search_query) summaries = searcher.crawl_and_summarize(urls, llm_function=summarize_with_llama) searcher.add_to_kb(summaries) _, kb_hits = searcher.query_kb(user_input) @@ -79,7 +92,7 @@ def generate_response(user_input: str): else: print("[šŸ”’ Search Skipped]") - # Step 3: Generate final answer + # Step 3: Final answer generation messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_input}, @@ -87,31 +100,29 @@ def generate_response(user_input: str): {"role": "user", "content": f"Web Knowledge:\n{kb_text or '[None]'}"} ] full_prompt = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) - start = time.time() output = pipe(full_prompt, max_new_tokens=512) elapsed = time.time() - start response = output[0]["generated_text"].replace(full_prompt, "").strip() - # Step 4: Store memory if requested if "##MEM:add(" in response: try: content = response.split("##MEM:add(")[1].split(")")[0].strip('"\'') memory.add(content) print("[āœ… Memory Added]") except Exception as e: - print(f"[āš ļø Could not parse memory]: {e}") + print(f"[āš ļø Failed to add memory]: {e}") return response, elapsed -# === šŸ‘‚ Main loop === +# === šŸ’¬ REPL Loop === if __name__ == "__main__": print(f"šŸš€ Kshama ready in {time.time() - begin_time:.2f}s") print("šŸ‘‹ Hello, Abu. Type 'exit' to quit.") while True: user_input = input("\nšŸ§‘ You: ") if user_input.strip().lower() in ["exit", "quit"]: - print("šŸ‘‹ Farewell.") + print("šŸ‘‹ Goodbye.") break response, delay = generate_response(user_input) print(f"\nšŸ¤– ą¦•ą§ą¦·ą¦®ą¦¾ [{delay:.2f}s]: {response}") diff --git a/memory.py b/memory.py index f58a746..49af19a 100644 --- a/memory.py +++ b/memory.py @@ -29,9 +29,11 @@ class Memory: self._save() def query(self, text, top_k=5): + if self.index.ntotal == 0: + return [] vec = embedder.encode([text]) D, I = self.index.search(vec, top_k) - return [self.metadata[i]["text"] for i in I[0]] + return [self.metadata[i]["text"] for i in I[0] if 0 <= i < len(self.metadata)] def _save(self): faiss.write_index(self.index, self.index_path)