Source code for fosf.reasoning.clauses

#!/usr/bin/env python3

from collections import defaultdict
from typing import overload

from fosf.syntax.base import Feature, Tag
from fosf.syntax.constraints import (Constraint, Clause, FeatureConstraint,
                                    EqualityConstraint, SortConstraint,
                                    RootedClause, RootedSolvedClause, SolvedClause)

from fosf.syntax.taxonomy import SortTaxonomy


@overload
def normalize_clause(clause: RootedClause,
                     taxonomy: SortTaxonomy) -> RootedSolvedClause: ...


@overload
def normalize_clause(clause: Clause, taxonomy: SortTaxonomy) -> SolvedClause: ...


[docs] def normalize_clause(clause: Clause, taxonomy: SortTaxonomy): return ClauseNormalizer().normalize(clause, taxonomy)
[docs] class ClauseNormalizer: def __init__(self): self.taxonomy: SortTaxonomy # X : s -> self.rep_to_code[X] = code(s) self.rep_to_code: dict[Tag, int] # X.f = Y -> self.rep_to_feats[X][f] = Y self.rep_to_feats: dict[Tag, dict[Feature, Tag]] # For union-find self._parents: dict self._indices: dict self._cost: dict def _init_structures(self, clause, taxonomy): self.taxonomy = taxonomy self.rep_to_code = defaultdict(lambda: self.taxonomy.top_code) self.rep_to_feats = defaultdict(lambda: dict()) # For union-find self._parents = {} self._indices = {} for X in clause.tags: self._add_tag(X) self._cost = defaultdict(lambda: 0) # Union-find methods def _add_tag(self, X: Tag): if X in self._indices: return self._parents[X] = X self._indices[X] = len(self._indices)
[docs] def deref_tag(self, X: Tag) -> Tag: if X not in self._indices: raise KeyError(X) parents = self._parents while self._indices[X] != self._indices[parents[X]]: parents[X] = parents[parents[X]] X = parents[X] return X
def _merge_tags(self, X: Tag, Y: Tag) -> tuple[bool, Tag, Tag]: X, Y = self.deref_tag(X), self.deref_tag(Y) if self._indices[X] == self._indices[Y]: return False, X, Y # Merge Y into X or X into Y depending on cost if (self._cost[X], self._indices[Y]) < (self._cost[Y], self._indices[X]): X, Y = Y, X self._parents[Y] = X return True, X, Y def _connected(self, X: Tag, Y: Tag) -> bool: return self._indices[self.deref_tag(X)] == self._indices[self.deref_tag(Y)] # Clause normalization methods def __call__(self, clause, taxonomy): return self.normalize(clause, taxonomy) @overload def normalize(self, clause: RootedClause, taxonomy: SortTaxonomy) -> RootedSolvedClause: ... @overload def normalize(self, clause: Clause, taxonomy: SortTaxonomy) -> SolvedClause: ...
[docs] def normalize(self, clause, taxonomy): self._init_structures(clause, taxonomy) for c in clause.constraints: consistent = self._process_constraint(c) if not consistent: FAIL_TAG = Tag("_FAIL") bot = self.taxonomy.bot s = SortConstraint(FAIL_TAG, bot) return RootedSolvedClause(FAIL_TAG, s) # Build the normalized clause if isinstance(clause, RootedClause): return self._build_output(clause.root) return self._build_output()
def _process_constraint(self, c: Constraint): if isinstance(c, SortConstraint): return self._process_sort_constraint(c.X, c.s) if isinstance(c, FeatureConstraint): return self._process_feature_constraint(c.X, c.f, c.Y) if isinstance(c, EqualityConstraint): return self._process_equality_constraint(c.X, c.Y) def _process_sort_constraint(self, X, s): rep = self.deref_tag(X) self.rep_to_code[rep] &= self.taxonomy.code(s) return self._consistency_check(rep) def _process_feature_constraint(self, X, f, Y): X, Y = self.deref_tag(X), self.deref_tag(Y) if f not in self.rep_to_feats[X]: self.rep_to_feats[X][f] = Y self._cost[X] += 1 return True Z = self.deref_tag(self.rep_to_feats[X][f]) if Y != Z: return self._process_equality_constraint(Y, Z) return True def _process_equality_constraint(self, X, Y): stack = [(X, Y)] while stack: X, Y = stack.pop() merged, X, Y = self._merge_tags(X, Y) if not merged: # Tags are already merged continue # Merge sorts self.rep_to_code[X] &= self.rep_to_code[Y] if not self._consistency_check(X): return False # Merge features for f, Z in self.rep_to_feats[Y].items(): Z = self.deref_tag(Z) if f not in self.rep_to_feats[X]: self.rep_to_feats[X][f] = Z self._cost[X] += 1 else: stack.append((self.rep_to_feats[X][f], Z)) return True def _consistency_check(self, tag): return self.rep_to_code[tag] != self.taxonomy.bot_code @overload def _build_output(self) -> SolvedClause: ... @overload def _build_output(self, root: Tag) -> RootedSolvedClause: ... def _build_output(self, root=None) -> SolvedClause: if root is None: clause = SolvedClause() else: clause = RootedSolvedClause(root) seen = set() for tag in self._indices: if (rep := self.deref_tag(tag)) in seen: continue seen.add(rep) for feat, val in self.rep_to_feats[rep].items(): clause.add(FeatureConstraint(rep, feat, self.deref_tag(val))) sort = self.taxonomy._decode(self.rep_to_code[rep]) clause.add(SortConstraint(rep, sort)) return clause