"""Implementation of DPLL algorithm Further improvements: eliminate calls to pl_true, implement branching rules, efficient unit propagation. References: - https://en.wikipedia.org/wiki/DPLL_algorithm - https://www.researchgate.net/publication/242384772_Implementations_of_the_DPLL_Algorithm """ from sympy.core.sorting import default_sort_key from sympy.logic.boolalg import Or, Not, conjuncts, disjuncts, to_cnf, \ to_int_repr, _find_predicates from sympy.assumptions.cnf import CNF from sympy.logic.inference import pl_true, literal_symbol def dpll_satisfiable(expr): """ Check satisfiability of a propositional sentence. It returns a model rather than True when it succeeds >>> from sympy.abc import A, B >>> from sympy.logic.algorithms.dpll import dpll_satisfiable >>> dpll_satisfiable(A & ~B) {A: True, B: False} >>> dpll_satisfiable(A & ~A) False """ if not isinstance(expr, CNF): clauses = conjuncts(to_cnf(expr)) else: clauses = expr.clauses if False in clauses: return False symbols = sorted(_find_predicates(expr), key=default_sort_key) symbols_int_repr = set(range(1, len(symbols) + 1)) clauses_int_repr = to_int_repr(clauses, symbols) result = dpll_int_repr(clauses_int_repr, symbols_int_repr, {}) if not result: return result output = {} for key in result: output.update({symbols[key - 1]: result[key]}) return output def dpll(clauses, symbols, model): """ Compute satisfiability in a partial model. Clauses is an array of conjuncts. >>> from sympy.abc import A, B, D >>> from sympy.logic.algorithms.dpll import dpll >>> dpll([A, B, D], [A, B], {D: False}) False """ # compute DP kernel P, value = find_unit_clause(clauses, model) while P: model.update({P: value}) symbols.remove(P) if not value: P = ~P clauses = unit_propagate(clauses, P) P, value = find_unit_clause(clauses, model) P, value = find_pure_symbol(symbols, clauses) while P: model.update({P: value}) symbols.remove(P) if not value: P = ~P clauses = unit_propagate(clauses, P) P, value = find_pure_symbol(symbols, clauses) # end DP kernel unknown_clauses = [] for c in clauses: val = pl_true(c, model) if val is False: return False if val is not True: unknown_clauses.append(c) if not unknown_clauses: return model if not clauses: return model P = symbols.pop() model_copy = model.copy() model.update({P: True}) model_copy.update({P: False}) symbols_copy = symbols[:] return (dpll(unit_propagate(unknown_clauses, P), symbols, model) or dpll(unit_propagate(unknown_clauses, Not(P)), symbols_copy, model_copy)) def dpll_int_repr(clauses, symbols, model): """ Compute satisfiability in a partial model. Arguments are expected to be in integer representation >>> from sympy.logic.algorithms.dpll import dpll_int_repr >>> dpll_int_repr([{1}, {2}, {3}], {1, 2}, {3: False}) False """ # compute DP kernel P, value = find_unit_clause_int_repr(clauses, model) while P: model.update({P: value}) symbols.remove(P) if not value: P = -P clauses = unit_propagate_int_repr(clauses, P) P, value = find_unit_clause_int_repr(clauses, model) P, value = find_pure_symbol_int_repr(symbols, clauses) while P: model.update({P: value}) symbols.remove(P) if not value: P = -P clauses = unit_propagate_int_repr(clauses, P) P, value = find_pure_symbol_int_repr(symbols, clauses) # end DP kernel unknown_clauses = [] for c in clauses: val = pl_true_int_repr(c, model) if val is False: return False if val is not True: unknown_clauses.append(c) if not unknown_clauses: return model P = symbols.pop() model_copy = model.copy() model.update({P: True}) model_copy.update({P: False}) symbols_copy = symbols.copy() return (dpll_int_repr(unit_propagate_int_repr(unknown_clauses, P), symbols, model) or dpll_int_repr(unit_propagate_int_repr(unknown_clauses, -P), symbols_copy, model_copy)) ### helper methods for DPLL def pl_true_int_repr(clause, model={}): """ Lightweight version of pl_true. Argument clause represents the set of args of an Or clause. This is used inside dpll_int_repr, it is not meant to be used directly. >>> from sympy.logic.algorithms.dpll import pl_true_int_repr >>> pl_true_int_repr({1, 2}, {1: False}) >>> pl_true_int_repr({1, 2}, {1: False, 2: False}) False """ result = False for lit in clause: if lit < 0: p = model.get(-lit) if p is not None: p = not p else: p = model.get(lit) if p is True: return True elif p is None: result = None return result def unit_propagate(clauses, symbol): """ Returns an equivalent set of clauses If a set of clauses contains the unit clause l, the other clauses are simplified by the application of the two following rules: 1. every clause containing l is removed 2. in every clause that contains ~l this literal is deleted Arguments are expected to be in CNF. >>> from sympy.abc import A, B, D >>> from sympy.logic.algorithms.dpll import unit_propagate >>> unit_propagate([A | B, D | ~B, B], B) [D, B] """ output = [] for c in clauses: if c.func != Or: output.append(c) continue for arg in c.args: if arg == ~symbol: output.append(Or(*[x for x in c.args if x != ~symbol])) break if arg == symbol: break else: output.append(c) return output def unit_propagate_int_repr(clauses, s): """ Same as unit_propagate, but arguments are expected to be in integer representation >>> from sympy.logic.algorithms.dpll import unit_propagate_int_repr >>> unit_propagate_int_repr([{1, 2}, {3, -2}, {2}], 2) [{3}] """ negated = {-s} return [clause - negated for clause in clauses if s not in clause] def find_pure_symbol(symbols, unknown_clauses): """ Find a symbol and its value if it appears only as a positive literal (or only as a negative) in clauses. >>> from sympy.abc import A, B, D >>> from sympy.logic.algorithms.dpll import find_pure_symbol >>> find_pure_symbol([A, B, D], [A|~B,~B|~D,D|A]) (A, True) """ for sym in symbols: found_pos, found_neg = False, False for c in unknown_clauses: if not found_pos and sym in disjuncts(c): found_pos = True if not found_neg and Not(sym) in disjuncts(c): found_neg = True if found_pos != found_neg: return sym, found_pos return None, None def find_pure_symbol_int_repr(symbols, unknown_clauses): """ Same as find_pure_symbol, but arguments are expected to be in integer representation >>> from sympy.logic.algorithms.dpll import find_pure_symbol_int_repr >>> find_pure_symbol_int_repr({1,2,3}, ... [{1, -2}, {-2, -3}, {3, 1}]) (1, True) """ all_symbols = set().union(*unknown_clauses) found_pos = all_symbols.intersection(symbols) found_neg = all_symbols.intersection([-s for s in symbols]) for p in found_pos: if -p not in found_neg: return p, True for p in found_neg: if -p not in found_pos: return -p, False return None, None def find_unit_clause(clauses, model): """ A unit clause has only 1 variable that is not bound in the model. >>> from sympy.abc import A, B, D >>> from sympy.logic.algorithms.dpll import find_unit_clause >>> find_unit_clause([A | B | D, B | ~D, A | ~B], {A:True}) (B, False) """ for clause in clauses: num_not_in_model = 0 for literal in disjuncts(clause): sym = literal_symbol(literal) if sym not in model: num_not_in_model += 1 P, value = sym, not isinstance(literal, Not) if num_not_in_model == 1: return P, value return None, None def find_unit_clause_int_repr(clauses, model): """ Same as find_unit_clause, but arguments are expected to be in integer representation. >>> from sympy.logic.algorithms.dpll import find_unit_clause_int_repr >>> find_unit_clause_int_repr([{1, 2, 3}, ... {2, -3}, {1, -2}], {1: True}) (2, False) """ bound = set(model) | {-sym for sym in model} for clause in clauses: unbound = clause - bound if len(unbound) == 1: p = unbound.pop() if p < 0: return -p, False else: return p, True return None, None