| using System.Collections.Generic; |
| using Unity.Barracuda; |
| using Unity.MLAgents.Sensors; |
|
|
| namespace Unity.MLAgents.Inference |
| { |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| internal class TensorGenerator |
| { |
| public interface IGenerator |
| { |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| void Generate( |
| TensorProxy tensorProxy, int batchSize, IList<AgentInfoSensorsPair> infos); |
| } |
|
|
| readonly Dictionary<string, IGenerator> m_Dict = new Dictionary<string, IGenerator>(); |
| int m_ApiVersion; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| public TensorGenerator( |
| int seed, |
| ITensorAllocator allocator, |
| Dictionary<int, List<float>> memories, |
| object barracudaModel = null, |
| bool deterministicInference = false) |
| { |
| |
| if (barracudaModel == null) |
| { |
| return; |
| } |
| var model = (Model)barracudaModel; |
|
|
| m_ApiVersion = model.GetVersion(); |
|
|
| |
| m_Dict[TensorNames.BatchSizePlaceholder] = |
| new BatchSizeGenerator(allocator); |
| m_Dict[TensorNames.SequenceLengthPlaceholder] = |
| new SequenceLengthGenerator(allocator); |
| m_Dict[TensorNames.RecurrentInPlaceholder] = |
| new RecurrentInputGenerator(allocator, memories); |
|
|
| m_Dict[TensorNames.PreviousActionPlaceholder] = |
| new PreviousActionInputGenerator(allocator); |
| m_Dict[TensorNames.ActionMaskPlaceholder] = |
| new ActionMaskInputGenerator(allocator); |
| m_Dict[TensorNames.RandomNormalEpsilonPlaceholder] = |
| new RandomNormalInputGenerator(seed, allocator); |
|
|
|
|
| |
| if (model.HasContinuousOutputs(deterministicInference)) |
| { |
| m_Dict[model.ContinuousOutputName(deterministicInference)] = new BiDimensionalOutputGenerator(allocator); |
| } |
| if (model.HasDiscreteOutputs(deterministicInference)) |
| { |
| m_Dict[model.DiscreteOutputName(deterministicInference)] = new BiDimensionalOutputGenerator(allocator); |
| } |
| m_Dict[TensorNames.RecurrentOutput] = new BiDimensionalOutputGenerator(allocator); |
| m_Dict[TensorNames.ValueEstimateOutput] = new BiDimensionalOutputGenerator(allocator); |
| } |
|
|
| public void InitializeObservations(List<ISensor> sensors, ITensorAllocator allocator) |
| { |
| if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents1_0) |
| { |
| |
| |
| |
| var visIndex = 0; |
| ObservationGenerator vecObsGen = null; |
| for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++) |
| { |
| var sensor = sensors[sensorIndex]; |
| var rank = sensor.GetObservationSpec().Rank; |
| ObservationGenerator obsGen = null; |
| string obsGenName = null; |
| switch (rank) |
| { |
| case 1: |
| if (vecObsGen == null) |
| { |
| vecObsGen = new ObservationGenerator(allocator); |
| } |
| obsGen = vecObsGen; |
| obsGenName = TensorNames.VectorObservationPlaceholder; |
| break; |
| case 2: |
| |
| |
| obsGen = new ObservationGenerator(allocator); |
| obsGenName = TensorNames.GetObservationName(sensorIndex); |
| break; |
| case 3: |
| |
| |
| obsGen = new ObservationGenerator(allocator); |
| obsGenName = TensorNames.GetVisualObservationName(visIndex); |
| visIndex++; |
| break; |
| default: |
| throw new UnityAgentsException( |
| $"Sensor {sensor.GetName()} have an invalid rank {rank}"); |
| } |
| obsGen.AddSensorIndex(sensorIndex); |
| m_Dict[obsGenName] = obsGen; |
| } |
| } |
|
|
| if (m_ApiVersion == (int)BarracudaModelParamLoader.ModelApiVersion.MLAgents2_0) |
| { |
| for (var sensorIndex = 0; sensorIndex < sensors.Count; sensorIndex++) |
| { |
| var obsGen = new ObservationGenerator(allocator); |
| var obsGenName = TensorNames.GetObservationName(sensorIndex); |
| obsGen.AddSensorIndex(sensorIndex); |
| m_Dict[obsGenName] = obsGen; |
| } |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| public void GenerateTensors( |
| IReadOnlyList<TensorProxy> tensors, int currentBatchSize, IList<AgentInfoSensorsPair> infos) |
| { |
| for (var tensorIndex = 0; tensorIndex < tensors.Count; tensorIndex++) |
| { |
| var tensor = tensors[tensorIndex]; |
| if (!m_Dict.ContainsKey(tensor.name)) |
| { |
| throw new UnityAgentsException( |
| $"Unknown tensorProxy expected as input : {tensor.name}"); |
| } |
| m_Dict[tensor.name].Generate(tensor, currentBatchSize, infos); |
| } |
| } |
| } |
| } |
|
|