guohanghui commited on
Commit
8e8cf4f
·
verified ·
1 Parent(s): 413a390

Update pyro/mcp_output/mcp_plugin/mcp_service.py

Browse files
pyro/mcp_output/mcp_plugin/mcp_service.py CHANGED
@@ -1,97 +1,466 @@
 
 
 
 
 
 
 
 
 
1
  from fastmcp import FastMCP
 
 
 
 
 
2
 
3
  # Create the FastMCP service application
4
  mcp = FastMCP("pyro_service")
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- @mcp.tool(name="list_distributions", description="List all available distributions in Pyro")
8
  def list_distributions() -> dict:
9
  """
10
  List all available distributions in Pyro.
11
 
12
  Returns:
13
- - dict: A dictionary with success status and list of distributions.
14
  """
15
  try:
16
- from pyro.distributions import __all__ as distributions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  return {
18
  "success": True,
19
- "distributions": distributions,
20
- "count": len(distributions)
 
 
 
 
 
21
  }
22
  except Exception as e:
23
- return {"success": False, "error": str(e)}
24
 
25
 
26
- @mcp.tool(name="sample_from_distribution", description="Sample from a specific Pyro distribution")
27
- def sample_from_distribution(distribution_name: str, *args, **kwargs) -> dict:
28
  """
29
- Sample from a specific Pyro distribution.
30
 
31
  Parameters:
32
- - distribution_name: Name of the distribution (e.g., 'Normal', 'Beta')
33
- - *args, **kwargs: Parameters for the distribution
 
34
 
35
  Returns:
36
- - dict: Sampled values or error message.
37
  """
38
  try:
39
- from pyro.distributions import __dict__ as dist_dict
40
- if distribution_name not in dist_dict:
41
- return {
42
- "success": False,
43
- "error": f"Distribution '{distribution_name}' not found. Use list_distributions to see available options."
44
- }
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
- distribution = dist_dict[distribution_name](*args, **kwargs)
47
- samples = distribution.sample()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  return {
49
  "success": True,
50
- "samples": samples.tolist()
 
 
 
 
 
51
  }
52
  except Exception as e:
53
- return {"success": False, "error": str(e)}
54
 
55
 
56
- @mcp.tool(name="run_inference", description="Run inference using a specific Pyro model")
57
- def run_inference(model_name: str, data: dict) -> dict:
58
  """
59
- Run inference using a specific Pyro model.
60
 
61
  Parameters:
62
- - model_name: Name of the model to use
63
- - data: Input data for the model
64
 
65
  Returns:
66
- - dict: Inference results or error message.
67
  """
68
  try:
69
- from pyro.infer import SVI, Trace_ELBO
70
- from pyro.optim import Adam
71
- from pyro import __dict__ as pyro_dict
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- if model_name not in pyro_dict:
74
- return {
75
- "success": False,
76
- "error": f"Model '{model_name}' not found."
77
- }
78
 
79
- model = pyro_dict[model_name]
80
- guide = pyro_dict.get(f"{model_name}_guide")
81
- if not guide:
82
- return {
83
- "success": False,
84
- "error": f"Guide for model '{model_name}' not found."
85
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- svi = SVI(model, guide, Adam({"lr": 0.01}), Trace_ELBO())
88
- loss = svi.step(data)
 
 
 
 
 
 
 
89
  return {
90
  "success": True,
91
- "loss": loss
 
 
 
 
 
 
92
  }
93
  except Exception as e:
94
- return {"success": False, "error": str(e)}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
95
 
96
 
97
  def create_app() -> FastMCP:
@@ -99,6 +468,6 @@ def create_app() -> FastMCP:
99
  Create and return the FastMCP application instance.
100
 
101
  Returns:
102
- - FastMCP: The FastMCP application instance.
103
  """
104
  return mcp
 
1
+ import os
2
+ import sys
3
+ from typing import Dict, Any, List, Optional
4
+
5
+ # Add the local source directory to sys.path
6
+ source_path = os.path.join(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))), "source")
7
+ if source_path not in sys.path:
8
+ sys.path.insert(0, source_path)
9
+
10
  from fastmcp import FastMCP
11
+ import torch
12
+ import pyro
13
+ import pyro.distributions as dist
14
+ from pyro.infer import SVI, Trace_ELBO, MCMC, NUTS, Predictive
15
+ from pyro.optim import Adam
16
 
17
  # Create the FastMCP service application
18
  mcp = FastMCP("pyro_service")
19
 
20
+ # Store models and guides as string IDs
21
+ _models = {}
22
+ _guides = {}
23
+ _svi_instances = {}
24
+ _mcmc_instances = {}
25
+
26
+ @mcp.tool(name="get_pyro_info")
27
+ def get_pyro_info() -> dict:
28
+ """
29
+ Get information about the Pyro library.
30
+
31
+ Returns:
32
+ dict: Version and configuration information about Pyro.
33
+ """
34
+ try:
35
+ return {
36
+ "success": True,
37
+ "result": {
38
+ "version": pyro.__version__,
39
+ "torch_version": torch.__version__,
40
+ "cuda_available": torch.cuda.is_available(),
41
+ "backend": "torch",
42
+ },
43
+ "error": None,
44
+ }
45
+ except Exception as e:
46
+ return {"success": False, "result": None, "error": str(e)}
47
+
48
 
49
+ @mcp.tool(name="list_distributions")
50
  def list_distributions() -> dict:
51
  """
52
  List all available distributions in Pyro.
53
 
54
  Returns:
55
+ dict: A list of available distribution names organized by category.
56
  """
57
  try:
58
+ basic_dists = [
59
+ "Normal", "Bernoulli", "Beta", "Binomial", "Categorical", "Cauchy",
60
+ "Dirichlet", "Exponential", "Gamma", "Geometric", "LogNormal",
61
+ "Multinomial", "MultivariateNormal", "Poisson", "Uniform"
62
+ ]
63
+
64
+ hmm_dists = [
65
+ "DiscreteHMM", "GaussianHMM", "GammaGaussianHMM", "LinearHMM",
66
+ "GaussianMRF", "IndependentHMM"
67
+ ]
68
+
69
+ advanced_dists = [
70
+ "Delta", "Empirical", "MixtureOfDiagNormals", "TransformedDistribution",
71
+ "ConditionalDistribution", "ZeroInflatedPoisson", "ZeroInflatedNegativeBinomial"
72
+ ]
73
+
74
  return {
75
  "success": True,
76
+ "result": {
77
+ "basic": basic_dists,
78
+ "hidden_markov_models": hmm_dists,
79
+ "advanced": advanced_dists,
80
+ "total_count": len(basic_dists) + len(hmm_dists) + len(advanced_dists),
81
+ },
82
+ "error": None,
83
  }
84
  except Exception as e:
85
+ return {"success": False, "result": None, "error": str(e)}
86
 
87
 
88
+ @mcp.tool(name="sample_normal")
89
+ def sample_normal(loc: float = 0.0, scale: float = 1.0, sample_shape: Optional[List[int]] = None) -> dict:
90
  """
91
+ Sample from a Normal distribution.
92
 
93
  Parameters:
94
+ loc (float): Mean of the distribution.
95
+ scale (float): Standard deviation of the distribution.
96
+ sample_shape (Optional[List[int]]): Shape of samples to draw.
97
 
98
  Returns:
99
+ dict: Samples from the Normal distribution.
100
  """
101
  try:
102
+ normal = dist.Normal(loc, scale)
103
+ if sample_shape:
104
+ samples = normal.sample(torch.Size(sample_shape))
105
+ else:
106
+ samples = normal.sample()
107
+
108
+ return {
109
+ "success": True,
110
+ "result": {
111
+ "samples": samples.tolist() if isinstance(samples, torch.Tensor) else float(samples),
112
+ "distribution": "Normal",
113
+ "parameters": {"loc": loc, "scale": scale},
114
+ },
115
+ "error": None,
116
+ }
117
+ except Exception as e:
118
+ return {"success": False, "result": None, "error": str(e)}
119
+
120
 
121
+ @mcp.tool(name="sample_bernoulli")
122
+ def sample_bernoulli(probs: float = 0.5, sample_shape: Optional[List[int]] = None) -> dict:
123
+ """
124
+ Sample from a Bernoulli distribution.
125
+
126
+ Parameters:
127
+ probs (float): Probability of success (between 0 and 1).
128
+ sample_shape (Optional[List[int]]): Shape of samples to draw.
129
+
130
+ Returns:
131
+ dict: Samples from the Bernoulli distribution.
132
+ """
133
+ try:
134
+ bernoulli = dist.Bernoulli(probs)
135
+ if sample_shape:
136
+ samples = bernoulli.sample(torch.Size(sample_shape))
137
+ else:
138
+ samples = bernoulli.sample()
139
+
140
  return {
141
  "success": True,
142
+ "result": {
143
+ "samples": samples.tolist() if isinstance(samples, torch.Tensor) else int(samples),
144
+ "distribution": "Bernoulli",
145
+ "parameters": {"probs": probs},
146
+ },
147
+ "error": None,
148
  }
149
  except Exception as e:
150
+ return {"success": False, "result": None, "error": str(e)}
151
 
152
 
153
+ @mcp.tool(name="sample_categorical")
154
+ def sample_categorical(probs: List[float], sample_shape: Optional[List[int]] = None) -> dict:
155
  """
156
+ Sample from a Categorical distribution.
157
 
158
  Parameters:
159
+ probs (List[float]): Probabilities for each category (must sum to 1).
160
+ sample_shape (Optional[List[int]]): Shape of samples to draw.
161
 
162
  Returns:
163
+ dict: Samples from the Categorical distribution.
164
  """
165
  try:
166
+ probs_tensor = torch.tensor(probs)
167
+ categorical = dist.Categorical(probs_tensor)
168
+ if sample_shape:
169
+ samples = categorical.sample(torch.Size(sample_shape))
170
+ else:
171
+ samples = categorical.sample()
172
+
173
+ return {
174
+ "success": True,
175
+ "result": {
176
+ "samples": samples.tolist() if isinstance(samples, torch.Tensor) else int(samples),
177
+ "distribution": "Categorical",
178
+ "parameters": {"probs": probs},
179
+ },
180
+ "error": None,
181
+ }
182
+ except Exception as e:
183
+ return {"success": False, "result": None, "error": str(e)}
184
 
 
 
 
 
 
185
 
186
+ @mcp.tool(name="create_simple_model")
187
+ def create_simple_model(model_id: str, model_type: str = "normal_normal") -> dict:
188
+ """
189
+ Create a simple probabilistic model.
190
+
191
+ Parameters:
192
+ model_id (str): Unique identifier for the model.
193
+ model_type (str): Type of model ('normal_normal', 'coin_flip', 'linear_regression').
194
+
195
+ Returns:
196
+ dict: Information about the created model.
197
+ """
198
+ try:
199
+ if model_type == "normal_normal":
200
+ def model(data=None):
201
+ loc = pyro.sample("loc", dist.Normal(0.0, 1.0))
202
+ scale = pyro.sample("scale", dist.LogNormal(0.0, 1.0))
203
+ with pyro.plate("data", len(data) if data is not None else 1):
204
+ return pyro.sample("obs", dist.Normal(loc, scale), obs=data)
205
+
206
+ elif model_type == "coin_flip":
207
+ def model(data=None):
208
+ p = pyro.sample("p", dist.Beta(2.0, 2.0))
209
+ with pyro.plate("data", len(data) if data is not None else 1):
210
+ return pyro.sample("obs", dist.Bernoulli(p), obs=data)
211
+
212
+ elif model_type == "linear_regression":
213
+ def model(x=None, y=None):
214
+ a = pyro.sample("a", dist.Normal(0.0, 10.0))
215
+ b = pyro.sample("b", dist.Normal(0.0, 10.0))
216
+ sigma = pyro.sample("sigma", dist.LogNormal(0.0, 1.0))
217
+ if x is not None:
218
+ mean = a + b * x
219
+ with pyro.plate("data", len(x)):
220
+ return pyro.sample("obs", dist.Normal(mean, sigma), obs=y)
221
+ else:
222
+ return {"success": False, "result": None, "error": f"Unknown model type: {model_type}"}
223
+
224
+ _models[model_id] = model
225
+
226
+ return {
227
+ "success": True,
228
+ "result": {
229
+ "model_id": model_id,
230
+ "model_type": model_type,
231
+ "message": f"Model '{model_id}' created successfully",
232
+ },
233
+ "error": None,
234
+ }
235
+ except Exception as e:
236
+ return {"success": False, "result": None, "error": str(e)}
237
+
238
+
239
+ @mcp.tool(name="create_guide")
240
+ def create_guide(guide_id: str, model_id: str, guide_type: str = "auto_normal") -> dict:
241
+ """
242
+ Create a variational guide for a model.
243
+
244
+ Parameters:
245
+ guide_id (str): Unique identifier for the guide.
246
+ model_id (str): ID of the model to create a guide for.
247
+ guide_type (str): Type of guide ('auto_normal', 'auto_delta').
248
+
249
+ Returns:
250
+ dict: Information about the created guide.
251
+ """
252
+ try:
253
+ if model_id not in _models:
254
+ return {"success": False, "result": None, "error": f"Model '{model_id}' not found"}
255
+
256
+ model = _models[model_id]
257
+
258
+ if guide_type == "auto_normal":
259
+ from pyro.infer.autoguide import AutoNormal
260
+ guide = AutoNormal(model)
261
+ elif guide_type == "auto_delta":
262
+ from pyro.infer.autoguide import AutoDelta
263
+ guide = AutoDelta(model)
264
+ else:
265
+ return {"success": False, "result": None, "error": f"Unknown guide type: {guide_type}"}
266
+
267
+ _guides[guide_id] = guide
268
+
269
+ return {
270
+ "success": True,
271
+ "result": {
272
+ "guide_id": guide_id,
273
+ "model_id": model_id,
274
+ "guide_type": guide_type,
275
+ "message": f"Guide '{guide_id}' created successfully",
276
+ },
277
+ "error": None,
278
+ }
279
+ except Exception as e:
280
+ return {"success": False, "result": None, "error": str(e)}
281
+
282
+
283
+ @mcp.tool(name="run_svi")
284
+ def run_svi(
285
+ svi_id: str,
286
+ model_id: str,
287
+ guide_id: str,
288
+ num_steps: int = 1000,
289
+ learning_rate: float = 0.01
290
+ ) -> dict:
291
+ """
292
+ Run Stochastic Variational Inference.
293
+
294
+ Parameters:
295
+ svi_id (str): Unique identifier for this SVI instance.
296
+ model_id (str): ID of the model.
297
+ guide_id (str): ID of the guide.
298
+ num_steps (int): Number of optimization steps.
299
+ learning_rate (float): Learning rate for optimization.
300
+
301
+ Returns:
302
+ dict: Training results including loss history.
303
+ """
304
+ try:
305
+ if model_id not in _models:
306
+ return {"success": False, "result": None, "error": f"Model '{model_id}' not found"}
307
+ if guide_id not in _guides:
308
+ return {"success": False, "result": None, "error": f"Guide '{guide_id}' not found"}
309
+
310
+ model = _models[model_id]
311
+ guide = _guides[guide_id]
312
+
313
+ # Clear parameter store
314
+ pyro.clear_param_store()
315
+
316
+ # Create SVI
317
+ optimizer = Adam({"lr": learning_rate})
318
+ svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
319
+
320
+ # Training loop with synthetic data
321
+ data = torch.randn(100)
322
+ losses = []
323
+ for step in range(num_steps):
324
+ loss = svi.step(data)
325
+ losses.append(loss)
326
+ if step % 100 == 0:
327
+ print(f"Step {step}, Loss: {loss}")
328
+
329
+ _svi_instances[svi_id] = svi
330
+
331
+ return {
332
+ "success": True,
333
+ "result": {
334
+ "svi_id": svi_id,
335
+ "model_id": model_id,
336
+ "guide_id": guide_id,
337
+ "num_steps": num_steps,
338
+ "final_loss": losses[-1],
339
+ "loss_history": losses[::max(1, num_steps // 20)], # Return ~20 points
340
+ "message": "SVI completed successfully",
341
+ },
342
+ "error": None,
343
+ }
344
+ except Exception as e:
345
+ return {"success": False, "result": None, "error": str(e)}
346
+
347
 
348
+ @mcp.tool(name="list_models")
349
+ def list_models() -> dict:
350
+ """
351
+ List all stored models.
352
+
353
+ Returns:
354
+ dict: List of model IDs.
355
+ """
356
+ try:
357
  return {
358
  "success": True,
359
+ "result": {
360
+ "models": list(_models.keys()),
361
+ "guides": list(_guides.keys()),
362
+ "svi_instances": list(_svi_instances.keys()),
363
+ "mcmc_instances": list(_mcmc_instances.keys()),
364
+ },
365
+ "error": None,
366
  }
367
  except Exception as e:
368
+ return {"success": False, "result": None, "error": str(e)}
369
+
370
+
371
+ @mcp.tool(name="delete_model")
372
+ def delete_model(model_id: str) -> dict:
373
+ """
374
+ Delete a stored model.
375
+
376
+ Parameters:
377
+ model_id (str): ID of the model to delete.
378
+
379
+ Returns:
380
+ dict: Confirmation of deletion.
381
+ """
382
+ try:
383
+ if model_id not in _models:
384
+ return {"success": False, "result": None, "error": f"Model '{model_id}' not found"}
385
+
386
+ del _models[model_id]
387
+
388
+ return {
389
+ "success": True,
390
+ "result": {"message": f"Model '{model_id}' deleted successfully"},
391
+ "error": None,
392
+ }
393
+ except Exception as e:
394
+ return {"success": False, "result": None, "error": str(e)}
395
+
396
+
397
+ @mcp.tool(name="get_distribution_info")
398
+ def get_distribution_info(distribution_name: str) -> dict:
399
+ """
400
+ Get information about a specific distribution.
401
+
402
+ Parameters:
403
+ distribution_name (str): Name of the distribution (e.g., 'Normal', 'Bernoulli').
404
+
405
+ Returns:
406
+ dict: Information about the distribution including parameters.
407
+ """
408
+ try:
409
+ dist_info = {
410
+ "Normal": {
411
+ "parameters": ["loc (mean)", "scale (std dev)"],
412
+ "support": "real numbers",
413
+ "description": "Gaussian/Normal distribution",
414
+ },
415
+ "Bernoulli": {
416
+ "parameters": ["probs (probability of 1)"],
417
+ "support": "{0, 1}",
418
+ "description": "Binary distribution",
419
+ },
420
+ "Beta": {
421
+ "parameters": ["concentration1 (alpha)", "concentration0 (beta)"],
422
+ "support": "[0, 1]",
423
+ "description": "Beta distribution for probabilities",
424
+ },
425
+ "Categorical": {
426
+ "parameters": ["probs (probability vector)"],
427
+ "support": "{0, 1, ..., K-1}",
428
+ "description": "Categorical distribution over K categories",
429
+ },
430
+ "Dirichlet": {
431
+ "parameters": ["concentration (alpha vector)"],
432
+ "support": "probability simplex",
433
+ "description": "Dirichlet distribution for probability vectors",
434
+ },
435
+ "DiscreteHMM": {
436
+ "parameters": ["initial_logits", "transition_logits", "observation_dist"],
437
+ "support": "sequences of discrete states",
438
+ "description": "Discrete Hidden Markov Model",
439
+ },
440
+ "GaussianHMM": {
441
+ "parameters": ["initial_dist", "transition_matrix", "observation_matrix"],
442
+ "support": "sequences of continuous observations",
443
+ "description": "Gaussian Hidden Markov Model",
444
+ },
445
+ }
446
+
447
+ if distribution_name in dist_info:
448
+ return {
449
+ "success": True,
450
+ "result": {
451
+ "distribution": distribution_name,
452
+ **dist_info[distribution_name],
453
+ },
454
+ "error": None,
455
+ }
456
+ else:
457
+ return {
458
+ "success": False,
459
+ "result": None,
460
+ "error": f"Distribution '{distribution_name}' not found in info database",
461
+ }
462
+ except Exception as e:
463
+ return {"success": False, "result": None, "error": str(e)}
464
 
465
 
466
  def create_app() -> FastMCP:
 
468
  Create and return the FastMCP application instance.
469
 
470
  Returns:
471
+ The FastMCP application instance.
472
  """
473
  return mcp