| using System; |
| using System.Collections.Generic; |
| using System.Linq; |
| using Unity.Barracuda; |
| using FailedCheck = Unity.MLAgents.Inference.BarracudaModelParamLoader.FailedCheck; |
|
|
| namespace Unity.MLAgents.Inference |
| { |
| |
| |
| |
| internal static class BarracudaModelExtensions |
| { |
| |
| |
| |
| |
| |
| |
| |
| public static string[] GetInputNames(this Model model) |
| { |
| var names = new List<string>(); |
|
|
| if (model == null) |
| return names.ToArray(); |
|
|
| foreach (var input in model.inputs) |
| { |
| names.Add(input.name); |
| } |
|
|
| foreach (var mem in model.memories) |
| { |
| names.Add(mem.input); |
| } |
|
|
| names.Sort(StringComparer.InvariantCulture); |
|
|
| return names.ToArray(); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| public static int GetVersion(this Model model) |
| { |
| return (int)model.GetTensorByName(TensorNames.VersionNumber)[0]; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| public static IReadOnlyList<TensorProxy> GetInputTensors(this Model model) |
| { |
| var tensors = new List<TensorProxy>(); |
|
|
| if (model == null) |
| return tensors; |
|
|
| foreach (var input in model.inputs) |
| { |
| tensors.Add(new TensorProxy |
| { |
| name = input.name, |
| valueType = TensorProxy.TensorType.FloatingPoint, |
| data = null, |
| shape = input.shape.Select(i => (long)i).ToArray() |
| }); |
| } |
|
|
| tensors.Sort((el1, el2) => string.Compare(el1.name, el2.name, StringComparison.InvariantCulture)); |
|
|
| return tensors; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| public static int GetNumVisualInputs(this Model model) |
| { |
| var count = 0; |
| if (model == null) |
| return count; |
|
|
| foreach (var input in model.inputs) |
| { |
| if (input.name.StartsWith(TensorNames.VisualObservationPlaceholderPrefix)) |
| { |
| count++; |
| } |
| } |
|
|
| return count; |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| public static string[] GetOutputNames(this Model model, bool deterministicInference = false) |
| { |
| var names = new List<string>(); |
|
|
| if (model == null) |
| { |
| return names.ToArray(); |
| } |
|
|
| if (model.HasContinuousOutputs(deterministicInference)) |
| { |
| names.Add(model.ContinuousOutputName(deterministicInference)); |
| } |
| if (model.HasDiscreteOutputs(deterministicInference)) |
| { |
| names.Add(model.DiscreteOutputName(deterministicInference)); |
| } |
|
|
| var modelVersion = model.GetVersion(); |
| var memory = (int)model.GetTensorByName(TensorNames.MemorySize)[0]; |
| if (memory > 0) |
| { |
| names.Add(TensorNames.RecurrentOutput); |
| } |
|
|
| names.Sort(StringComparer.InvariantCulture); |
|
|
| return names.ToArray(); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| public static bool HasContinuousOutputs(this Model model, bool deterministicInference = false) |
| { |
| if (model == null) |
| return false; |
| if (!model.SupportsContinuousAndDiscrete()) |
| { |
| return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0; |
| } |
| else |
| { |
| bool hasStochasticOutput = !deterministicInference && |
| model.outputs.Contains(TensorNames.ContinuousActionOutput); |
| bool hasDeterministicOutput = deterministicInference && |
| model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput); |
|
|
| return (hasStochasticOutput || hasDeterministicOutput) && |
| (int)model.GetTensorByName(TensorNames.ContinuousActionOutputShape)[0] > 0; |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| public static int ContinuousOutputSize(this Model model) |
| { |
| if (model == null) |
| return 0; |
| if (!model.SupportsContinuousAndDiscrete()) |
| { |
| return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ? |
| (int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0] : 0; |
| } |
| else |
| { |
| var continuousOutputShape = model.GetTensorByName(TensorNames.ContinuousActionOutputShape); |
| return continuousOutputShape == null ? 0 : (int)continuousOutputShape[0]; |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| public static string ContinuousOutputName(this Model model, bool deterministicInference = false) |
| { |
| if (model == null) |
| return null; |
| if (!model.SupportsContinuousAndDiscrete()) |
| { |
| return TensorNames.ActionOutputDeprecated; |
| } |
| else |
| { |
| return deterministicInference ? TensorNames.DeterministicContinuousActionOutput : TensorNames.ContinuousActionOutput; |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| public static bool HasDiscreteOutputs(this Model model, bool deterministicInference = false) |
| { |
| if (model == null) |
| return false; |
| if (!model.SupportsContinuousAndDiscrete()) |
| { |
| return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] == 0; |
| } |
| else |
| { |
| bool hasStochasticOutput = !deterministicInference && |
| model.outputs.Contains(TensorNames.DiscreteActionOutput); |
| bool hasDeterministicOutput = deterministicInference && |
| model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput); |
| return (hasStochasticOutput || hasDeterministicOutput) && |
| model.DiscreteOutputSize() > 0; |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| public static int DiscreteOutputSize(this Model model) |
| { |
| if (model == null) |
| return 0; |
| if (!model.SupportsContinuousAndDiscrete()) |
| { |
| return (int)model.GetTensorByName(TensorNames.IsContinuousControlDeprecated)[0] > 0 ? |
| 0 : (int)model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated)[0]; |
| } |
| else |
| { |
| var discreteOutputShape = model.GetTensorByName(TensorNames.DiscreteActionOutputShape); |
| if (discreteOutputShape == null) |
| { |
| return 0; |
| } |
| else |
| { |
| int result = 0; |
| for (int i = 0; i < discreteOutputShape.length; i++) |
| { |
| result += (int)discreteOutputShape[i]; |
| } |
| return result; |
| } |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| public static string DiscreteOutputName(this Model model, bool deterministicInference = false) |
| { |
| if (model == null) |
| return null; |
| if (!model.SupportsContinuousAndDiscrete()) |
| { |
| return TensorNames.ActionOutputDeprecated; |
| } |
| else |
| { |
| return deterministicInference ? TensorNames.DeterministicDiscreteActionOutput : TensorNames.DiscreteActionOutput; |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| public static bool SupportsContinuousAndDiscrete(this Model model) |
| { |
| return model == null || |
| model.outputs.Contains(TensorNames.ContinuousActionOutput) || |
| model.outputs.Contains(TensorNames.DiscreteActionOutput); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| public static bool CheckExpectedTensors(this Model model, List<FailedCheck> failedModelChecks, bool deterministicInference = false) |
| { |
| |
| var modelApiVersionTensor = model.GetTensorByName(TensorNames.VersionNumber); |
| if (modelApiVersionTensor == null) |
| { |
| failedModelChecks.Add( |
| FailedCheck.Warning($"Required constant \"{TensorNames.VersionNumber}\" was not found in the model file.") |
| ); |
| return false; |
| } |
|
|
| |
| var memorySizeTensor = model.GetTensorByName(TensorNames.MemorySize); |
| if (memorySizeTensor == null) |
| { |
| failedModelChecks.Add( |
| FailedCheck.Warning($"Required constant \"{TensorNames.MemorySize}\" was not found in the model file.") |
| ); |
| return false; |
| } |
|
|
| |
| if (!model.outputs.Contains(TensorNames.ActionOutputDeprecated) && |
| !model.outputs.Contains(TensorNames.ContinuousActionOutput) && |
| !model.outputs.Contains(TensorNames.DiscreteActionOutput) && |
| !model.outputs.Contains(TensorNames.DeterministicContinuousActionOutput) && |
| !model.outputs.Contains(TensorNames.DeterministicDiscreteActionOutput)) |
| { |
| failedModelChecks.Add( |
| FailedCheck.Warning("The model does not contain any Action Output Node.") |
| ); |
| return false; |
| } |
|
|
| |
| if (!model.SupportsContinuousAndDiscrete()) |
| { |
| if (model.GetTensorByName(TensorNames.ActionOutputShapeDeprecated) == null) |
| { |
| failedModelChecks.Add( |
| FailedCheck.Warning("The model does not contain any Action Output Shape Node.") |
| ); |
| return false; |
| } |
| if (model.GetTensorByName(TensorNames.IsContinuousControlDeprecated) == null) |
| { |
| failedModelChecks.Add( |
| FailedCheck.Warning($"Required constant \"{TensorNames.IsContinuousControlDeprecated}\" was " + |
| "not found in the model file. " + |
| "This is only required for model that uses a deprecated model format.") |
| ); |
| return false; |
| } |
| } |
| else |
| { |
| if (model.outputs.Contains(TensorNames.ContinuousActionOutput)) |
| { |
| if (model.GetTensorByName(TensorNames.ContinuousActionOutputShape) == null) |
| { |
| failedModelChecks.Add( |
| FailedCheck.Warning("The model uses continuous action but does not contain Continuous Action Output Shape Node.") |
| ); |
| return false; |
| } |
| else if (!model.HasContinuousOutputs(deterministicInference)) |
| { |
| var actionType = deterministicInference ? "deterministic" : "stochastic"; |
| var actionName = deterministicInference ? "Deterministic" : ""; |
| failedModelChecks.Add( |
| FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Continuous Action Output Tensor. Uncheck `Deterministic inference` flag..") |
| ); |
| return false; |
| } |
| } |
|
|
| if (model.outputs.Contains(TensorNames.DiscreteActionOutput)) |
| { |
| if (model.GetTensorByName(TensorNames.DiscreteActionOutputShape) == null) |
| { |
| failedModelChecks.Add( |
| FailedCheck.Warning("The model uses discrete action but does not contain Discrete Action Output Shape Node.") |
| ); |
| return false; |
| } |
| else if (!model.HasDiscreteOutputs(deterministicInference)) |
| { |
| var actionType = deterministicInference ? "deterministic" : "stochastic"; |
| var actionName = deterministicInference ? "Deterministic" : ""; |
| failedModelChecks.Add( |
| FailedCheck.Warning($"The model uses {actionType} inference but does not contain {actionName} Discrete Action Output Tensor. Uncheck `Deterministic inference` flag.") |
| ); |
| return false; |
| } |
| } |
| } |
| return true; |
| } |
| } |
| } |
|
|