| class UnificationFailure(Exception): |
| pass |
|
|
|
|
| class Occurs(UnificationFailure): |
| pass |
|
|
|
|
| class Type(object): |
| def __str__(self): return self.show(True) |
|
|
| def __repr__(self): return str(self) |
|
|
| @staticmethod |
| def fromjson(j): |
| if "index" in j: return TypeVariable(j["index"]) |
| if "constructor" in j: return TypeConstructor(j["constructor"], |
| [ Type.fromjson(a) for a in j["arguments"] ]) |
| assert False |
|
|
|
|
| class TypeConstructor(Type): |
| def __init__(self, name, arguments): |
| self.name = name |
| self.arguments = arguments |
| self.isPolymorphic = any(a.isPolymorphic for a in arguments) |
|
|
| def makeDummyMonomorphic(self, mapping=None): |
| mapping = mapping if mapping is not None else {} |
| return TypeConstructor(self.name, |
| [ a.makeDummyMonomorphic(mapping) for a in self.arguments ]) |
|
|
| def __eq__(self, other): |
| return isinstance(other, TypeConstructor) and \ |
| self.name == other.name and \ |
| all(x == y for x, y in zip(self.arguments, other.arguments)) |
|
|
| def __hash__(self): return hash((self.name,) + tuple(self.arguments)) |
|
|
| def __ne__(self, other): |
| return not (self == other) |
|
|
| def show(self, isReturn): |
| if self.name == ARROW: |
| if isReturn: |
| return "%s %s %s" % (self.arguments[0].show( |
| False), ARROW, self.arguments[1].show(True)) |
| else: |
| return "(%s %s %s)" % (self.arguments[0].show( |
| False), ARROW, self.arguments[1].show(True)) |
| elif self.arguments == []: |
| return self.name |
| else: |
| return "%s(%s)" % (self.name, ", ".join(x.show(True) |
| for x in self.arguments)) |
|
|
| def json(self): |
| return {"constructor": self.name, |
| "arguments": [a.json() for a in self.arguments]} |
|
|
|
|
| def isArrow(self): return self.name == ARROW |
|
|
| def functionArguments(self): |
| if self.name == ARROW: |
| xs = self.arguments[1].functionArguments() |
| return [self.arguments[0]] + xs |
| return [] |
|
|
| def returns(self): |
| if self.name == ARROW: |
| return self.arguments[1].returns() |
| else: |
| return self |
|
|
| def apply(self, context): |
| if not self.isPolymorphic: |
| return self |
| return TypeConstructor(self.name, |
| [x.apply(context) for x in self.arguments]) |
|
|
| def applyMutable(self, context): |
| if not self.isPolymorphic: |
| return self |
| return TypeConstructor(self.name, |
| [x.applyMutable(context) for x in self.arguments]) |
|
|
| def occurs(self, v): |
| if not self.isPolymorphic: |
| return False |
| return any(x.occurs(v) for x in self.arguments) |
|
|
| def negateVariables(self): |
| return TypeConstructor(self.name, |
| [a.negateVariables() for a in self.arguments]) |
|
|
| def instantiate(self, context, bindings=None): |
| if not self.isPolymorphic: |
| return context, self |
| if bindings is None: |
| bindings = {} |
| newArguments = [] |
| for x in self.arguments: |
| (context, x) = x.instantiate(context, bindings) |
| newArguments.append(x) |
| return (context, TypeConstructor(self.name, newArguments)) |
|
|
| def instantiateMutable(self, context, bindings=None): |
| if not self.isPolymorphic: |
| return self |
| if bindings is None: |
| bindings = {} |
| newArguments = [] |
| return TypeConstructor(self.name, [x.instantiateMutable(context, bindings) |
| for x in self.arguments ]) |
| |
|
|
| def canonical(self, bindings=None): |
| if not self.isPolymorphic: |
| return self |
| if bindings is None: |
| bindings = {} |
| return TypeConstructor(self.name, |
| [x.canonical(bindings) for x in self.arguments]) |
|
|
|
|
| class TypeVariable(Type): |
| def __init__(self, j): |
| assert isinstance(j, int) |
| self.v = j |
| self.isPolymorphic = True |
|
|
| def makeDummyMonomorphic(self, mapping=None): |
| mapping = mapping if mapping is not None else {} |
| if self.v not in mapping: |
| mapping[self.v] = TypeConstructor(f"dummy_type_{len(mapping)}", []) |
| return mapping[self.v] |
| |
|
|
| def __eq__(self, other): |
| return isinstance(other, TypeVariable) and self.v == other.v |
|
|
| def __ne__(self, other): return not (self.v == other.v) |
|
|
| def __hash__(self): return self.v |
|
|
| def show(self, _): return "t%d" % self.v |
|
|
| def json(self): |
| return {"index": self.v} |
|
|
| def returns(self): return self |
|
|
| def isArrow(self): return False |
|
|
| def functionArguments(self): return [] |
|
|
| def apply(self, context): |
| for v, t in context.substitution: |
| if v == self.v: |
| return t.apply(context) |
| return self |
|
|
| def applyMutable(self, context): |
| s = context.substitution[self.v] |
| if s is None: return self |
| new = s.applyMutable(context) |
| context.substitution[self.v] = new |
| return new |
|
|
| def occurs(self, v): return v == self.v |
|
|
| def instantiate(self, context, bindings=None): |
| if bindings is None: |
| bindings = {} |
| if self.v in bindings: |
| return (context, bindings[self.v]) |
| new = TypeVariable(context.nextVariable) |
| bindings[self.v] = new |
| context = Context(context.nextVariable + 1, context.substitution) |
| return (context, new) |
|
|
| def instantiateMutable(self, context, bindings=None): |
| if bindings is None: bindings = {} |
| if self.v in bindings: return bindings[self.v] |
| new = context.makeVariable() |
| bindings[self.v] = new |
| return new |
|
|
| def canonical(self, bindings=None): |
| if bindings is None: |
| bindings = {} |
| if self.v in bindings: |
| return bindings[self.v] |
| new = TypeVariable(len(bindings)) |
| bindings[self.v] = new |
| return new |
|
|
| def negateVariables(self): |
| return TypeVariable(-1 - self.v) |
|
|
|
|
| class Context(object): |
| def __init__(self, nextVariable=0, substitution=[]): |
| self.nextVariable = nextVariable |
| self.substitution = substitution |
|
|
| def extend(self, j, t): |
| return Context(self.nextVariable, [(j, t)] + self.substitution) |
|
|
| def makeVariable(self): |
| return (Context(self.nextVariable + 1, self.substitution), |
| TypeVariable(self.nextVariable)) |
|
|
| def unify(self, t1, t2): |
| t1 = t1.apply(self) |
| t2 = t2.apply(self) |
| if t1 == t2: |
| return self |
| |
| if not t1.isPolymorphic and not t2.isPolymorphic: |
| raise UnificationFailure(t1, t2) |
|
|
| if isinstance(t1, TypeVariable): |
| if t2.occurs(t1.v): |
| raise Occurs() |
| return self.extend(t1.v, t2) |
| if isinstance(t2, TypeVariable): |
| if t1.occurs(t2.v): |
| raise Occurs() |
| return self.extend(t2.v, t1) |
| if t1.name != t2.name: |
| raise UnificationFailure(t1, t2) |
| k = self |
| for x, y in zip(t2.arguments, t1.arguments): |
| k = k.unify(x, y) |
| return k |
|
|
| def __str__(self): |
| return "Context(next = %d, {%s})" % (self.nextVariable, ", ".join( |
| "t%d ||> %s" % (k, v.apply(self)) for k, v in self.substitution)) |
|
|
| def __repr__(self): return str(self) |
|
|
| class MutableContext(object): |
| def __init__(self): |
| self.substitution = [] |
|
|
| def extend(self,i,t): |
| assert self.substitution[i] is None |
| self.substitution[i] = t |
|
|
| def makeVariable(self): |
| self.substitution.append(None) |
| return TypeVariable(len(self.substitution) - 1) |
|
|
| def unify(self, t1, t2): |
| t1 = t1.applyMutable(self) |
| t2 = t2.applyMutable(self) |
|
|
| if t1 == t2: return |
|
|
| |
| if not t1.isPolymorphic and not t2.isPolymorphic: |
| raise UnificationFailure(t1, t2) |
|
|
| if isinstance(t1, TypeVariable): |
| if t2.occurs(t1.v): |
| raise Occurs() |
| self.extend(t1.v, t2) |
| return |
| if isinstance(t2, TypeVariable): |
| if t1.occurs(t2.v): |
| raise Occurs() |
| self.extend(t2.v, t1) |
| return |
| if t1.name != t2.name: |
| raise UnificationFailure(t1, t2) |
| |
| for x, y in zip(t2.arguments, t1.arguments): |
| self.unify(x, y) |
|
|
|
|
| Context.EMPTY = Context(0, []) |
|
|
|
|
| def canonicalTypes(ts): |
| bindings = {} |
| return [t.canonical(bindings) for t in ts] |
|
|
|
|
| def instantiateTypes(context, ts): |
| bindings = {} |
| newTypes = [] |
| for t in ts: |
| context, t = t.instantiate(context, bindings) |
| newTypes.append(t) |
| return context, newTypes |
|
|
|
|
| def baseType(n): return TypeConstructor(n, []) |
|
|
|
|
| tint = baseType("int") |
| treal = baseType("real") |
| tbool = baseType("bool") |
| tboolean = tbool |
| tcharacter = baseType("char") |
|
|
|
|
| def tlist(t): return TypeConstructor("list", [t]) |
|
|
|
|
| def tpair(a, b): return TypeConstructor("pair", [a, b]) |
|
|
|
|
| def tmaybe(t): return TypeConstructor("maybe", [t]) |
|
|
|
|
| tstr = tlist(tcharacter) |
| t0 = TypeVariable(0) |
| t1 = TypeVariable(1) |
| t2 = TypeVariable(2) |
|
|
| |
| tpregex = baseType("pregex") |
|
|
| ARROW = "->" |
|
|
|
|
| def arrow(*arguments): |
| if len(arguments) == 1: |
| return arguments[0] |
| return TypeConstructor(ARROW, [arguments[0], arrow(*arguments[1:])]) |
|
|
|
|
| def inferArg(tp, tcaller): |
| ctx, tp = tp.instantiate(Context.EMPTY) |
| ctx, tcaller = tcaller.instantiate(ctx) |
| ctx, targ = ctx.makeVariable() |
| ctx = ctx.unify(tcaller, arrow(targ, tp)) |
| return targ.apply(ctx) |
|
|
|
|
| def guess_type(xs): |
| """ |
| Return a TypeConstructor corresponding to x's python type. |
| Raises an exception if the type cannot be guessed. |
| """ |
| if all(isinstance(x, bool) for x in xs): |
| return tbool |
| elif all(isinstance(x, int) for x in xs): |
| return tint |
| elif all(isinstance(x, str) for x in xs): |
| return tstr |
| elif all(isinstance(x, list) for x in xs): |
| return tlist(guess_type([y for ys in xs for y in ys])) |
| else: |
| raise ValueError("cannot guess type from {}".format(xs)) |
|
|
|
|
| def guess_arrow_type(examples): |
| a = len(examples[0][0]) |
| input_types = [] |
| for n in range(a): |
| input_types.append(guess_type([xs[n] for xs, _ in examples])) |
| output_type = guess_type([y for _, y in examples]) |
| return arrow(*(input_types + [output_type])) |
|
|
| def canUnify(t1, t2): |
| k = MutableContext() |
| t1 = t1.instantiateMutable(k) |
| t2 = t2.instantiateMutable(k) |
| try: |
| k.unify(t1, t2) |
| return True |
| except UnificationFailure: return False |
| |
|
|