Skip to content

Commit 2ee68fd

Browse files
committed
feat: Add MetaLadder adapter for enhanced mathematical reasoning
This commit implements the MetaLadder approach from Lin et al. (2025) for improving mathematical reasoning through analogical learning and problem restatement. Key features include: problem type identification, meta problem generation, problem restatement, shortcut/full path options, LRU caching, and optimizer integration.
1 parent 595e322 commit 2ee68fd

File tree

7 files changed

+522
-2
lines changed

7 files changed

+522
-2
lines changed

PR.md

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

dspy/adapters/__init__.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,17 @@
1+
"""DSPy adapter implementations."""
2+
13
from dspy.adapters.base import Adapter
24
from dspy.adapters.chat_adapter import ChatAdapter
35
from dspy.adapters.json_adapter import JSONAdapter
4-
from dspy.adapters.types import Image, History
6+
from dspy.adapters.types import Image, History, AdapterResponse
7+
from dspy.adapters.metaladder_adapter import MetaLadderAdapter
58

69
__all__ = [
710
"Adapter",
811
"ChatAdapter",
912
"JSONAdapter",
1013
"Image",
1114
"History",
15+
"AdapterResponse",
16+
"MetaLadderAdapter"
1217
]

dspy/adapters/metaladder_adapter.py

+208
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
"""MetaLadder adapter for enhancing mathematical reasoning through analogical learning.
2+
3+
This module implements the MetaLadder framework as described in the paper
4+
"MetaLadder: Ascending Mathematical Solution Quality via Analogical-Problem Reasoning Transfer".
5+
"""
6+
7+
from typing import Any, Dict, List, Optional, Tuple, Union, Callable
8+
from dataclasses import dataclass
9+
import re
10+
import hashlib
11+
from functools import lru_cache
12+
13+
from dspy.adapters.base import Adapter
14+
from dspy.adapters.types.response import AdapterResponse
15+
from dspy.dsp.utils import normalize_text
16+
from dspy.teleprompt import BootstrapFewShot
17+
from dspy.primitives.program import Module
18+
19+
20+
@lru_cache(maxsize=1000)
21+
def _get_cache_key(text: str) -> str:
22+
"""Generate a stable cache key for a given text.
23+
24+
Args:
25+
text: The text to generate a cache key for.
26+
27+
Returns:
28+
A stable hash of the text.
29+
"""
30+
return hashlib.sha256(text.encode()).hexdigest()
31+
32+
33+
@dataclass
34+
class MetaProblem:
35+
"""A class representing a meta problem for the MetaLadder adapter.
36+
37+
Attributes:
38+
problem_type: The type of the problem.
39+
meta_problem: The meta problem description.
40+
restatement: The restatement of the problem.
41+
"""
42+
problem_type: str
43+
meta_problem: str
44+
restatement: str
45+
46+
def __hash__(self) -> int:
47+
"""Generate a hash for the MetaProblem instance.
48+
49+
Returns:
50+
int: The hash value.
51+
"""
52+
return hash((self.problem_type, self.meta_problem, self.restatement))
53+
54+
55+
class MetaLadderAdapter(Adapter):
56+
"""An adapter that implements the MetaLadder approach for mathematical reasoning.
57+
58+
This adapter enhances mathematical reasoning through analogical learning by:
59+
1. Identifying the problem type
60+
2. Generating a meta problem
61+
3. Restating the problem
62+
4. Using either a shortcut or full reasoning path
63+
64+
Attributes:
65+
model (Module): The language model to use.
66+
optimizer (Optional[BootstrapFewShot]): The optimizer for improving prompts.
67+
use_shortcut (bool): Whether to use shortcut inference.
68+
max_tokens (int): Maximum number of tokens for responses.
69+
cache_size (int): Size of the LRU cache for method results.
70+
"""
71+
72+
def __init__(
73+
self,
74+
model: Module,
75+
optimizer: Optional[BootstrapFewShot] = None,
76+
use_shortcut: bool = True,
77+
max_tokens: int = 1000,
78+
cache_size: int = 1000,
79+
) -> None:
80+
"""Initialize the MetaLadderAdapter.
81+
82+
Args:
83+
model: The language model to use.
84+
optimizer: Optional optimizer for improving prompts.
85+
use_shortcut: Whether to use shortcut inference.
86+
max_tokens: Maximum number of tokens for responses.
87+
cache_size: Size of the LRU cache for method results.
88+
"""
89+
super().__init__()
90+
self.model = model
91+
self.optimizer = optimizer
92+
self.use_shortcut = use_shortcut
93+
self.max_tokens = max_tokens
94+
95+
# Initialize cached methods
96+
self._identify_problem_type = self._create_cached_method(
97+
self._identify_problem_type_impl, cache_size
98+
)
99+
self._generate_meta_problem = self._create_cached_method(
100+
self._generate_meta_problem_impl, cache_size
101+
)
102+
self._restate_problem = self._create_cached_method(
103+
self._restate_problem_impl, cache_size
104+
)
105+
106+
def _create_cached_method(self, method: Any, cache_size: int) -> Any:
107+
"""Create a cached version of a method.
108+
109+
Args:
110+
method: The method to cache.
111+
cache_size: Size of the LRU cache.
112+
113+
Returns:
114+
The cached method.
115+
"""
116+
return lru_cache(maxsize=cache_size)(method)
117+
118+
def _call_model(self, prompt: str) -> str:
119+
"""Call the model with a prompt.
120+
121+
Args:
122+
prompt: The input prompt.
123+
124+
Returns:
125+
The model's response.
126+
"""
127+
if self.optimizer:
128+
return self.optimizer.compile(self.model, trainset=[prompt])
129+
return self.model.__call__(prompt)
130+
131+
def _identify_problem_type_impl(self, problem: str) -> str:
132+
"""Identify the type of mathematical problem.
133+
134+
Args:
135+
problem: The problem description.
136+
137+
Returns:
138+
The identified problem type.
139+
"""
140+
prompt = f"Identify the type of this math problem: {problem}"
141+
return self._call_model(prompt)
142+
143+
def _generate_meta_problem_impl(self, problem_type: str, problem: str) -> str:
144+
"""Generate a meta problem description.
145+
146+
Args:
147+
problem_type: The type of problem.
148+
problem: The original problem.
149+
150+
Returns:
151+
The meta problem description.
152+
"""
153+
prompt = f"Generate a meta problem for this {problem_type} problem: {problem}"
154+
return self._call_model(prompt)
155+
156+
def _restate_problem_impl(
157+
self, problem_type: str, meta_problem: str, problem: str
158+
) -> str:
159+
"""Restate the problem using the meta problem structure.
160+
161+
Args:
162+
problem_type: The type of problem.
163+
meta_problem: The meta problem description.
164+
problem: The original problem.
165+
166+
Returns:
167+
The restated problem.
168+
"""
169+
prompt = (
170+
f"Restate this {problem_type} problem using the structure of the meta problem.\n"
171+
f"Meta problem: {meta_problem}\n"
172+
f"Problem: {problem}"
173+
)
174+
return self._call_model(prompt)
175+
176+
def forward(self, prompt: str) -> Tuple[str, Optional[MetaProblem]]:
177+
"""Process a prompt using the MetaLadder approach.
178+
179+
Args:
180+
prompt: The input prompt.
181+
182+
Returns:
183+
A tuple containing:
184+
- The model's response
185+
- The MetaProblem object (if not using shortcut)
186+
"""
187+
if self.use_shortcut:
188+
return self._call_model(prompt), None
189+
190+
# Full reasoning path
191+
problem_type = self._identify_problem_type(prompt)
192+
meta_problem = self._generate_meta_problem(problem_type, prompt)
193+
restatement = self._restate_problem(problem_type, meta_problem, prompt)
194+
195+
meta_problem_obj = MetaProblem(
196+
problem_type=problem_type,
197+
meta_problem=meta_problem,
198+
restatement=restatement,
199+
)
200+
201+
response = self._call_model(restatement)
202+
return response, meta_problem_obj
203+
204+
def clear_cache(self) -> None:
205+
"""Clear all cached data."""
206+
self._identify_problem_type.cache_clear()
207+
self._generate_meta_problem.cache_clear()
208+
self._restate_problem.cache_clear()

dspy/adapters/types/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,7 @@
1+
"""Types for adapters."""
2+
13
from dspy.adapters.types.history import History
24
from dspy.adapters.types.image import Image
5+
from dspy.adapters.types.response import AdapterResponse
36

4-
__all__ = ["History", "Image"]
7+
__all__ = ["History", "Image", "AdapterResponse"]

dspy/adapters/types/response.py

+12
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
"""Response type for adapters."""
2+
3+
from dataclasses import dataclass
4+
from typing import Optional
5+
6+
7+
@dataclass
8+
class AdapterResponse:
9+
"""Response from an adapter."""
10+
11+
text: str
12+
logprobs: Optional[dict] = None

example.py

+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
"""Example usage of MetaLadderAdapter for mathematical reasoning."""
2+
3+
from typing import Any
4+
from dspy.primitives.program import Module
5+
from dspy.predict.predict import Predict
6+
from dspy.signatures.signature import make_signature
7+
from dspy.adapters.metaladder_adapter import MetaLadderAdapter
8+
from dspy.clients.lm import LM
9+
10+
# Create a basic signature for our math solver
11+
MathSolver = make_signature(
12+
"problem -> solution",
13+
"Given a mathematical problem, provide a step-by-step solution."
14+
)
15+
16+
class SimpleMathModel(Module):
17+
"""A simple model for solving math problems."""
18+
19+
def __init__(self) -> None:
20+
"""Initialize the model with a predictor."""
21+
super().__init__()
22+
self.predictor = Predict(MathSolver)
23+
24+
def forward(self, *args: Any, **kwargs: Any) -> Any:
25+
"""Forward pass of the model."""
26+
return self.predictor(**kwargs)
27+
28+
def main() -> None:
29+
"""Run an example using the MetaLadderAdapter."""
30+
# Initialize the language model
31+
lm = LM(model="gpt-3.5-turbo")
32+
33+
# Create our math model
34+
model = SimpleMathModel()
35+
model.set_lm(lm)
36+
37+
# Create the adapter
38+
adapter = MetaLadderAdapter(
39+
model=model,
40+
use_shortcut=False # Use the full reasoning path
41+
)
42+
43+
# Example math problem
44+
problem = "If a train travels at 60 miles per hour for 2.5 hours, how far does it travel?"
45+
46+
# Get the solution
47+
response, meta_problem = adapter.forward(problem)
48+
49+
print("Problem Type:", meta_problem.problem_type)
50+
print("\nMeta Problem:", meta_problem.meta_problem)
51+
print("\nRestatement:", meta_problem.restatement)
52+
print("\nSolution:", response)
53+
54+
if __name__ == "__main__":
55+
main()

0 commit comments

Comments
 (0)