Source code for fosf.reasoning.theory

#!/usr/bin/env python3

from collections import defaultdict, deque
from typing import overload

from fosf.reasoning import TermUnifier
from fosf.syntax.base import Tag, Sort, Feature
from fosf.syntax.constraints import Constraint, FeatureConstraint, SortConstraint
from fosf.syntax.taxonomy import SortTaxonomy, FuzzySortTaxonomy
from fosf.syntax.terms import Term, NormalTerm
from fosf.syntax.theory import OsfTheory


class _Frame:

    def __init__(self, tag: Tag):
        self.tag = tag
        self.local_tags = defaultdict(set) # Map global X to local Ys
        self.global_tag = dict()           # Map local Y to global X
        self.main_sort: Sort

    def pairs(self):
        for x, ys in self.local_tags.items():
            for y in ys:
                yield (x, y)

    def __str__(self):
        _locals = ", ".join([f"{Y}/{X}" for Y, X in self.pairs()])
        return f"{self.tag}:{self.main_sort} [{_locals}]"

    def __repr__(self):
        _locals = ", ".join([f"{Y}/{X}" for Y, X in self.pairs()])
        return f"{self.tag}:{self.main_sort} [{_locals}]"



[docs] class TheoryTermNormalizer: def __init__(self): self.taxonomy: SortTaxonomy # X.f = Y -> self.rep_to_feats[X][f] = Y self.rep_to_feats: dict[Tag, dict[Feature, Tag]] # For union-find self._parents: dict self._indices: dict self._cost: dict # Theory for normalization self.theory: OsfTheory # Map a tag to the frames where it appears self.global_to_frames: dict[Tag, set[Tag]] # Map a tag to their main frame self.frames: dict[Tag, _Frame] self.eq_queue: deque[tuple[Tag, Tag]] self.queue: deque[Constraint] self.rule_9_stack: dict[Tag, Feature] @overload def normalize(self, term: Term, theory: OsfTheory, normalize_first: bool, return_degree=False) -> NormalTerm: ... @overload def normalize(self, term: Term, theory: OsfTheory, normalize_first: bool, return_degree=True) -> tuple[NormalTerm, float]: ... # Main method
[docs] def normalize(self, term: Term, theory: OsfTheory, normalize_first=True, return_degree=False) -> NormalTerm | tuple[NormalTerm, float]: cn = None if normalize_first: cn = TermUnifier() term = cn.normalize(term, theory.taxonomy) self._init_structures(term, theory, cn) root = term.X for c in term.generate_constraints(): self.queue.append(c) while self.eq_queue or self.queue or self.rule_9_stack: if self.eq_queue: X, Y = self.eq_queue.popleft() consistent = self._process_equality_constraint(X, Y) elif self.queue: consistent = self._process_constraint(self.queue.popleft()) else: X, f = self.rule_9_stack.popitem() Z = self._new_tag() consistent = self._process_feature_constraint(X, f, Z) if not consistent: FAIL_TAG = Tag("_FAIL") return NormalTerm(FAIL_TAG, self.taxonomy.bot) output = self._build_output(root) if return_degree: if isinstance(self.taxonomy, FuzzySortTaxonomy): return output, self._subsumption_degree() return output, 1.0 return output
def _init_structures(self, clause, theory, cn=None): self.taxonomy = theory.taxonomy self.rep_to_feats = defaultdict(lambda: dict()) # For union-find if cn: self._parents = cn._parents.copy() # TODO keep copy? self._indices = cn._indices.copy() else: self._parents = {} self._indices = {} for X in clause.tags(): self._add_tag(X) self._cost = defaultdict(lambda: 0) self.theory = theory self.global_to_frames = defaultdict(set) self.frames = dict() self.tag_counter = 0 self.eq_queue = deque() # High priority constraints self.queue = deque() self.rule_9_stack = dict() # Union-find methods def _add_tag(self, X: Tag): if X in self._indices: return self._parents[X] = X self._indices[X] = len(self._indices)
[docs] def deref_tag(self, X: Tag) -> Tag: if X not in self._indices: raise KeyError(X) parents = self._parents while self._indices[X] != self._indices[parents[X]]: parents[X] = parents[parents[X]] X = parents[X] return X
def _merge_tags(self, X: Tag, Y: Tag) -> tuple[bool, Tag, Tag]: X, Y = self.deref_tag(X), self.deref_tag(Y) if self._indices[X] == self._indices[Y]: return False, X, Y # Merge Y into X or X into Y depending on cost fX = X in self.frames fY = Y in self.frames if (fX, self._cost[X], self._indices[Y]) < (fY, self._cost[Y], self._indices[X]): X, Y = Y, X self._parents[Y] = X return True, X, Y def _connected(self, X: Tag, Y: Tag) -> bool: return self._indices[self.deref_tag(X)] == self._indices[self.deref_tag(Y)] # Methods for processing constraints def _process_constraint(self, c: Constraint): consistent = True if isinstance(c, SortConstraint): consistent = self._process_sort_constraint(c.X, c.s) if isinstance(c, FeatureConstraint): consistent = self._process_feature_constraint(c.X, c.f, c.Y) return consistent def _process_sort_constraint(self, X, s): X = self.deref_tag(X) if X in self.frames: # Refine frame current_s = self.frames[X].main_sort if self.taxonomy.is_subsort(current_s, s): return True glb = self.taxonomy.glb(s, current_s) if glb == self.taxonomy.bot: return False self.frames[X].main_sort = glb Y_glb = self.theory.definitions[glb].X self._update_frame_locals(X, X, Y_glb) self._check_features(X, Y_glb) return True # Otherwise, initialize frame Y = self.theory.definitions[s].X self.frames[X] = _Frame(X) self.frames[X].main_sort = s self.frames[X].local_tags[X].add(Y) self.frames[X].global_tag[Y] = X self.global_to_frames[X].add(X) self._cost[X] += 1 self._check_features(X, Y) # TODO Necessary here? self._check_rule_9(X, Y, X) return True def _process_feature_constraint(self, X1, f, X2): X1, X2 = self.deref_tag(X1), self.deref_tag(X2) if f in self.rep_to_feats[X1]: X3 = self.deref_tag(self.rep_to_feats[X1][f]) if X3 != X2: self.eq_queue.append((X2, X3)) return True self.rep_to_feats[X1][f] = X2 self._cost[X1] += 1 # Apply rule 5 for X in self.global_to_frames[X1]: updates = set() for Y1 in self.frames[X].local_tags[X1]: if f in self.theory.features(Y1): Y2 = self.theory.features(Y1)[f] sort = self.theory.sort(Y2) if sort is not None: self.queue.append(SortConstraint(X2, sort)) if X1 == X2: # self._update_frame_locals(X, X2, Y2) would # modify self.frames[X].local_tags[X1] # resulting in a RuntimeError # We delay calling the method in this case updates.add((X, X2, Y2)) else: self._update_frame_locals(X, X2, Y2) for args in updates: self._update_frame_locals(*args) return True def _process_equality_constraint(self, X1, X2): merged, X1, X2 = self._merge_tags(X1, X2) if not merged: return True # Merge features for f, Z in self.rep_to_feats[X2].items(): Z = self.deref_tag(Z) if f in self.rep_to_feats[X1]: Z1 = self.deref_tag(self.rep_to_feats[X1][f]) if Z1 != Z: self.eq_queue.append((Z1, Z)) else: self.rep_to_feats[X1][f] = Z # Merge frames (either both exist, or only X1's exists, or neither) if X1 in self.frames and X2 in self.frames: frame1 = self.frames[X1] frame2 = self.frames[X2] ss1 = frame1.main_sort ss2 = frame2.main_sort s_glb = self.theory.taxonomy.glb(ss1, ss2) frame1.main_sort = s_glb if s_glb == self.taxonomy.bot: return False Yglb = self.theory.definitions[s_glb].X for X, Y in frame2.pairs(): if X == X2 or Y == X2: continue self.global_to_frames[X].discard(X2) self._update_frame_locals(X1, X, Y) self._update_frame_locals(X1, X1, Yglb) self._check_features(X1, Yglb) for X in self.global_to_frames[X2]-{X1, X2}: frame = self.frames[X] Y1s = frame.local_tags[X2] for Y1 in Y1s: self._update_frame_locals(X, X1, Y1) frame.local_tags.pop(X2) return True # Utilities and data structure maintenance def _new_tag(self): while (tag := Tag(f"Z{self.tag_counter}")) in self._indices: self.tag_counter += 1 self._add_tag(tag) return tag def _check_common_features(self, X:Tag, Y1: Tag, Y2: Tag): Y1_features = self.theory.features(Y1) Y2_features = self.theory.features(Y2) for f in Y1_features.keys() & Y2_features.keys(): self.rule_9_stack[X] = f def _check_features(self, X: Tag, Y: Tag): # TODO: check efficiency: is it always necessary to recurse? frame_X = X stack = {(X, Y)} seen = set() while stack: X, Y = stack.pop() if (X, Y) in seen: continue seen.add((X, Y)) for f, X1 in self.rep_to_feats[X].items(): X1 = self.deref_tag(X1) if f in self.theory.features(Y): Y1 = self.theory.features(Y)[f] self._update_frame_locals(frame_X, X1, Y1) stack.add((X1, Y1)) s1 = self.theory.sort(Y1) if s1 is not None: c = SortConstraint(X1, s1) self.queue.append(c) def _update_frame_locals(self, X, X1, Y1): # Check local tags associatd with X1 in X's frame frame = self.frames[X] if X1 not in frame.local_tags: frame.local_tags[X1].add(Y1) else: current_Y1s = frame.local_tags[X1] if Y1 in current_Y1s: return add = False to_remove = set() for current_Y1 in current_Y1s: if self.theory.tag_taxonomy.is_subsort(current_Y1, Y1): # We already have a more specific version of Y1 in this frame # No need to check anything else add = False break add = True if self.theory.tag_taxonomy.is_subsort(Y1, current_Y1): to_remove.add(current_Y1) else: # rule 9 self._check_common_features(X1, current_Y1, Y1) frame.local_tags[X1].difference_update(to_remove) if add: frame.local_tags[X1].add(Y1) self._check_rule_9(X1, Y1, tag_to_exclude=X) # Check the global tag for Y1 in X if Y1 not in frame.global_tag: frame.global_tag[Y1] = X1 else: current_X1 = self.deref_tag(frame.global_tag[Y1]) if X1 != current_X1: self.eq_queue.append((X1, current_X1)) self.global_to_frames[X1].add(X) def _check_rule_9(self, X: Tag, Y: Tag, tag_to_exclude=None): done = set() for tag in self.global_to_frames[X]-{tag_to_exclude}: current_locals = self.frames[tag].local_tags[X] if Y in current_locals: return for Yp in current_locals: if Yp in done: continue done.add(Yp) if self.theory.tag_taxonomy.is_subsort(Yp, Y): # We already have a more specific version of Y in this frame # No need to check anything else break if self.theory.tag_taxonomy.is_subsort(Y, Yp): # Our tag is more specific that Yp, but might have more features # We continue checking the others continue self._check_common_features(X, Y, Yp) def _build_output(self, root: Tag) -> NormalTerm: def visit(tag): rep = self.deref_tag(tag) if rep in seen: return NormalTerm(rep) seen.add(rep) if rep in self.frames: sort = self.frames[rep].main_sort else: sort = self.taxonomy.top subterms = {f: visit(other) for f, other in self.rep_to_feats[rep].items()} return NormalTerm(rep, sort, subterms) seen = set() return visit(root) def _subsumption_pairs(self): derefed_tags = {self.deref_tag(tag) for tag in self._indices} sub_pairs = defaultdict(set) for tag in derefed_tags: frame = self.frames[tag] for X, Y in frame.pairs(): X = self.deref_tag(X) if X not in self.frames: continue s = self.frames[X].main_sort t = self.theory.sort(Y) if s != t: sub_pairs[s].add(t) return sub_pairs def _subsumption_degree(self) -> float: sub_pairs = self._subsumption_pairs() if not sub_pairs: # if sub_pairs is empty, s and t in _subsumption_pairs were always equal return 1.0 sources = sub_pairs.keys() targets = {val for values in sub_pairs.values() for val in values} degrees = self.theory.taxonomy.degree(sources, targets) alpha = min(degrees[s][t] for s, ts in sub_pairs.items() for t in ts) return alpha