| #if UNITY_EDITOR || UNITY_STANDALONE |
| #define MLA_SUPPORTED_TRAINING_PLATFORM |
| #endif |
|
|
| #if MLA_SUPPORTED_TRAINING_PLATFORM |
| using Grpc.Core; |
| #if UNITY_EDITOR |
| using UnityEditor; |
| #endif |
| using System; |
| using System.Collections.Generic; |
| using System.Linq; |
| using UnityEngine; |
| using Unity.MLAgents.Actuators; |
| using Unity.MLAgents.CommunicatorObjects; |
| using Unity.MLAgents.Sensors; |
| using Unity.MLAgents.SideChannels; |
| using Google.Protobuf; |
|
|
| using Unity.MLAgents.Analytics; |
|
|
| namespace Unity.MLAgents |
| { |
| |
| public class RpcCommunicator : ICommunicator |
| { |
| public event QuitCommandHandler QuitCommandReceived; |
| public event ResetCommandHandler ResetCommandReceived; |
|
|
| |
| bool m_IsOpen; |
|
|
| List<string> m_BehaviorNames = new List<string>(); |
| bool m_NeedCommunicateThisStep; |
| ObservationWriter m_ObservationWriter = new ObservationWriter(); |
| Dictionary<string, SensorShapeValidator> m_SensorShapeValidators = new Dictionary<string, SensorShapeValidator>(); |
| Dictionary<string, List<int>> m_OrderedAgentsRequestingDecisions = new Dictionary<string, List<int>>(); |
|
|
| |
| UnityRLOutputProto m_CurrentUnityRlOutput = |
| new UnityRLOutputProto(); |
|
|
| Dictionary<string, Dictionary<int, ActionBuffers>> m_LastActionsReceived = |
| new Dictionary<string, Dictionary<int, ActionBuffers>>(); |
|
|
| |
| HashSet<string> m_SentBrainKeys = new HashSet<string>(); |
| Dictionary<string, ActionSpec> m_UnsentBrainKeys = new Dictionary<string, ActionSpec>(); |
|
|
|
|
| |
| UnityToExternalProto.UnityToExternalProtoClient m_Client; |
| Channel m_Channel; |
|
|
| |
| |
| |
| protected RpcCommunicator() |
| { |
| } |
|
|
| public static RpcCommunicator Create() |
| { |
| #if MLA_SUPPORTED_TRAINING_PLATFORM |
| return new RpcCommunicator(); |
| #else |
| return null; |
| #endif |
| } |
|
|
| #region Initialization |
|
|
| internal static bool CheckCommunicationVersionsAreCompatible( |
| string unityCommunicationVersion, |
| string pythonApiVersion |
| ) |
| { |
| var unityVersion = new Version(unityCommunicationVersion); |
| var pythonVersion = new Version(pythonApiVersion); |
| if (unityVersion.Major == 0) |
| { |
| if (unityVersion.Major != pythonVersion.Major || unityVersion.Minor != pythonVersion.Minor) |
| { |
| return false; |
| } |
| } |
| else if (unityVersion.Major != pythonVersion.Major) |
| { |
| return false; |
| } |
| else if (unityVersion.Minor != pythonVersion.Minor) |
| { |
| |
| |
| } |
| return true; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| public bool Initialize(CommunicatorInitParameters initParameters, out UnityRLInitParameters initParametersOut) |
| { |
| #if MLA_SUPPORTED_TRAINING_PLATFORM |
| var academyParameters = new UnityRLInitializationOutputProto |
| { |
| Name = initParameters.name, |
| PackageVersion = initParameters.unityPackageVersion, |
| CommunicationVersion = initParameters.unityCommunicationVersion, |
| Capabilities = initParameters.CSharpCapabilities.ToProto() |
| }; |
|
|
| UnityInputProto input; |
| UnityInputProto initializationInput; |
| try |
| { |
| initializationInput = Initialize( |
| initParameters.port, |
| new UnityOutputProto |
| { |
| RlInitializationOutput = academyParameters |
| }, |
| out input |
| ); |
| } |
| catch (Exception ex) |
| { |
| if (ex is RpcException rpcException) |
| { |
| switch (rpcException.Status.StatusCode) |
| { |
| case StatusCode.Unavailable: |
| |
| break; |
| case StatusCode.DeadlineExceeded: |
| |
| break; |
| default: |
| Debug.Log($"Unexpected gRPC exception when trying to initialize communication: {rpcException}"); |
| break; |
| } |
| } |
| else |
| { |
| Debug.Log($"Unexpected exception when trying to initialize communication: {ex}"); |
| } |
| initParametersOut = new UnityRLInitParameters(); |
| NotifyQuitAndShutDownChannel(); |
| return false; |
| } |
|
|
| var pythonPackageVersion = initializationInput.RlInitializationInput.PackageVersion; |
| var pythonCommunicationVersion = initializationInput.RlInitializationInput.CommunicationVersion; |
| TrainingAnalytics.SetTrainerInformation(pythonPackageVersion, pythonCommunicationVersion); |
|
|
| var communicationIsCompatible = CheckCommunicationVersionsAreCompatible( |
| initParameters.unityCommunicationVersion, |
| pythonCommunicationVersion |
| ); |
|
|
| |
| |
| if (initializationInput != null && input == null) |
| { |
| if (!communicationIsCompatible) |
| { |
| Debug.LogWarningFormat( |
| "Communication protocol between python ({0}) and Unity ({1}) have different " + |
| "versions which make them incompatible. Python library version: {2}.", |
| pythonCommunicationVersion, initParameters.unityCommunicationVersion, |
| pythonPackageVersion |
| ); |
| } |
| else |
| { |
| Debug.LogWarningFormat( |
| "Unknown communication error between Python. Python communication protocol: {0}, " + |
| "Python library version: {1}.", |
| pythonCommunicationVersion, |
| pythonPackageVersion |
| ); |
| } |
|
|
| initParametersOut = new UnityRLInitParameters(); |
| return false; |
| } |
|
|
| UpdateEnvironmentWithInput(input.RlInput); |
| initParametersOut = initializationInput.RlInitializationInput.ToUnityRLInitParameters(); |
| |
| Application.quitting += NotifyQuitAndShutDownChannel; |
| return true; |
| #else |
| initParametersOut = new UnityRLInitParameters(); |
| return false; |
| #endif |
| } |
|
|
| |
| |
| |
| |
| |
| public void SubscribeBrain(string brainKey, ActionSpec actionSpec) |
| { |
| if (m_BehaviorNames.Contains(brainKey)) |
| { |
| return; |
| } |
| m_BehaviorNames.Add(brainKey); |
| m_CurrentUnityRlOutput.AgentInfos.Add( |
| brainKey, |
| new UnityRLOutputProto.Types.ListAgentInfoProto() |
| ); |
|
|
| CacheActionSpec(brainKey, actionSpec); |
| } |
|
|
| void UpdateEnvironmentWithInput(UnityRLInputProto rlInput) |
| { |
| SideChannelManager.ProcessSideChannelData(rlInput.SideChannel.ToArray()); |
| SendCommandEvent(rlInput.Command); |
| } |
|
|
| UnityInputProto Initialize(int port, UnityOutputProto unityOutput, out UnityInputProto unityInput) |
| { |
| m_IsOpen = true; |
| m_Channel = new Channel($"localhost:{port}", ChannelCredentials.Insecure); |
|
|
| m_Client = new UnityToExternalProto.UnityToExternalProtoClient(m_Channel); |
| var result = m_Client.Exchange(WrapMessage(unityOutput, 200)); |
| var inputMessage = m_Client.Exchange(WrapMessage(null, 200)); |
| unityInput = inputMessage.UnityInput; |
| #if UNITY_EDITOR |
| EditorApplication.playModeStateChanged += HandleOnPlayModeChanged; |
| #endif |
| if (result.Header.Status != 200 || inputMessage.Header.Status != 200) |
| { |
| m_IsOpen = false; |
| NotifyQuitAndShutDownChannel(); |
| } |
| return result.UnityInput; |
| } |
|
|
| void NotifyQuitAndShutDownChannel() |
| { |
| QuitCommandReceived?.Invoke(); |
| try |
| { |
| m_Channel.ShutdownAsync().Wait(); |
| } |
| catch (Exception) |
| { |
| |
| } |
| } |
|
|
| #endregion |
|
|
| #region Destruction |
|
|
| |
| |
| |
| public void Dispose() |
| { |
| if (!m_IsOpen) |
| { |
| return; |
| } |
|
|
| try |
| { |
| m_Client.Exchange(WrapMessage(null, 400)); |
| m_IsOpen = false; |
| } |
| catch |
| { |
| |
| } |
| } |
|
|
| #endregion |
|
|
| #region Sending Events |
|
|
| void SendCommandEvent(CommandProto command) |
| { |
| switch (command) |
| { |
| case CommandProto.Quit: |
| { |
| NotifyQuitAndShutDownChannel(); |
| return; |
| } |
| case CommandProto.Reset: |
| { |
| foreach (var brainName in m_OrderedAgentsRequestingDecisions.Keys) |
| { |
| m_OrderedAgentsRequestingDecisions[brainName].Clear(); |
| } |
| ResetCommandReceived?.Invoke(); |
| return; |
| } |
| default: |
| { |
| return; |
| } |
| } |
| } |
|
|
| #endregion |
|
|
| #region Sending and retreiving data |
|
|
| public void DecideBatch() |
| { |
| if (!m_NeedCommunicateThisStep) |
| { |
| return; |
| } |
| m_NeedCommunicateThisStep = false; |
|
|
| SendBatchedMessageHelper(); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| public void PutObservations(string behaviorName, AgentInfo info, List<ISensor> sensors) |
| { |
| #if DEBUG |
| if (!m_SensorShapeValidators.ContainsKey(behaviorName)) |
| { |
| m_SensorShapeValidators[behaviorName] = new SensorShapeValidator(); |
| } |
| m_SensorShapeValidators[behaviorName].ValidateSensors(sensors); |
| #endif |
|
|
| using (TimerStack.Instance.Scoped("AgentInfo.ToProto")) |
| { |
| var agentInfoProto = info.ToAgentInfoProto(); |
|
|
| using (TimerStack.Instance.Scoped("GenerateSensorData")) |
| { |
| foreach (var sensor in sensors) |
| { |
| var obsProto = sensor.GetObservationProto(m_ObservationWriter); |
| agentInfoProto.Observations.Add(obsProto); |
| } |
| } |
| m_CurrentUnityRlOutput.AgentInfos[behaviorName].Value.Add(agentInfoProto); |
| } |
|
|
| m_NeedCommunicateThisStep = true; |
| if (!m_OrderedAgentsRequestingDecisions.ContainsKey(behaviorName)) |
| { |
| m_OrderedAgentsRequestingDecisions[behaviorName] = new List<int>(); |
| } |
| if (!info.done) |
| { |
| m_OrderedAgentsRequestingDecisions[behaviorName].Add(info.episodeId); |
| } |
| if (!m_LastActionsReceived.ContainsKey(behaviorName)) |
| { |
| m_LastActionsReceived[behaviorName] = new Dictionary<int, ActionBuffers>(); |
| } |
| m_LastActionsReceived[behaviorName][info.episodeId] = ActionBuffers.Empty; |
| if (info.done) |
| { |
| m_LastActionsReceived[behaviorName].Remove(info.episodeId); |
| } |
| } |
|
|
| |
| |
| |
| |
| void SendBatchedMessageHelper() |
| { |
| var message = new UnityOutputProto |
| { |
| RlOutput = m_CurrentUnityRlOutput, |
| }; |
| var tempUnityRlInitializationOutput = GetTempUnityRlInitializationOutput(); |
| if (tempUnityRlInitializationOutput != null) |
| { |
| message.RlInitializationOutput = tempUnityRlInitializationOutput; |
| } |
|
|
| byte[] messageAggregated = SideChannelManager.GetSideChannelMessage(); |
| message.RlOutput.SideChannel = ByteString.CopyFrom(messageAggregated); |
|
|
| var input = Exchange(message); |
| UpdateSentActionSpec(tempUnityRlInitializationOutput); |
|
|
| foreach (var k in m_CurrentUnityRlOutput.AgentInfos.Keys) |
| { |
| m_CurrentUnityRlOutput.AgentInfos[k].Value.Clear(); |
| } |
|
|
| var rlInput = input?.RlInput; |
|
|
| if (rlInput?.AgentActions == null) |
| { |
| return; |
| } |
|
|
| UpdateEnvironmentWithInput(rlInput); |
|
|
| foreach (var brainName in rlInput.AgentActions.Keys) |
| { |
| if (!m_OrderedAgentsRequestingDecisions[brainName].Any()) |
| { |
| continue; |
| } |
|
|
| if (!rlInput.AgentActions[brainName].Value.Any()) |
| { |
| continue; |
| } |
|
|
| var agentActions = rlInput.AgentActions[brainName].ToAgentActionList(); |
| var numAgents = m_OrderedAgentsRequestingDecisions[brainName].Count; |
| for (var i = 0; i < numAgents; i++) |
| { |
| var agentAction = agentActions[i]; |
| var agentId = m_OrderedAgentsRequestingDecisions[brainName][i]; |
| if (m_LastActionsReceived[brainName].ContainsKey(agentId)) |
| { |
| m_LastActionsReceived[brainName][agentId] = agentAction; |
| } |
| } |
| } |
| foreach (var brainName in m_OrderedAgentsRequestingDecisions.Keys) |
| { |
| m_OrderedAgentsRequestingDecisions[brainName].Clear(); |
| } |
| } |
|
|
| public ActionBuffers GetActions(string behaviorName, int agentId) |
| { |
| if (m_LastActionsReceived.ContainsKey(behaviorName)) |
| { |
| if (m_LastActionsReceived[behaviorName].ContainsKey(agentId)) |
| { |
| return m_LastActionsReceived[behaviorName][agentId]; |
| } |
| } |
| return ActionBuffers.Empty; |
| } |
|
|
| |
| |
| |
| |
| |
| UnityInputProto Exchange(UnityOutputProto unityOutput) |
| { |
| if (!m_IsOpen) |
| { |
| return null; |
| } |
|
|
| try |
| { |
| var message = m_Client.Exchange(WrapMessage(unityOutput, 200)); |
| if (message.Header.Status == 200) |
| { |
| return message.UnityInput; |
| } |
|
|
| m_IsOpen = false; |
| |
| |
| |
| NotifyQuitAndShutDownChannel(); |
| return message.UnityInput; |
| } |
| catch (Exception ex) |
| { |
| if (ex is RpcException rpcException) |
| { |
| |
| switch (rpcException.Status.StatusCode) |
| { |
| case StatusCode.Unavailable: |
| |
| break; |
| case StatusCode.ResourceExhausted: |
| |
| |
| |
| Debug.LogError($"GRPC Exception: {rpcException.Message}. Disconnecting from trainer."); |
| break; |
| default: |
| |
| Debug.Log($"GRPC Exception: {rpcException.Message}. Disconnecting from trainer."); |
| break; |
| } |
| } |
| else |
| { |
| |
| Debug.LogError($"Communication Exception: {ex.Message}. Disconnecting from trainer."); |
| } |
|
|
| m_IsOpen = false; |
| NotifyQuitAndShutDownChannel(); |
| return null; |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| static UnityMessageProto WrapMessage(UnityOutputProto content, int status) |
| { |
| return new UnityMessageProto |
| { |
| Header = new HeaderProto { Status = status }, |
| UnityOutput = content |
| }; |
| } |
|
|
| void CacheActionSpec(string behaviorName, ActionSpec actionSpec) |
| { |
| if (m_SentBrainKeys.Contains(behaviorName)) |
| { |
| return; |
| } |
|
|
| |
| m_UnsentBrainKeys[behaviorName] = actionSpec; |
| } |
|
|
| UnityRLInitializationOutputProto GetTempUnityRlInitializationOutput() |
| { |
| UnityRLInitializationOutputProto output = null; |
| foreach (var behaviorName in m_UnsentBrainKeys.Keys) |
| { |
| if (m_CurrentUnityRlOutput.AgentInfos.ContainsKey(behaviorName)) |
| { |
| if (m_CurrentUnityRlOutput.AgentInfos[behaviorName].CalculateSize() > 0) |
| { |
| |
| |
| |
| |
| if (output == null) |
| { |
| output = new UnityRLInitializationOutputProto(); |
| } |
|
|
| var actionSpec = m_UnsentBrainKeys[behaviorName]; |
| output.BrainParameters.Add(actionSpec.ToBrainParametersProto(behaviorName, true)); |
| } |
| } |
| } |
|
|
| return output; |
| } |
|
|
| void UpdateSentActionSpec(UnityRLInitializationOutputProto output) |
| { |
| if (output == null) |
| { |
| return; |
| } |
|
|
| foreach (var brainProto in output.BrainParameters) |
| { |
| m_SentBrainKeys.Add(brainProto.BrainName); |
| m_UnsentBrainKeys.Remove(brainProto.BrainName); |
| } |
| } |
|
|
| #endregion |
|
|
| #if UNITY_EDITOR |
| |
| |
| |
| |
| void HandleOnPlayModeChanged(PlayModeStateChange state) |
| { |
| |
| if (state == PlayModeStateChange.ExitingPlayMode) |
| { |
| Dispose(); |
| } |
| } |
|
|
| #endif |
| } |
| } |
| #endif // UNITY_EDITOR || UNITY_STANDALONE |
|
|