| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| import sys |
| from unittest.mock import MagicMock, patch |
|
|
| import pytest |
|
|
| from verl.utils.rollout_trace import RolloutTraceConfig, rollout_trace_attr, rollout_trace_op |
|
|
|
|
| @pytest.fixture(autouse=True) |
| def reset_rollout_trace_config_singleton(): |
| """Fixture to reset the RolloutTraceConfig singleton before each test.""" |
| RolloutTraceConfig.reset() |
|
|
|
|
| @pytest.fixture |
| def mock_weave_client(): |
| """Mocks the weave module and its client, yielding the mock client.""" |
| mock_weave = MagicMock() |
| mock_client = MagicMock() |
| mock_call = MagicMock() |
| mock_client.create_call.return_value = mock_call |
| mock_weave.init.return_value = mock_client |
|
|
| |
| mock_weave.trace.context.call_context.return_value = MagicMock() |
|
|
| with patch.dict(sys.modules, {"weave": mock_weave, "weave.trace.context": mock_weave.trace.context}): |
| yield mock_client |
|
|
|
|
| class TracedClass: |
| @rollout_trace_op |
| |
| |
| async def my_method(self, a, b="default"): |
| return f"result: {a}, {b}" |
|
|
| @rollout_trace_op |
| |
| |
| async def middle_method(self, a, b="default"): |
| await self.my_method("test_a1", b="test_b1") |
| return f"result: {a}, {b}" |
|
|
| @rollout_trace_op |
| |
| async def my_method_with_exception(self): |
| raise ValueError("Test Exception") |
|
|
| async def upper_method(self): |
| await self.my_method("test_a0", b="test_b0") |
| await self.middle_method("test_a2", b="test_b2") |
| return True |
|
|
|
|
| class UntracedClass: |
| @rollout_trace_op |
| async def my_method(self, x): |
| return x * 2 |
|
|
|
|
| async def test_rollout_trace_on_untraced_class(): |
| """Tests that the decorator works correctly when no backend is configured.""" |
| instance = UntracedClass() |
| assert await instance.my_method(10) == 20 |
|
|
|
|
| async def test_rollout_trace_with_tracer(mock_weave_client): |
| """Tests that the decorator calls the tracer's methods correctly.""" |
| RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave") |
| instance = TracedClass() |
| assert RolloutTraceConfig.get_client() is mock_weave_client |
|
|
| result = await instance.my_method("test_a", b="test_b") |
|
|
| assert result == "result: test_a, test_b" |
| mock_weave_client.create_call.assert_called_once() |
| call_kwargs = mock_weave_client.create_call.call_args.kwargs |
| assert call_kwargs["op"] == "TracedClass.my_method" |
| expected_inputs = {"a": "test_a", "b": "test_b"} |
| assert call_kwargs["inputs"] == expected_inputs |
|
|
| mock_call = mock_weave_client.create_call.return_value |
| mock_weave_client.finish_call.assert_called_once_with(mock_call, output=result) |
|
|
|
|
| async def test_rollout_trace_with_exception(mock_weave_client): |
| """Tests that `finish` is called with the exception when one is raised.""" |
| RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave") |
| instance = TracedClass() |
|
|
| with pytest.raises(ValueError, match="Test Exception"): |
| await instance.my_method_with_exception() |
|
|
| mock_weave_client.create_call.assert_called_once() |
| mock_call = mock_weave_client.create_call.return_value |
| mock_weave_client.finish_call.assert_called_once() |
|
|
| |
| args, kwargs = mock_weave_client.finish_call.call_args |
| assert args[0] == mock_call |
| assert "exception" in kwargs |
| assert isinstance(kwargs["exception"], ValueError) |
|
|
|
|
| async def test_rollout_trace_with_dummy_backend(mock_weave_client): |
| """Tests that the tracer is not called when the backend is 'dummy'.""" |
| RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="dummy") |
| instance = TracedClass() |
|
|
| await instance.my_method("test_a") |
|
|
| mock_weave_client.create_call.assert_not_called() |
|
|
|
|
| @pytest.mark.skipif( |
| os.environ.get("RUN_WEAVE_INTEGRATION_TESTS", "false").lower() != "true", |
| reason="Skipping weave integration test. Set RUN_WEAVE_INTEGRATION_TESTS=true to run.", |
| ) |
| async def test_rollout_trace_with_real_weave_backend(): |
| """Integration test with a real weave backend.""" |
|
|
| |
| RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="weave") |
|
|
| instance = TracedClass() |
|
|
| with rollout_trace_attr(step=1, sample_index=2, rollout_n=3): |
| await instance.upper_method() |
|
|
| with pytest.raises(ValueError, match="Test Exception"): |
| await instance.my_method_with_exception() |
|
|
| print("\nWeave integration test ran successfully. Check your weave project for the trace.") |
|
|
|
|
| @pytest.mark.skipif( |
| os.environ.get("RUN_MLFLOW_INTEGRATION_TESTS", "false").lower() != "true", |
| reason="Skipping mlflow integration test. Set RUN_MLFLOW_INTEGRATION_TESTS=true to run.", |
| ) |
| async def test_rollout_trace_with_real_mlflow_backend(): |
| """Integration test with a real mlflow backend.""" |
|
|
| |
| RolloutTraceConfig.init(project_name="my-project", experiment_name="my-experiment", backend="mlflow") |
|
|
| instance = TracedClass() |
|
|
| with rollout_trace_attr(step=1, sample_index=2, rollout_n=3, name="agent_run"): |
| assert await instance.upper_method() |
|
|
| |
| |
|
|
| print("\nWeave integration test ran successfully. Check your weave project for the trace.") |
|
|