| using System.Collections.Generic; |
| using Unity.Barracuda; |
| using UnityEngine.Profiling; |
| using Unity.MLAgents.Actuators; |
| using Unity.MLAgents.Policies; |
| using Unity.MLAgents.Sensors; |
|
|
| namespace Unity.MLAgents.Inference |
| { |
| internal struct AgentInfoSensorsPair |
| { |
| public AgentInfo agentInfo; |
| public List<ISensor> sensors; |
| } |
|
|
| internal class ModelRunner |
| { |
| List<AgentInfoSensorsPair> m_Infos = new List<AgentInfoSensorsPair>(); |
| Dictionary<int, ActionBuffers> m_LastActionsReceived = new Dictionary<int, ActionBuffers>(); |
| List<int> m_OrderedAgentsRequestingDecisions = new List<int>(); |
|
|
| ITensorAllocator m_TensorAllocator; |
| TensorGenerator m_TensorGenerator; |
| TensorApplier m_TensorApplier; |
|
|
| NNModel m_Model; |
| string m_ModelName; |
| InferenceDevice m_InferenceDevice; |
| IWorker m_Engine; |
| bool m_Verbose = false; |
| bool m_DeterministicInference; |
| string[] m_OutputNames; |
| IReadOnlyList<TensorProxy> m_InferenceInputs; |
| List<TensorProxy> m_InferenceOutputs; |
| Dictionary<string, Tensor> m_InputsByName; |
| Dictionary<int, List<float>> m_Memories = new Dictionary<int, List<float>>(); |
|
|
| SensorShapeValidator m_SensorShapeValidator = new SensorShapeValidator(); |
|
|
| bool m_ObservationsInitialized; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| public ModelRunner( |
| NNModel model, |
| ActionSpec actionSpec, |
| InferenceDevice inferenceDevice, |
| int seed = 0, |
| bool deterministicInference = false) |
| { |
| Model barracudaModel; |
| m_Model = model; |
| m_ModelName = model?.name; |
| m_InferenceDevice = inferenceDevice; |
| m_DeterministicInference = deterministicInference; |
| m_TensorAllocator = new TensorCachingAllocator(); |
| if (model != null) |
| { |
| #if BARRACUDA_VERBOSE |
| m_Verbose = true; |
| #endif |
|
|
| D.logEnabled = m_Verbose; |
|
|
| barracudaModel = ModelLoader.Load(model); |
|
|
| var failedCheck = BarracudaModelParamLoader.CheckModelVersion( |
| barracudaModel |
| ); |
| if (failedCheck != null) |
| { |
| if (failedCheck.CheckType == BarracudaModelParamLoader.FailedCheck.CheckTypeEnum.Error) |
| { |
| throw new UnityAgentsException(failedCheck.Message); |
| } |
| } |
|
|
| WorkerFactory.Type executionDevice; |
| switch (inferenceDevice) |
| { |
| case InferenceDevice.CPU: |
| executionDevice = WorkerFactory.Type.CSharp; |
| break; |
| case InferenceDevice.GPU: |
| executionDevice = WorkerFactory.Type.ComputePrecompiled; |
| break; |
| case InferenceDevice.Burst: |
| executionDevice = WorkerFactory.Type.CSharpBurst; |
| break; |
| case InferenceDevice.Default: |
| default: |
| executionDevice = WorkerFactory.Type.CSharpBurst; |
| break; |
| } |
| m_Engine = WorkerFactory.CreateWorker(executionDevice, barracudaModel, m_Verbose); |
| } |
| else |
| { |
| barracudaModel = null; |
| m_Engine = null; |
| } |
|
|
| m_InferenceInputs = barracudaModel.GetInputTensors(); |
| m_OutputNames = barracudaModel.GetOutputNames(m_DeterministicInference); |
|
|
| m_TensorGenerator = new TensorGenerator( |
| seed, m_TensorAllocator, m_Memories, barracudaModel, m_DeterministicInference); |
| m_TensorApplier = new TensorApplier( |
| actionSpec, seed, m_TensorAllocator, m_Memories, barracudaModel, m_DeterministicInference); |
| m_InputsByName = new Dictionary<string, Tensor>(); |
| m_InferenceOutputs = new List<TensorProxy>(); |
| } |
|
|
| public InferenceDevice InferenceDevice |
| { |
| get { return m_InferenceDevice; } |
| } |
|
|
| public NNModel Model |
| { |
| get { return m_Model; } |
| } |
|
|
| void PrepareBarracudaInputs(IReadOnlyList<TensorProxy> infInputs) |
| { |
| m_InputsByName.Clear(); |
| for (var i = 0; i < infInputs.Count; i++) |
| { |
| var inp = infInputs[i]; |
| m_InputsByName[inp.name] = inp.data; |
| } |
| } |
|
|
| public void Dispose() |
| { |
| if (m_Engine != null) |
| m_Engine.Dispose(); |
| m_TensorAllocator?.Reset(false); |
| } |
|
|
| void FetchBarracudaOutputs(string[] names) |
| { |
| m_InferenceOutputs.Clear(); |
| foreach (var n in names) |
| { |
| var output = m_Engine.PeekOutput(n); |
| m_InferenceOutputs.Add(TensorUtils.TensorProxyFromBarracuda(output, n)); |
| } |
| } |
|
|
| public void PutObservations(AgentInfo info, List<ISensor> sensors) |
| { |
| #if DEBUG |
| m_SensorShapeValidator.ValidateSensors(sensors); |
| #endif |
| m_Infos.Add(new AgentInfoSensorsPair |
| { |
| agentInfo = info, |
| sensors = sensors |
| }); |
|
|
| |
| m_OrderedAgentsRequestingDecisions.Add(info.episodeId); |
|
|
| if (!m_LastActionsReceived.ContainsKey(info.episodeId)) |
| { |
| m_LastActionsReceived[info.episodeId] = ActionBuffers.Empty; |
| } |
| if (info.done) |
| { |
| |
| |
| m_LastActionsReceived.Remove(info.episodeId); |
| } |
| } |
|
|
| public void DecideBatch() |
| { |
| var currentBatchSize = m_Infos.Count; |
| if (currentBatchSize == 0) |
| { |
| return; |
| } |
| if (!m_ObservationsInitialized) |
| { |
| |
| |
| var firstInfo = m_Infos[0]; |
| m_TensorGenerator.InitializeObservations(firstInfo.sensors, m_TensorAllocator); |
| m_ObservationsInitialized = true; |
| } |
|
|
| Profiler.BeginSample("ModelRunner.DecideAction"); |
| Profiler.BeginSample(m_ModelName); |
|
|
| Profiler.BeginSample($"GenerateTensors"); |
| |
| m_TensorGenerator.GenerateTensors(m_InferenceInputs, currentBatchSize, m_Infos); |
| Profiler.EndSample(); |
|
|
| Profiler.BeginSample($"PrepareBarracudaInputs"); |
| PrepareBarracudaInputs(m_InferenceInputs); |
| Profiler.EndSample(); |
|
|
| |
| Profiler.BeginSample($"ExecuteGraph"); |
| m_Engine.Execute(m_InputsByName); |
| Profiler.EndSample(); |
|
|
| Profiler.BeginSample($"FetchBarracudaOutputs"); |
| FetchBarracudaOutputs(m_OutputNames); |
| Profiler.EndSample(); |
|
|
| Profiler.BeginSample($"ApplyTensors"); |
| |
| m_TensorApplier.ApplyTensors(m_InferenceOutputs, m_OrderedAgentsRequestingDecisions, m_LastActionsReceived); |
| Profiler.EndSample(); |
|
|
| Profiler.EndSample(); |
| Profiler.EndSample(); |
|
|
| m_Infos.Clear(); |
|
|
| m_OrderedAgentsRequestingDecisions.Clear(); |
| } |
|
|
| public bool HasModel(NNModel other, InferenceDevice otherInferenceDevice) |
| { |
| return m_Model == other && m_InferenceDevice == otherInferenceDevice; |
| } |
|
|
| public ActionBuffers GetAction(int agentId) |
| { |
| if (m_LastActionsReceived.ContainsKey(agentId)) |
| { |
| return m_LastActionsReceived[agentId]; |
| } |
| return ActionBuffers.Empty; |
| } |
| } |
| } |
|
|