| |
| """MALM Inference Script - Run directly from Hugging Face. |
| |
| Usage: |
| # Install dependencies |
| pip install mlx huggingface_hub |
| |
| # Download and run |
| huggingface-cli download codelion/malm-165m --local-dir ./malm-165m |
| python malm-165m/inference.py --query "function that sorts a list" |
| """ |
|
|
| import mlx.core as mx |
| import mlx.nn as nn |
| import numpy as np |
| import json |
| import argparse |
| from pathlib import Path |
| from typing import List, Dict, Tuple |
| import re |
|
|
|
|
| class MALM(nn.Module): |
| """Memory-Augmented Language Model.""" |
|
|
| def __init__( |
| self, |
| vocab_size: int, |
| d_model: int = 768, |
| n_heads: int = 12, |
| n_layers: int = 12, |
| n_query_layers: int = 4, |
| max_seq_len: int = 128, |
| dropout: float = 0.0, |
| ): |
| super().__init__() |
| self.vocab_size = vocab_size |
| self.d_model = d_model |
| self.n_heads = n_heads |
| self.n_layers = n_layers |
| self.n_query_layers = n_query_layers |
| self.max_seq_len = max_seq_len |
|
|
| |
| self.embed = nn.Embedding(vocab_size, d_model) |
| self.pos_embed = nn.Embedding(max_seq_len, d_model) |
| self.embed_dropout = nn.Dropout(dropout) |
|
|
| |
| self.query_layers = [ |
| nn.TransformerEncoderLayer(d_model, n_heads, d_model * 4) |
| for _ in range(n_query_layers) |
| ] |
| self.query_ln = nn.LayerNorm(d_model) |
| self.query_proj = nn.Linear(d_model, d_model) |
|
|
| |
| self.value_layers = [ |
| nn.TransformerEncoderLayer(d_model, n_heads, d_model * 4) |
| for _ in range(n_query_layers) |
| ] |
| self.value_ln = nn.LayerNorm(d_model) |
| self.value_proj = nn.Linear(d_model, d_model) |
|
|
| |
| self.decoder_layers = [ |
| nn.TransformerEncoderLayer(d_model, n_heads, d_model * 4) |
| for _ in range(n_layers) |
| ] |
| self.decoder_ln = nn.LayerNorm(d_model) |
|
|
| |
| self.output = nn.Linear(d_model, vocab_size) |
|
|
| |
| self.log_temp = mx.array([0.0]) |
|
|
| def encode_query(self, query_ids: mx.array) -> mx.array: |
| """Encode query to single embedding.""" |
| B, L = query_ids.shape |
|
|
| h = self.embed(query_ids) |
| pos = mx.arange(min(L, self.max_seq_len)) |
| h = h + self.pos_embed(pos) |
| h = self.embed_dropout(h) |
|
|
| for layer in self.query_layers: |
| h = layer(h, None) |
|
|
| h = self.query_ln(h) |
|
|
| mask = (query_ids != 0).astype(mx.float32)[:, :, None] |
| h = h * mask |
| query_emb = mx.sum(h, axis=1) / (mx.sum(mask, axis=1) + 1e-8) |
|
|
| return self.query_proj(query_emb) |
|
|
| def encode_value(self, value_ids: mx.array) -> mx.array: |
| """Encode value to single embedding.""" |
| B, L = value_ids.shape |
|
|
| h = self.embed(value_ids) |
| pos = mx.arange(min(L, self.max_seq_len)) |
| h = h + self.pos_embed(pos) |
|
|
| for layer in self.value_layers: |
| h = layer(h, None) |
|
|
| h = self.value_ln(h) |
|
|
| mask = (value_ids != 0).astype(mx.float32)[:, :, None] |
| h = h * mask |
| val_emb = mx.sum(h, axis=1) / (mx.sum(mask, axis=1) + 1e-8) |
|
|
| return self.value_proj(val_emb) |
|
|
| def retrieve( |
| self, |
| query_emb: mx.array, |
| key_emb: mx.array, |
| val_emb: mx.array, |
| ) -> Tuple[mx.array, mx.array, mx.array]: |
| """Retrieve from memory.""" |
| scale = self.d_model ** -0.5 |
| temp = mx.exp(self.log_temp) + 0.1 |
|
|
| scores = (query_emb @ key_emb.T) * scale / temp |
| attn = mx.softmax(scores, axis=-1) |
| retrieved = attn @ val_emb |
|
|
| return retrieved, attn, scores |
|
|
|
|
| class Tokenizer: |
| """Simple tokenizer for MALM.""" |
|
|
| def __init__(self, tokenizer_dict: Dict): |
| self.token_to_id = tokenizer_dict.get("token_to_id", {}) |
| self.id_to_token = {int(v): k for k, v in self.token_to_id.items()} |
| self.special = {"<PAD>": 0, "<UNK>": 1, "<BOS>": 2, "<EOS>": 3} |
|
|
| def encode(self, text: str) -> List[int]: |
| """Tokenize text.""" |
| tokens = re.findall(r"[a-zA-Z_][a-zA-Z0-9_]*|[0-9]+|[^\s]", text.lower()) |
| return [self.token_to_id.get(t, self.special.get("<UNK>", 1)) for t in tokens] |
|
|
| def decode(self, ids: List[int]) -> str: |
| """Decode token IDs to text.""" |
| tokens = [self.id_to_token.get(i, "<UNK>") for i in ids] |
| return " ".join(tokens) |
|
|
|
|
| def load_model(model_dir: Path): |
| """Load MALM model from directory.""" |
| import mlx.utils as mlx_utils |
|
|
| |
| with open(model_dir / "config.json") as f: |
| config = json.load(f) |
|
|
| |
| model = MALM( |
| vocab_size=config["vocab_size"], |
| d_model=config["d_model"], |
| n_heads=config["n_heads"], |
| n_layers=config["n_layers"], |
| n_query_layers=config["n_query_layers"], |
| max_seq_len=config["max_seq_len"], |
| ) |
|
|
| |
| weights = dict(np.load(model_dir / "model.npz")) |
| weights = {k: mx.array(v) for k, v in weights.items()} |
|
|
| |
| params = mlx_utils.tree_unflatten(list(weights.items())) |
| model.update(params) |
| mx.eval(model.parameters()) |
|
|
| |
| with open(model_dir / "tokenizer.json") as f: |
| tokenizer_dict = json.load(f) |
| tokenizer = Tokenizer(tokenizer_dict) |
|
|
| |
| with open(model_dir / "functions.json") as f: |
| functions = json.load(f) |
|
|
| return model, tokenizer, functions, config |
|
|
|
|
| def search_functions( |
| model: MALM, |
| tokenizer: Tokenizer, |
| functions: List[Dict], |
| query: str, |
| top_k: int = 5, |
| ) -> List[Tuple[str, str, float]]: |
| """Search for functions matching a query. |
| |
| Uses the function name as key and signature+docstring as value for retrieval. |
| """ |
| |
| query_ids = tokenizer.encode(query) |
| if not query_ids: |
| query_ids = [1] |
| query_ids = mx.array([query_ids]) |
|
|
| |
| key_tokens = [] |
| value_tokens = [] |
| max_val_len = 64 |
|
|
| for func in functions: |
| name = func["name"] |
| |
| sig = func.get("signature", name) |
| doc = func.get("docstring", "") |
| value_text = f"{sig} {doc}" |
|
|
| key_id = tokenizer.token_to_id.get(name.lower(), 1) |
| key_tokens.append(key_id) |
|
|
| val_ids = tokenizer.encode(value_text)[:max_val_len] |
| val_ids = val_ids + [0] * (max_val_len - len(val_ids)) |
| value_tokens.append(val_ids) |
|
|
| key_tokens = mx.array(key_tokens) |
| value_tokens = mx.array(value_tokens) |
|
|
| |
| key_emb = model.embed(key_tokens) |
| val_emb = model.encode_value(value_tokens) |
|
|
| |
| query_emb = model.encode_query(query_ids) |
| _, attn, scores = model.retrieve(query_emb, key_emb, val_emb) |
| mx.eval(scores) |
|
|
| |
| scores_np = np.array(scores[0]) |
| top_indices = np.argsort(scores_np)[::-1][:top_k] |
|
|
| results = [] |
| for idx in top_indices: |
| func = functions[idx] |
| score = float(scores_np[idx]) |
| sig = func.get("signature", func["name"]) |
| doc = func.get("docstring", "") |
| results.append((func["name"], sig, doc, score)) |
|
|
| return results |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="MALM Inference - Semantic Code Search") |
| parser.add_argument("--query", type=str, required=True, help="Natural language query") |
| parser.add_argument("--top-k", type=int, default=5, help="Number of results") |
| parser.add_argument("--model-dir", type=str, default=None, help="Model directory") |
| args = parser.parse_args() |
|
|
| |
| if args.model_dir: |
| model_dir = Path(args.model_dir) |
| else: |
| model_dir = Path(__file__).parent |
|
|
| print(f"Loading model from {model_dir}...") |
| model, tokenizer, functions, config = load_model(model_dir) |
| print(f"Loaded {len(functions)} functions, {config['num_parameters']:,} parameters") |
|
|
| |
| print(f"\nQuery: {args.query}") |
| print("-" * 60) |
|
|
| results = search_functions(model, tokenizer, functions, args.query, args.top_k) |
|
|
| for i, (name, signature, docstring, score) in enumerate(results, 1): |
| print(f"\n{i}. {name} (score: {score:.4f})") |
| print(f" Signature: {signature}") |
| if docstring: |
| print(f" Docstring: {docstring[:100]}{'...' if len(docstring) > 100 else ''}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|