Upload README.md
Browse files
README.md
CHANGED
|
@@ -1,141 +1,63 @@
|
|
| 1 |
-
---
|
| 2 |
-
tags:
|
| 3 |
-
- ml-intern
|
| 4 |
-
---
|
| 5 |
# Speculative Tool Actions
|
| 6 |
|
| 7 |
Investigating whether speculative decoding can be adapted from token prediction to agent action prediction.
|
| 8 |
|
| 9 |
-
|
| 10 |
|
| 11 |
-
|
| 12 |
|
| 13 |
-
##
|
| 14 |
-
|
| 15 |
-
- `tool_call` - Execute external tool
|
| 16 |
-
- `retrieval` - Retrieve information
|
| 17 |
-
- `file_read` - Read from file system
|
| 18 |
-
- `file_write` - Write to file system
|
| 19 |
-
- `repair` - Attempt self-repair
|
| 20 |
-
- `verifier` - Invoke verification
|
| 21 |
-
- `ask_clarification` - Request user clarification
|
| 22 |
-
- `final_answer` - Provide final response
|
| 23 |
-
- `BLOCKED` - Block unsafe action (critical for safety)
|
| 24 |
-
|
| 25 |
-
## Architecture
|
| 26 |
-
|
| 27 |
-
```
|
| 28 |
-
User Task
|
| 29 |
-
β
|
| 30 |
-
βΌ
|
| 31 |
-
βββββββββββββββββββ βββββββββββββββββββ βββββββββββββββββββ
|
| 32 |
-
β Cheap Model ββββββΆβ Verifier ββββββΆβ Strong Model β
|
| 33 |
-
β (Qwen3-1.7B) β β (Strong or β β (Qwen2.5-7B) β
|
| 34 |
-
β Proposes action β β Trained 4B) β β Fallback/Repair β
|
| 35 |
-
βββββββββββββββββββ βββββββββββββββββββ βββββββββββββββββββ
|
| 36 |
-
```
|
| 37 |
-
|
| 38 |
-
## Evaluation Configurations
|
| 39 |
-
|
| 40 |
-
| Config | Name | Description |
|
| 41 |
-
|--------|------|-------------|
|
| 42 |
-
| A | Always Strong | Baseline: always use strong model |
|
| 43 |
-
| B | Cheap Only | Always use cheap model |
|
| 44 |
-
| C | Cheap + Strong Verifier | Propose cheap, verify with strong |
|
| 45 |
-
| D | Cheap + Trained Judge | Propose cheap, verify with trained verifier |
|
| 46 |
-
| E | Multi-Proposal Reranking | Generate multiple proposals, strong reranks |
|
| 47 |
-
|
| 48 |
-
## Datasets
|
| 49 |
-
|
| 50 |
-
| Dataset | Size | Purpose |
|
| 51 |
-
|---------|------|---------|
|
| 52 |
-
| `speculative-actions-proposer-sft` | 5K train, 17K test | Proposer training (SFT) |
|
| 53 |
-
| `speculative-actions-verifier-pref` | 5K train, 17K test | Verifier training (Reward) |
|
| 54 |
-
| `speculative-actions-eval` | 500 examples | Evaluation benchmark |
|
| 55 |
-
|
| 56 |
-
## Models
|
| 57 |
-
|
| 58 |
-
| Model | Size | Role |
|
| 59 |
-
|-------|------|------|
|
| 60 |
-
| `Qwen/Qwen3-1.7B` | 1.7B | Proposer (cheap) |
|
| 61 |
-
| `Qwen/Qwen3-4B` | 4B | Trained verifier |
|
| 62 |
-
| `Qwen/Qwen2.5-7B` | 7B | Strong model (baseline) |
|
| 63 |
-
|
| 64 |
-
## Results (Expected)
|
| 65 |
-
|
| 66 |
-
| Config | Accuracy | Avg Cost | Safety |
|
| 67 |
-
|--------|----------|----------|--------|
|
| 68 |
-
| A | 0.85 | 1.00 | 0.82 |
|
| 69 |
-
| B | 0.62 | 0.20 | 0.65 |
|
| 70 |
-
| C | 0.78 | 0.55 | 0.88 |
|
| 71 |
-
| D | 0.75 | 0.42 | 0.85 |
|
| 72 |
-
| E | 0.81 | 0.75 | 0.80 |
|
| 73 |
-
|
| 74 |
-
**Best trade-off**: Config D (Cheap + Trained Judge) - 75% accuracy at 42% of the cost.
|
| 75 |
-
|
| 76 |
-
## Files
|
| 77 |
-
|
| 78 |
-
- `synthetic_data_and_train.py` - Full pipeline (data gen + train)
|
| 79 |
-
- `generate_data_only.py` - Standalone dataset generator
|
| 80 |
-
- `train_proposer.py` - Proposer SFT training
|
| 81 |
-
- `train_verifier.py` - Verifier RewardModel training
|
| 82 |
-
- `eval_base_models.py` - Evaluation with base models
|
| 83 |
-
- `eval_runner.py` - Evaluation with trained models
|
| 84 |
-
- `eval_runner_simple.py` - Simulated evaluation
|
| 85 |
-
- `ABLACTION_REPORT.md` - Detailed ablation report
|
| 86 |
-
|
| 87 |
-
## Usage
|
| 88 |
-
|
| 89 |
-
### Generate Data
|
| 90 |
```bash
|
| 91 |
-
python
|
| 92 |
```
|
| 93 |
|
| 94 |
-
|
| 95 |
-
```
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
### Train Verifier
|
| 100 |
-
```bash
|
| 101 |
-
python train_verifier.py
|
| 102 |
```
|
| 103 |
|
| 104 |
-
###
|
| 105 |
```bash
|
| 106 |
-
python
|
| 107 |
```
|
| 108 |
|
| 109 |
-
##
|
| 110 |
-
|
| 111 |
-
Costs are normalized relative to the strong model (1.0):
|
| 112 |
-
- Strong model (7B): 1.0 per inference
|
| 113 |
-
- Cheap model (1.7B): 0.2 per inference
|
| 114 |
-
- Verifier (4B): 0.3 per inference
|
| 115 |
-
- Trained judge (4B LoRA): 0.15 per inference
|
| 116 |
-
|
| 117 |
-
## Key Findings
|
| 118 |
-
|
| 119 |
-
1. Speculative action prediction achieves 88% of strong model accuracy at 42% cost
|
| 120 |
-
2. Verifier is crucial for safety - improves from 0.65 to 0.88
|
| 121 |
-
3. Trained judge nearly matches strong verifier at lower cost
|
| 122 |
-
4. Multi-proposal reranking is expensive and dominated by other configs
|
| 123 |
-
5. `final_answer` is easiest (95% accuracy); `repair` is hardest (55% cheap, 72% with verifier)
|
| 124 |
-
|
| 125 |
-
## Citation
|
| 126 |
-
|
| 127 |
-
Based on:
|
| 128 |
-
- *DualSpec*: Draft-Target Verification
|
| 129 |
-
- *Tool-Star*: Small Model for Draft Generation
|
| 130 |
-
- *TinyV*: Tiny Verifier for Efficient Verification
|
| 131 |
|
| 132 |
-
--
|
| 133 |
-
*Generated by ML Intern*
|
| 134 |
|
| 135 |
-
|
| 136 |
-
## Generated by ML Intern
|
| 137 |
|
| 138 |
-
|
| 139 |
|
| 140 |
-
|
| 141 |
-
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
# Speculative Tool Actions
|
| 2 |
|
| 3 |
Investigating whether speculative decoding can be adapted from token prediction to agent action prediction.
|
| 4 |
|
| 5 |
+
**Current state:** v2 evaluation complete (see [ABLATION_REPORT_v2.md](./ABLATION_REPORT_v2.md)). v3 datasets + 1.7B proposer trained. **Need:** train 4B verifier + 8B proposer, then run eval.
|
| 6 |
|
| 7 |
+
## Quick Start: Complete the Project
|
| 8 |
|
| 9 |
+
### One-command training (A100-large, ~2h):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
```bash
|
| 11 |
+
python train_all_v3.py
|
| 12 |
```
|
| 13 |
|
| 14 |
+
Or via HF Jobs:
|
| 15 |
+
```python
|
| 16 |
+
hf_jobs(operation="run", script="https://hf.co/narcolepticchicken/speculative-tool-actions/resolve/main/train_all_v3.py",
|
| 17 |
+
dependencies=["transformers>=4.51","trl","torch","datasets","accelerate","peft","huggingface_hub"],
|
| 18 |
+
hardware_flavor="a100-large", timeout="12h")
|
|
|
|
|
|
|
|
|
|
| 19 |
```
|
| 20 |
|
| 21 |
+
### Then evaluate:
|
| 22 |
```bash
|
| 23 |
+
python eval_runner_v3.py
|
| 24 |
```
|
| 25 |
|
| 26 |
+
## Architecture
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
A cheap model (Qwen3-1.7B LoRA) proposes the next agent action. A verifier (Qwen3-4B LoRA) accepts or rejects. On rejection, fall back to the expensive 8B model.
|
|
|
|
| 29 |
|
| 30 |
+
**Action space:** `tool_call`, `retrieval`, `file_read`, `file_write`, `repair`, `verifier`, `ask_clarification`, `final_answer`, `BLOCKED`
|
|
|
|
| 31 |
|
| 32 |
+
## Files
|
| 33 |
|
| 34 |
+
| File | Purpose |
|
| 35 |
+
|------|---------|
|
| 36 |
+
| `train_all_v3.py` | Consolidated: trains 1.7B+4B+8B sequentially |
|
| 37 |
+
| `train_sft_v3.py` | Individual proposer training |
|
| 38 |
+
| `train_verifier_v3.py` | Individual verifier training |
|
| 39 |
+
| `eval_runner_v3.py` | All-5-configs evaluation |
|
| 40 |
+
| `PROJECT_REPORT.md` | Full project documentation + v2 results |
|
| 41 |
+
| `ABLATION_REPORT_v2.md` | v2 analysis (51% cheap vs 40% frozen 8B) |
|
| 42 |
+
| `eval_results_v2.json` | v2 raw results |
|
| 43 |
+
|
| 44 |
+
## v2 Results
|
| 45 |
+
|
| 46 |
+
| Config | Acc | Cost |
|
| 47 |
+
|--------|-----|------|
|
| 48 |
+
| A: 8B frozen | 40% | 1.00 |
|
| 49 |
+
| B: 1.7B cheap | **51%** | **0.15** |
|
| 50 |
+
| D: cheap + 4B RM | 51% | 0.25 |
|
| 51 |
+
| E: multi-proposal | 42% | 0.75 |
|
| 52 |
+
|
| 53 |
+
See [ABLATION_REPORT_v2.md](./ABLATION_REPORT_v2.md) for analysis.
|
| 54 |
+
|
| 55 |
+
## v3 Status
|
| 56 |
+
|
| 57 |
+
| Component | Status |
|
| 58 |
+
|-----------|--------|
|
| 59 |
+
| Datasets (SFT, verifier, eval) | β Built |
|
| 60 |
+
| 1.7B proposer | β Trained |
|
| 61 |
+
| 4B verifier | β Needs training |
|
| 62 |
+
| 8B proposer | β Needs training |
|
| 63 |
+
| Eval runner | β Ready |
|