🤖 ReAct Agent + RAGLens (Pretrained SAE)¶
- LLM: Llama-3.1-8B-Instruct (4-bit)
- SAE: EleutherAI/sae-llama-3.1-8b-64x @ layers.23.mlp
- Tools: calculator, python, search, convert, finish
1. Setup¶
In [1]:
!pip install -q transformers accelerate torch bitsandbytes eai-sparsify scikit-learn
In [2]:
from kaggle_secrets import UserSecretsClient
from huggingface_hub import login
user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HF_TOKEN")
login(token=hf_token)
print("✅ Logged in!")
✅ Logged in!
In [3]:
import re, math, gc
import numpy as np
from typing import Dict, List, Callable, Tuple
from dataclasses import dataclass
import torch
print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
GPU: Tesla P100-PCIE-16GB
2. Load LLM + SAE¶
In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4"
)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map="auto",
trust_remote_code=True
)
model.eval()
tokenizer.pad_token = tokenizer.eos_token
gc.collect(); torch.cuda.empty_cache()
print(f"✅ LLM loaded! GPU: {torch.cuda.memory_allocated()/1e9:.2f} GB")
2026-02-08 13:56:27.734895: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1770558987.755324 624 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered E0000 00:00:1770558987.760421 624 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered W0000 00:00:1770558987.773795 624 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1770558987.773808 624 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1770558987.773811 624 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1770558987.773813 624 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
Loading checkpoint shards: 0%| | 0/4 [00:00<?, ?it/s]
✅ LLM loaded! GPU: 5.71 GB
In [5]:
from sparsify import Sae
SAE_NAME = "EleutherAI/sae-llama-3.1-8b-64x"
LAYER_IDX = 23
HOOKPOINT = f"layers.{LAYER_IDX}.mlp"
sae = Sae.load_from_hub(SAE_NAME, hookpoint=HOOKPOINT, device="cpu")
print(f"✅ SAE loaded! Latents: {sae.num_latents}")
# Test SAE output structure
test_input = torch.randn(1, sae.d_in)
test_out = sae.encode(test_input)
print(f"SAE output type: {type(test_out)}")
print(f"SAE output attrs: {[a for a in dir(test_out) if not a.startswith('_')]}")
Fetching 2 files: 0%| | 0/2 [00:00<?, ?it/s]
✅ SAE loaded! Latents: 262144 SAE output type: <class 'sparsify.fused_encoder.EncoderOutput'> SAE output attrs: ['count', 'index', 'pre_acts', 'top_acts', 'top_indices']
3. RAGLens Monitor¶
In [13]:
class RAGLensMonitor:
def __init__(self, model, tokenizer, sae, layer_idx=23):
self.model, self.tokenizer, self.sae = model, tokenizer, sae
self.layer_idx = layer_idx
self.detection_log = []
self._hidden = None
def _hook(self, m, i, o): self._hidden = o.detach()
def extract_features(self, text: str) -> np.ndarray:
hook = self.model.model.layers[self.layer_idx].mlp.register_forward_hook(self._hook)
try:
inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
with torch.inference_mode(): self.model(**inputs)
h = self._hidden[0] if self._hidden.dim()==3 else self._hidden
enc_out = self.sae.encode(h.cpu().float())
if hasattr(enc_out, 'top_acts'):
latent = torch.zeros(h.shape[0], self.sae.num_latents)
latent.scatter_(1, enc_out.top_indices, enc_out.top_acts.detach())
elif hasattr(enc_out, 'latent_acts'):
latent = enc_out.latent_acts.detach()
elif hasattr(enc_out, 'latents'):
latent = enc_out.latents.detach()
elif isinstance(enc_out, torch.Tensor):
latent = enc_out.detach()
else:
for attr in dir(enc_out):
val = getattr(enc_out, attr)
if isinstance(val, torch.Tensor) and val.dim() >= 1:
latent = val.detach()
break
else:
return np.zeros(self.sae.num_latents)
return latent.detach().max(0)[0].numpy()
except Exception as e:
print(f"Feature error: {e}")
return np.zeros(self.sae.num_latents)
finally:
hook.remove()
self._hidden = None
def detect(self, text: str, step: str="") -> Dict:
f = self.extract_features(text)
active = (f > 0.1).sum()
prob = min(active / len(f) * 2, 1.0)
result = {"step": step, "hallucination": prob>0.3, "confidence": float(prob)}
self.detection_log.append(result)
return result
def get_summary(self) -> Dict:
if not self.detection_log: return {"total": 0, "hallucinations": 0}
h = sum(1 for d in self.detection_log if d["hallucination"])
return {"total": len(self.detection_log), "hallucinations": h}
def clear(self): self.detection_log = []
print("✅ RAGLensMonitor")
✅ RAGLensMonitor
4. Tools¶
In [14]:
class ToolRegistry:
def __init__(self): self.tools, self.descs = {}, {}
def register(self, name, func, desc): self.tools[name]=func; self.descs[name]=desc
def execute(self, name, args):
if name not in self.tools: return f"Unknown tool: {name}"
try: return str(self.tools[name](args))
except Exception as e: return f"Error: {e}"
def desc_str(self): return "\n".join(f"- {n}: {d}" for n,d in self.descs.items())
def calculator(expr):
expr = expr.strip().strip('"\'')
safe = {"sqrt": math.sqrt, "sin": math.sin, "cos": math.cos, "pi": math.pi, "pow": pow, "abs": abs}
return round(eval(expr, {"__builtins__":{}}, safe), 6)
def python_exec(code):
code = code.strip().strip('"\'')
import io, sys
old = sys.stdout; sys.stdout = cap = io.StringIO()
try:
exec(code, {"math": math, "__builtins__": {"print":print, "range":range, "len":len, "sum":sum}})
return cap.getvalue().strip() or str(eval(code, {"math": math}))
except: return str(eval(code, {"math": math}))
finally: sys.stdout = old
def search(q):
q = q.lower().strip().strip('"\'')
kb = {"light": "Speed of light ≈ 3×10⁸ m/s", "gravity": "g ≈ 9.8 m/s²",
"sqrt": "√256 = 16, √144 = 12, √100 = 10", "square": "√256 = 16"}
for k,v in kb.items():
if k in q: return v
return f"No info for '{q}'"
def convert(s):
s = s.strip().strip('"\'')
m = re.match(r"([\d.]+)\s*(\w+)\s+to\s+(\w+)", s.lower())
if not m: return "Format: 'value unit1 to unit2'"
v, fr, to = float(m[1]), m[2], m[3]
if fr=="celsius" and to=="fahrenheit": return f"{v}°C = {v*9/5+32:.2f}°F"
if fr=="km" and to=="miles": return f"{v}km = {v/1.609:.2f}mi"
return "Unsupported"
def finish(ans): return f"FINAL: {ans.strip()}"
print("✅ Tools")
✅ Tools
5. ReAct Agent¶
In [15]:
@dataclass
class AgentStep:
step: int; thought: str; action: str; input: str; observation: str
hall: bool = False; conf: float = 0.0
@dataclass
class AgentResult:
question: str; answer: str; steps: List[AgentStep]; success: bool
In [16]:
class ReActAgent:
def __init__(self, model, tokenizer, monitor=None, max_steps=5):
self.model, self.tokenizer, self.monitor = model, tokenizer, monitor
self.max_steps = max_steps
self.tools = ToolRegistry()
self.tools.register("calculator", calculator, "Evaluate math: calculator(sqrt(256))")
self.tools.register("python", python_exec, "Run Python: python(math.sqrt(256))")
self.tools.register("search", search, "Search facts: search(speed of light)")
self.tools.register("convert", convert, "Convert units: convert(100 celsius to fahrenheit)")
self.tools.register("finish", finish, "Final answer: finish(16)")
def _prompt(self, q, history):
p = f"""Answer questions using tools. Stop immediately when you have the answer.
Tools:
{self.tools.desc_str()}
IMPORTANT: Output ONLY in this format, nothing else:
Thought: [your reasoning]
Action: [tool name]
Action Input: [single line input]
Question: {q}
"""
for s in history:
p += f"\nThought: {s.thought}\nAction: {s.action}\nAction Input: {s.input}\nObservation: {s.observation}\n"
return p
def _parse(self, r):
thought = action = inp = ""
if m := re.search(r"Thought:\s*(.+?)(?=\nAction:|$)", r, re.DOTALL):
thought = m[1].strip().split("\n")[0]
if m := re.search(r"Action:\s*(\w+)", r):
action = m[1].strip()
if m := re.search(r"Action Input:\s*(.+)", r):
inp = m[1].strip().split("\n")[0]
if "Answer:" in inp: inp = inp.split("Answer:")[0].strip()
return thought, action, inp
def _generate(self, prompt):
msgs = [{"role": "user", "content": prompt}]
input_ids = self.tokenizer.apply_chat_template(msgs, return_tensors="pt", add_generation_prompt=True)
input_ids = input_ids.to(self.model.device)
input_len = input_ids.shape[1]
with torch.inference_mode():
out = self.model.generate(input_ids, max_new_tokens=150, temperature=0.1, do_sample=True,
pad_token_id=self.tokenizer.eos_token_id)
return self.tokenizer.decode(out[0][input_len:], skip_special_tokens=True)
def run(self, question):
print("="*50)
print(f"❓ {question}")
print("="*50)
if self.monitor: self.monitor.clear()
history = []
answer = None
for step in range(1, self.max_steps + 1):
print(f"\n--- Step {step}/{self.max_steps} ---")
resp = self._generate(self._prompt(question, history))
thought, action, inp = self._parse(resp)
print(f"💭 {thought[:50]}" + ("..." if len(thought)>50 else ""))
print(f"🔧 {action}: {inp}")
obs = self.tools.execute(action, inp) if action else "No action"
print(f"👁️ {obs}")
hall, conf = False, 0.0
if self.monitor:
det = self.monitor.detect(f"{thought} {obs}", f"Step {step}")
hall, conf = det["hallucination"], det["confidence"]
print(f"{'⚠️' if hall else '✅'} RAGLens: {conf:.1%}")
history.append(AgentStep(step, thought, action or "none", inp, obs, hall, conf))
if action == "finish" or "FINAL:" in obs:
answer = obs.replace("FINAL:", "").strip()
print(f"\n✅ DONE: {answer}")
break
if not answer:
print("\n⚠️ Max steps")
answer = "Unknown"
if self.monitor:
s = self.monitor.get_summary()
print(f"📊 RAGLens: {s['hallucinations']}/{s['total']}")
return AgentResult(question, answer, history, answer != "Unknown")
print("✅ ReActAgent")
✅ ReActAgent
6. Run Tests 🚀¶
In [17]:
monitor = RAGLensMonitor(model, tokenizer, sae, layer_idx=LAYER_IDX)
agent = ReActAgent(model, tokenizer, monitor, max_steps=5)
print("\n🎯 Test 1: Math")
r1 = agent.run("What is sqrt(256)?")
🎯 Test 1: Math ================================================== ❓ What is sqrt(256)? ================================================== --- Step 1/5 --- 💭 The problem is asking for the square root of 256. 🔧 calculator: sqrt(256) 👁️ 16.0 ✅ RAGLens: 0.3% --- Step 2/5 --- 💭 The problem is asking for the square root of 256. 🔧 python: math.sqrt(256) 👁️ 16.0 ✅ RAGLens: 0.3% --- Step 3/5 --- 💭 The problem is asking for the square root of 256. 🔧 finish: 16 👁️ FINAL: 16 ✅ RAGLens: 0.3% ✅ DONE: 16 📊 RAGLens: 0/3
In [18]:
print("\n🎯 Test 2: Conversion")
r2 = agent.run("Convert 100 celsius to fahrenheit")
🎯 Test 2: Conversion ================================================== ❓ Convert 100 celsius to fahrenheit ================================================== --- Step 1/5 --- 💭 We need to convert a temperature from Celsius to F... 🔧 convert: 100 celsius to fahrenheit 👁️ 100.0°C = 212.00°F ✅ RAGLens: 0.4% --- Step 2/5 --- 💭 We need to convert a temperature from Celsius to F... 🔧 convert: 100 celsius to fahrenheit 👁️ 100.0°C = 212.00°F ✅ RAGLens: 0.4% --- Step 3/5 --- 💭 We need to convert a temperature from Celsius to F... 🔧 convert: 100 celsius to fahrenheit 👁️ 100.0°C = 212.00°F ✅ RAGLens: 0.4% --- Step 4/5 --- 💭 We need to convert a temperature from Celsius to F... 🔧 convert: 100 celsius to fahrenheit 👁️ 100.0°C = 212.00°F ✅ RAGLens: 0.4% --- Step 5/5 --- 💭 We need to convert a temperature from Celsius to F... 🔧 convert: 100 celsius to fahrenheit 👁️ 100.0°C = 212.00°F ✅ RAGLens: 0.4% ⚠️ Max steps 📊 RAGLens: 0/5
In [19]:
print("\n🎯 Test 3: Knowledge")
r3 = agent.run("What is the speed of light?")
🎯 Test 3: Knowledge ================================================== ❓ What is the speed of light? ================================================== --- Step 1/5 --- 💭 The speed of light is a fundamental constant in ph... 🔧 search: speed of light 👁️ Speed of light ≈ 3×10⁸ m/s ✅ RAGLens: 0.4% --- Step 2/5 --- 💭 The speed of light is a fundamental constant in ph... 🔧 search: speed of light in m/s 👁️ Speed of light ≈ 3×10⁸ m/s ✅ RAGLens: 0.4% --- Step 3/5 --- 💭 The speed of light is a fundamental constant in ph... 🔧 search: speed of light in m/s 👁️ Speed of light ≈ 3×10⁸ m/s ✅ RAGLens: 0.4% --- Step 4/5 --- 💭 The speed of light is a fundamental constant in ph... 🔧 search: speed of light in m/s 👁️ Speed of light ≈ 3×10⁸ m/s ✅ RAGLens: 0.4% --- Step 5/5 --- 💭 The speed of light is a fundamental constant in ph... 🔧 finish: 3×10⁸ 👁️ FINAL: 3×10⁸ ✅ RAGLens: 0.4% ✅ DONE: 3×10⁸ 📊 RAGLens: 0/5