🤖 ReAct Agent + RAGLens (Pretrained SAE)¶

  • LLM: Llama-3.1-8B-Instruct (4-bit)
  • SAE: EleutherAI/sae-llama-3.1-8b-64x @ layers.23.mlp
  • Tools: calculator, python, search, convert, finish

1. Setup¶

In [1]:
!pip install -q transformers accelerate torch bitsandbytes eai-sparsify scikit-learn
In [2]:
from kaggle_secrets import UserSecretsClient
from huggingface_hub import login

user_secrets = UserSecretsClient()
hf_token = user_secrets.get_secret("HF_TOKEN")
login(token=hf_token)
print("✅ Logged in!")
✅ Logged in!
In [3]:
import re, math, gc
import numpy as np
from typing import Dict, List, Callable, Tuple
from dataclasses import dataclass
import torch

print(f"GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")
GPU: Tesla P100-PCIE-16GB

2. Load LLM + SAE¶

In [4]:
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

MODEL_NAME = "meta-llama/Llama-3.1-8B-Instruct"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4"
)

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=True
)
model.eval()
tokenizer.pad_token = tokenizer.eos_token
gc.collect(); torch.cuda.empty_cache()
print(f"✅ LLM loaded! GPU: {torch.cuda.memory_allocated()/1e9:.2f} GB")
2026-02-08 13:56:27.734895: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1770558987.755324     624 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1770558987.760421     624 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1770558987.773795     624 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770558987.773808     624 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770558987.773811     624 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1770558987.773813     624 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]
✅ LLM loaded! GPU: 5.71 GB
In [5]:
from sparsify import Sae

SAE_NAME = "EleutherAI/sae-llama-3.1-8b-64x"
LAYER_IDX = 23
HOOKPOINT = f"layers.{LAYER_IDX}.mlp"

sae = Sae.load_from_hub(SAE_NAME, hookpoint=HOOKPOINT, device="cpu")
print(f"✅ SAE loaded! Latents: {sae.num_latents}")

# Test SAE output structure
test_input = torch.randn(1, sae.d_in)
test_out = sae.encode(test_input)
print(f"SAE output type: {type(test_out)}")
print(f"SAE output attrs: {[a for a in dir(test_out) if not a.startswith('_')]}")
Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]
✅ SAE loaded! Latents: 262144
SAE output type: <class 'sparsify.fused_encoder.EncoderOutput'>
SAE output attrs: ['count', 'index', 'pre_acts', 'top_acts', 'top_indices']

3. RAGLens Monitor¶

In [13]:
class RAGLensMonitor:
    def __init__(self, model, tokenizer, sae, layer_idx=23):
        self.model, self.tokenizer, self.sae = model, tokenizer, sae
        self.layer_idx = layer_idx
        self.detection_log = []
        self._hidden = None
        
    def _hook(self, m, i, o): self._hidden = o.detach()
        
    def extract_features(self, text: str) -> np.ndarray:
        hook = self.model.model.layers[self.layer_idx].mlp.register_forward_hook(self._hook)
        try:
            inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=256)
            inputs = {k: v.to(self.model.device) for k, v in inputs.items()}
            with torch.inference_mode(): self.model(**inputs)
            h = self._hidden[0] if self._hidden.dim()==3 else self._hidden
            
            enc_out = self.sae.encode(h.cpu().float())
            
            if hasattr(enc_out, 'top_acts'):
                latent = torch.zeros(h.shape[0], self.sae.num_latents)
                latent.scatter_(1, enc_out.top_indices, enc_out.top_acts.detach())
            elif hasattr(enc_out, 'latent_acts'):
                latent = enc_out.latent_acts.detach()
            elif hasattr(enc_out, 'latents'):
                latent = enc_out.latents.detach()
            elif isinstance(enc_out, torch.Tensor):
                latent = enc_out.detach()
            else:
                for attr in dir(enc_out):
                    val = getattr(enc_out, attr)
                    if isinstance(val, torch.Tensor) and val.dim() >= 1:
                        latent = val.detach()
                        break
                else:
                    return np.zeros(self.sae.num_latents)
            
            return latent.detach().max(0)[0].numpy()
        except Exception as e:
            print(f"Feature error: {e}")
            return np.zeros(self.sae.num_latents)
        finally: 
            hook.remove()
            self._hidden = None
            
    def detect(self, text: str, step: str="") -> Dict:
        f = self.extract_features(text)
        active = (f > 0.1).sum()
        prob = min(active / len(f) * 2, 1.0)
        result = {"step": step, "hallucination": prob>0.3, "confidence": float(prob)}
        self.detection_log.append(result)
        return result
        
    def get_summary(self) -> Dict:
        if not self.detection_log: return {"total": 0, "hallucinations": 0}
        h = sum(1 for d in self.detection_log if d["hallucination"])
        return {"total": len(self.detection_log), "hallucinations": h}
        
    def clear(self): self.detection_log = []

print("✅ RAGLensMonitor")
✅ RAGLensMonitor

4. Tools¶

In [14]:
class ToolRegistry:
    def __init__(self): self.tools, self.descs = {}, {}
    def register(self, name, func, desc): self.tools[name]=func; self.descs[name]=desc
    def execute(self, name, args):
        if name not in self.tools: return f"Unknown tool: {name}"
        try: return str(self.tools[name](args))
        except Exception as e: return f"Error: {e}"
    def desc_str(self): return "\n".join(f"- {n}: {d}" for n,d in self.descs.items())

def calculator(expr):
    expr = expr.strip().strip('"\'')
    safe = {"sqrt": math.sqrt, "sin": math.sin, "cos": math.cos, "pi": math.pi, "pow": pow, "abs": abs}
    return round(eval(expr, {"__builtins__":{}}, safe), 6)

def python_exec(code):
    code = code.strip().strip('"\'')
    import io, sys
    old = sys.stdout; sys.stdout = cap = io.StringIO()
    try:
        exec(code, {"math": math, "__builtins__": {"print":print, "range":range, "len":len, "sum":sum}})
        return cap.getvalue().strip() or str(eval(code, {"math": math}))
    except: return str(eval(code, {"math": math}))
    finally: sys.stdout = old

def search(q):
    q = q.lower().strip().strip('"\'')
    kb = {"light": "Speed of light ≈ 3×10⁸ m/s", "gravity": "g ≈ 9.8 m/s²",
          "sqrt": "√256 = 16, √144 = 12, √100 = 10", "square": "√256 = 16"}
    for k,v in kb.items():
        if k in q: return v
    return f"No info for '{q}'"

def convert(s):
    s = s.strip().strip('"\'')
    m = re.match(r"([\d.]+)\s*(\w+)\s+to\s+(\w+)", s.lower())
    if not m: return "Format: 'value unit1 to unit2'"
    v, fr, to = float(m[1]), m[2], m[3]
    if fr=="celsius" and to=="fahrenheit": return f"{v}°C = {v*9/5+32:.2f}°F"
    if fr=="km" and to=="miles": return f"{v}km = {v/1.609:.2f}mi"
    return "Unsupported"

def finish(ans): return f"FINAL: {ans.strip()}"

print("✅ Tools")
✅ Tools

5. ReAct Agent¶

In [15]:
@dataclass
class AgentStep:
    step: int; thought: str; action: str; input: str; observation: str
    hall: bool = False; conf: float = 0.0

@dataclass  
class AgentResult:
    question: str; answer: str; steps: List[AgentStep]; success: bool
In [16]:
class ReActAgent:
    def __init__(self, model, tokenizer, monitor=None, max_steps=5):
        self.model, self.tokenizer, self.monitor = model, tokenizer, monitor
        self.max_steps = max_steps
        self.tools = ToolRegistry()
        self.tools.register("calculator", calculator, "Evaluate math: calculator(sqrt(256))")
        self.tools.register("python", python_exec, "Run Python: python(math.sqrt(256))")
        self.tools.register("search", search, "Search facts: search(speed of light)")
        self.tools.register("convert", convert, "Convert units: convert(100 celsius to fahrenheit)")
        self.tools.register("finish", finish, "Final answer: finish(16)")
        
    def _prompt(self, q, history):
        p = f"""Answer questions using tools. Stop immediately when you have the answer.

Tools:
{self.tools.desc_str()}

IMPORTANT: Output ONLY in this format, nothing else:
Thought: [your reasoning]
Action: [tool name]
Action Input: [single line input]

Question: {q}
"""
        for s in history:
            p += f"\nThought: {s.thought}\nAction: {s.action}\nAction Input: {s.input}\nObservation: {s.observation}\n"
        return p
        
    def _parse(self, r):
        thought = action = inp = ""
        if m := re.search(r"Thought:\s*(.+?)(?=\nAction:|$)", r, re.DOTALL):
            thought = m[1].strip().split("\n")[0]
        if m := re.search(r"Action:\s*(\w+)", r):
            action = m[1].strip()
        if m := re.search(r"Action Input:\s*(.+)", r):
            inp = m[1].strip().split("\n")[0]
            if "Answer:" in inp: inp = inp.split("Answer:")[0].strip()
        return thought, action, inp
        
    def _generate(self, prompt):
        msgs = [{"role": "user", "content": prompt}]
        input_ids = self.tokenizer.apply_chat_template(msgs, return_tensors="pt", add_generation_prompt=True)
        input_ids = input_ids.to(self.model.device)
        input_len = input_ids.shape[1]
        with torch.inference_mode():
            out = self.model.generate(input_ids, max_new_tokens=150, temperature=0.1, do_sample=True,
                                       pad_token_id=self.tokenizer.eos_token_id)
        return self.tokenizer.decode(out[0][input_len:], skip_special_tokens=True)
        
    def run(self, question):
        print("="*50)
        print(f"❓ {question}")
        print("="*50)
        
        if self.monitor: self.monitor.clear()
        history = []
        answer = None
        
        for step in range(1, self.max_steps + 1):
            print(f"\n--- Step {step}/{self.max_steps} ---")
            
            resp = self._generate(self._prompt(question, history))
            thought, action, inp = self._parse(resp)
            
            print(f"💭 {thought[:50]}" + ("..." if len(thought)>50 else ""))
            print(f"🔧 {action}: {inp}")
            
            obs = self.tools.execute(action, inp) if action else "No action"
            print(f"👁️ {obs}")
            
            hall, conf = False, 0.0
            if self.monitor:
                det = self.monitor.detect(f"{thought} {obs}", f"Step {step}")
                hall, conf = det["hallucination"], det["confidence"]
                print(f"{'⚠️' if hall else '✅'} RAGLens: {conf:.1%}")
            
            history.append(AgentStep(step, thought, action or "none", inp, obs, hall, conf))
            
            if action == "finish" or "FINAL:" in obs:
                answer = obs.replace("FINAL:", "").strip()
                print(f"\n✅ DONE: {answer}")
                break
                
        if not answer:
            print("\n⚠️ Max steps")
            answer = "Unknown"
            
        if self.monitor:
            s = self.monitor.get_summary()
            print(f"📊 RAGLens: {s['hallucinations']}/{s['total']}")
            
        return AgentResult(question, answer, history, answer != "Unknown")

print("✅ ReActAgent")
✅ ReActAgent

6. Run Tests 🚀¶

In [17]:
monitor = RAGLensMonitor(model, tokenizer, sae, layer_idx=LAYER_IDX)
agent = ReActAgent(model, tokenizer, monitor, max_steps=5)

print("\n🎯 Test 1: Math")
r1 = agent.run("What is sqrt(256)?")
🎯 Test 1: Math
==================================================
❓ What is sqrt(256)?
==================================================

--- Step 1/5 ---
💭 The problem is asking for the square root of 256.
🔧 calculator: sqrt(256)
👁️ 16.0
✅ RAGLens: 0.3%

--- Step 2/5 ---
💭 The problem is asking for the square root of 256.
🔧 python: math.sqrt(256)
👁️ 16.0
✅ RAGLens: 0.3%

--- Step 3/5 ---
💭 The problem is asking for the square root of 256.
🔧 finish: 16
👁️ FINAL: 16
✅ RAGLens: 0.3%

✅ DONE: 16
📊 RAGLens: 0/3
In [18]:
print("\n🎯 Test 2: Conversion")
r2 = agent.run("Convert 100 celsius to fahrenheit")
🎯 Test 2: Conversion
==================================================
❓ Convert 100 celsius to fahrenheit
==================================================

--- Step 1/5 ---
💭 We need to convert a temperature from Celsius to F...
🔧 convert: 100 celsius to fahrenheit
👁️ 100.0°C = 212.00°F
✅ RAGLens: 0.4%

--- Step 2/5 ---
💭 We need to convert a temperature from Celsius to F...
🔧 convert: 100 celsius to fahrenheit
👁️ 100.0°C = 212.00°F
✅ RAGLens: 0.4%

--- Step 3/5 ---
💭 We need to convert a temperature from Celsius to F...
🔧 convert: 100 celsius to fahrenheit
👁️ 100.0°C = 212.00°F
✅ RAGLens: 0.4%

--- Step 4/5 ---
💭 We need to convert a temperature from Celsius to F...
🔧 convert: 100 celsius to fahrenheit
👁️ 100.0°C = 212.00°F
✅ RAGLens: 0.4%

--- Step 5/5 ---
💭 We need to convert a temperature from Celsius to F...
🔧 convert: 100 celsius to fahrenheit
👁️ 100.0°C = 212.00°F
✅ RAGLens: 0.4%

⚠️ Max steps
📊 RAGLens: 0/5
In [19]:
print("\n🎯 Test 3: Knowledge")
r3 = agent.run("What is the speed of light?")
🎯 Test 3: Knowledge
==================================================
❓ What is the speed of light?
==================================================

--- Step 1/5 ---
💭 The speed of light is a fundamental constant in ph...
🔧 search: speed of light
👁️ Speed of light ≈ 3×10⁸ m/s
✅ RAGLens: 0.4%

--- Step 2/5 ---
💭 The speed of light is a fundamental constant in ph...
🔧 search: speed of light in m/s
👁️ Speed of light ≈ 3×10⁸ m/s
✅ RAGLens: 0.4%

--- Step 3/5 ---
💭 The speed of light is a fundamental constant in ph...
🔧 search: speed of light in m/s
👁️ Speed of light ≈ 3×10⁸ m/s
✅ RAGLens: 0.4%

--- Step 4/5 ---
💭 The speed of light is a fundamental constant in ph...
🔧 search: speed of light in m/s
👁️ Speed of light ≈ 3×10⁸ m/s
✅ RAGLens: 0.4%

--- Step 5/5 ---
💭 The speed of light is a fundamental constant in ph...
🔧 finish: 3×10⁸
👁️ FINAL: 3×10⁸
✅ RAGLens: 0.4%

✅ DONE: 3×10⁸
📊 RAGLens: 0/5