| using System; |
| using System.Collections.Generic; |
| using Unity.MLAgents.Sensors; |
| using UnityEngine; |
|
|
| namespace Unity.MLAgents.Integrations.Match3 |
| { |
| |
| |
| |
| |
| |
| public delegate int GridValueProvider(int x, int y); |
|
|
| |
| |
| |
| |
| public enum Match3ObservationType |
| { |
| |
| |
| |
| |
| Vector, |
|
|
| |
| |
| |
| |
| UncompressedVisual, |
|
|
| |
| |
| |
| |
| |
| CompressedVisual |
| } |
|
|
| |
| |
| |
| |
| public class Match3Sensor : ISensor, IBuiltInSensor, IDisposable |
| { |
| Match3ObservationType m_ObservationType; |
| ObservationSpec m_ObservationSpec; |
| string m_Name; |
|
|
| AbstractBoard m_Board; |
| BoardSize m_MaxBoardSize; |
| GridValueProvider m_GridValues; |
| int m_OneHotSize; |
|
|
| Texture2D m_ObservationTexture; |
| OneHotToTextureUtil m_TextureUtil; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| public Match3Sensor(AbstractBoard board, GridValueProvider gvp, int oneHotSize, Match3ObservationType obsType, string name) |
| { |
| var maxBoardSize = board.GetMaxBoardSize(); |
| m_Name = name; |
| m_MaxBoardSize = maxBoardSize; |
| m_GridValues = gvp; |
| m_OneHotSize = oneHotSize; |
| m_Board = board; |
|
|
| m_ObservationType = obsType; |
| m_ObservationSpec = obsType == Match3ObservationType.Vector |
| ? ObservationSpec.Vector(maxBoardSize.Rows * maxBoardSize.Columns * oneHotSize) |
| : ObservationSpec.Visual(maxBoardSize.Rows, maxBoardSize.Columns, oneHotSize); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| public static Match3Sensor CellTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name) |
| { |
| var maxBoardSize = board.GetMaxBoardSize(); |
| return new Match3Sensor(board, board.GetCellType, maxBoardSize.NumCellTypes, obsType, name); |
| } |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| public static Match3Sensor SpecialTypeSensor(AbstractBoard board, Match3ObservationType obsType, string name) |
| { |
| var maxBoardSize = board.GetMaxBoardSize(); |
| if (maxBoardSize.NumSpecialTypes == 0) |
| { |
| return null; |
| } |
| var specialSize = maxBoardSize.NumSpecialTypes + 1; |
| return new Match3Sensor(board, board.GetSpecialType, specialSize, obsType, name); |
| } |
|
|
| |
| public ObservationSpec GetObservationSpec() |
| { |
| return m_ObservationSpec; |
| } |
|
|
| |
| public int Write(ObservationWriter writer) |
| { |
| m_Board.CheckBoardSizes(m_MaxBoardSize); |
| var currentBoardSize = m_Board.GetCurrentBoardSize(); |
|
|
| int offset = 0; |
| var isVisual = m_ObservationType != Match3ObservationType.Vector; |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| for (var r = 0; r < currentBoardSize.Rows; r++) |
| { |
| for (var c = 0; c < currentBoardSize.Columns; c++) |
| { |
| var val = m_GridValues(r, c); |
| writer.WriteOneHot(offset, r, c, val, m_OneHotSize, isVisual); |
| offset += m_OneHotSize; |
| } |
|
|
| for (var c = currentBoardSize.Columns; c < m_MaxBoardSize.Columns; c++) |
| { |
| writer.WriteZero(offset, r, c, m_OneHotSize, isVisual); |
| offset += m_OneHotSize; |
| } |
| } |
|
|
| for (var r = currentBoardSize.Rows; r < m_MaxBoardSize.Columns; r++) |
| { |
| for (var c = 0; c < m_MaxBoardSize.Columns; c++) |
| { |
| writer.WriteZero(offset, r, c, m_OneHotSize, isVisual); |
| offset += m_OneHotSize; |
| } |
| } |
|
|
| return offset; |
| } |
|
|
| |
| public byte[] GetCompressedObservation() |
| { |
| m_Board.CheckBoardSizes(m_MaxBoardSize); |
| var height = m_MaxBoardSize.Rows; |
| var width = m_MaxBoardSize.Columns; |
| if (ReferenceEquals(null, m_ObservationTexture)) |
| { |
| m_ObservationTexture = new Texture2D(width, height, TextureFormat.RGB24, false); |
| } |
|
|
| if (ReferenceEquals(null, m_TextureUtil)) |
| { |
| m_TextureUtil = new OneHotToTextureUtil(height, width); |
| } |
| var bytesOut = new List<byte>(); |
| var currentBoardSize = m_Board.GetCurrentBoardSize(); |
|
|
| |
| |
| |
| |
| var numCellImages = (m_OneHotSize + 2) / 3; |
| for (var i = 0; i < numCellImages; i++) |
| { |
| m_TextureUtil.EncodeToTexture( |
| m_GridValues, |
| m_ObservationTexture, |
| 3 * i, |
| currentBoardSize.Rows, |
| currentBoardSize.Columns |
| ); |
| bytesOut.AddRange(m_ObservationTexture.EncodeToPNG()); |
| } |
|
|
| return bytesOut.ToArray(); |
| } |
|
|
| |
| public void Update() |
| { |
| } |
|
|
| |
| public void Reset() |
| { |
| } |
|
|
| internal SensorCompressionType GetCompressionType() |
| { |
| return m_ObservationType == Match3ObservationType.CompressedVisual ? |
| SensorCompressionType.PNG : |
| SensorCompressionType.None; |
| } |
|
|
| |
| public CompressionSpec GetCompressionSpec() |
| { |
| return new CompressionSpec(GetCompressionType()); |
| } |
|
|
| |
| public string GetName() |
| { |
| return m_Name; |
| } |
|
|
| |
| public BuiltInSensorType GetBuiltInSensorType() |
| { |
| return BuiltInSensorType.Match3Sensor; |
| } |
|
|
| |
| |
| |
| public void Dispose() |
| { |
| if (!ReferenceEquals(null, m_ObservationTexture)) |
| { |
| Utilities.DestroyTexture(m_ObservationTexture); |
| m_ObservationTexture = null; |
| } |
| } |
| } |
|
|
| |
| |
| |
| |
| |
| |
| internal class OneHotToTextureUtil |
| { |
| Color[] m_Colors; |
| int m_MaxHeight; |
| int m_MaxWidth; |
| private static Color[] s_OneHotColors = { Color.red, Color.green, Color.blue }; |
|
|
| public OneHotToTextureUtil(int maxHeight, int maxWidth) |
| { |
| m_Colors = new Color[maxHeight * maxWidth]; |
| m_MaxHeight = maxHeight; |
| m_MaxWidth = maxWidth; |
| } |
|
|
| public void EncodeToTexture( |
| GridValueProvider gridValueProvider, |
| Texture2D texture, |
| int channelOffset, |
| int currentHeight, |
| int currentWidth |
| ) |
| { |
| var i = 0; |
| |
| |
| for (var h = m_MaxHeight - 1; h >= 0; h--) |
| { |
| for (var w = 0; w < m_MaxWidth; w++) |
| { |
| var colorVal = Color.black; |
| if (h < currentHeight && w < currentWidth) |
| { |
| int oneHotValue = gridValueProvider(h, w); |
| if (oneHotValue >= channelOffset && oneHotValue < channelOffset + 3) |
| { |
| colorVal = s_OneHotColors[oneHotValue - channelOffset]; |
| } |
| } |
| m_Colors[i++] = colorVal; |
| } |
| } |
| texture.SetPixels(m_Colors); |
| } |
| } |
|
|
| |
| |
| |
| internal static class ObservationWriterMatch3Extensions |
| { |
| public static void WriteOneHot(this ObservationWriter writer, int offset, int row, int col, int value, int oneHotSize, bool isVisual) |
| { |
| if (isVisual) |
| { |
| for (var i = 0; i < oneHotSize; i++) |
| { |
| writer[row, col, i] = (i == value) ? 1.0f : 0.0f; |
| } |
| } |
| else |
| { |
| for (var i = 0; i < oneHotSize; i++) |
| { |
| writer[offset] = (i == value) ? 1.0f : 0.0f; |
| offset++; |
| } |
| } |
| } |
|
|
| public static void WriteZero(this ObservationWriter writer, int offset, int row, int col, int oneHotSize, bool isVisual) |
| { |
| if (isVisual) |
| { |
| for (var i = 0; i < oneHotSize; i++) |
| { |
| writer[row, col, i] = 0.0f; |
| } |
| } |
| else |
| { |
| for (var i = 0; i < oneHotSize; i++) |
| { |
| writer[offset] = 0.0f; |
| offset++; |
| } |
| } |
| } |
| } |
| } |
|
|