conjure_core/solver/adaptors/rustsat/
adaptor.rs

1use std::any::type_name;
2use std::fmt::format;
3use std::hash::Hash;
4use std::iter::Inspect;
5use std::ops::Deref;
6use std::ptr::null;
7use std::vec;
8
9use clap::error;
10use minion_rs::ast::{Model, Tuple};
11use rustsat::encodings::am1::Def;
12use rustsat::solvers::{Solve, SolverResult};
13use rustsat::types::{Assignment, Clause, Lit, TernaryVal, Var as satVar};
14use std::collections::{BTreeMap, HashMap};
15use std::result::Result::Ok;
16use tracing_subscriber::filter::DynFilterFn;
17
18use crate::ast::Domain::Bool;
19
20use rustsat_minisat::core::Minisat;
21
22use crate::ast::{Atom, Expression, Literal, Name};
23use crate::metadata::Metadata;
24use crate::solver::adaptors::rustsat::convs::handle_cnf;
25use crate::solver::SearchComplete::NoSolutions;
26use crate::solver::{
27    self, private, SearchStatus, SolveSuccess, SolverAdaptor, SolverCallback, SolverError,
28    SolverFamily, SolverMutCallback,
29};
30use crate::stats::SolverStats;
31use crate::{ast as conjure_ast, bug, Model as ConjureModel};
32use crate::{into_matrix_expr, matrix_expr};
33
34use rustsat::instances::{BasicVarManager, Cnf, ManageVars, SatInstance};
35
36use thiserror::Error;
37
38/// A [SolverAdaptor] for interacting with the SatSolver generic and the types thereof.
39pub struct Sat {
40    __non_constructable: private::Internal,
41    model_inst: Option<SatInstance>,
42    var_map: Option<HashMap<String, Lit>>,
43    solver_inst: Minisat,
44    decision_refs: Option<Vec<String>>,
45}
46
47impl private::Sealed for Sat {}
48
49impl Default for Sat {
50    fn default() -> Self {
51        Sat {
52            __non_constructable: private::Internal,
53            solver_inst: Minisat::default(),
54            var_map: None,
55            model_inst: None,
56            decision_refs: None,
57        }
58    }
59}
60
61fn get_ref_sols(
62    find_refs: Vec<String>,
63    sol: Assignment,
64    var_map: HashMap<String, Lit>,
65) -> HashMap<Name, Literal> {
66    let mut solution: HashMap<Name, Literal> = HashMap::new();
67
68    for reference in find_refs {
69        // lit is 'Nothing' for unconstrained - if this is actually happenning, panicking is fine
70        // we are not supposed to do anything to resolve that here.
71        let lit: Lit = match var_map.get(&reference) {
72            Some(a) => *a,
73            None => panic!(
74                "There should never be a non-just literal occurring here. Something is broken upstream."
75            ),
76        };
77        solution.insert(
78            Name::User(reference),
79            match sol[lit.var()] {
80                TernaryVal::True => Literal::Int(1),
81                TernaryVal::False => Literal::Int(0),
82                TernaryVal::DontCare => Literal::Int(2),
83            },
84        );
85    }
86
87    solution
88}
89
90impl SolverAdaptor for Sat {
91    fn solve(
92        &mut self,
93        callback: SolverCallback,
94        _: private::Internal,
95    ) -> Result<SolveSuccess, SolverError> {
96        let mut solver = &mut self.solver_inst;
97
98        let cnf: (Cnf, BasicVarManager) = self.model_inst.clone().unwrap().into_cnf();
99
100        (*(solver)).add_cnf(cnf.0);
101
102        let mut has_sol = false;
103        loop {
104            let res = solver.solve().unwrap();
105
106            match res {
107                SolverResult::Sat => {}
108                SolverResult::Unsat => {
109                    return Ok(SolveSuccess {
110                        stats: SolverStats {
111                            conjure_solver_wall_time_s: -1.0,
112                            solver_family: Some(self.get_family()),
113                            solver_adaptor: Some("SAT".to_string()),
114                            nodes: None,
115                            satisfiable: None,
116                            sat_vars: None,
117                            sat_clauses: None,
118                        },
119                        status: if has_sol {
120                            SearchStatus::Complete(solver::SearchComplete::HasSolutions)
121                        } else {
122                            SearchStatus::Complete(NoSolutions)
123                        },
124                    });
125                }
126                SolverResult::Interrupted => {
127                    return Err(SolverError::Runtime("!!Interrupted Solution!!".to_string()))
128                }
129            };
130
131            let sol = solver.full_solution().unwrap();
132            has_sol = true;
133            let solution = get_ref_sols(
134                self.decision_refs.clone().unwrap(),
135                sol.clone(),
136                self.var_map.clone().unwrap(),
137            );
138
139            if !callback(solution) {
140                // println!("callback false");
141                return Ok(SolveSuccess {
142                    stats: SolverStats {
143                        conjure_solver_wall_time_s: -1.0,
144                        solver_family: Some(self.get_family()),
145                        solver_adaptor: Some("SAT".to_string()),
146                        nodes: None,
147                        satisfiable: None,
148                        sat_vars: None,
149                        sat_clauses: None,
150                    },
151                    status: SearchStatus::Incomplete(solver::SearchIncomplete::UserTerminated),
152                });
153            }
154
155            let blocking_vec: Vec<_> = sol.clone().iter().map(|lit| !lit).collect();
156            let mut blocking_cl = Clause::new();
157            for lit_i in blocking_vec {
158                blocking_cl.add(lit_i);
159            }
160
161            solver.add_clause(blocking_cl).unwrap();
162        }
163    }
164
165    fn solve_mut(
166        &mut self,
167        callback: SolverMutCallback,
168        _: private::Internal,
169    ) -> Result<SolveSuccess, SolverError> {
170        Err(SolverError::OpNotSupported("solve_mut".to_owned()))
171    }
172
173    fn load_model(&mut self, model: ConjureModel, _: private::Internal) -> Result<(), SolverError> {
174        let sym_tab = model.as_submodel().symbols().deref().clone();
175        let decisions = sym_tab.into_iter();
176
177        let mut finds: Vec<String> = Vec::new();
178
179        for find_ref in decisions {
180            if (*find_ref.1.domain().unwrap() != Bool) {
181                Err(SolverError::ModelInvalid(
182                    "Only Boolean Decision Variables supported".to_string(),
183                ))?;
184            }
185
186            let name = find_ref.0;
187            finds.push(name.to_string());
188        }
189
190        self.decision_refs = Some(finds);
191
192        let m_clone = model.clone();
193        let vec_constr = m_clone.as_submodel().constraints();
194
195        let vec_cnf = vec_constr.clone();
196
197        let mut var_map: HashMap<String, Lit> = HashMap::new();
198
199        let inst: SatInstance = handle_cnf(&vec_cnf, &mut var_map);
200
201        self.var_map = Some(var_map);
202        let cnf: (Cnf, BasicVarManager) = inst.clone().into_cnf();
203        tracing::info!("CNF: {:?}", cnf.0);
204        self.model_inst = Some(inst);
205
206        Ok(())
207    }
208
209    fn init_solver(&mut self, _: private::Internal) {}
210
211    fn get_family(&self) -> SolverFamily {
212        SolverFamily::Sat
213    }
214
215    fn get_name(&self) -> Option<String> {
216        Some("SAT".to_string())
217    }
218
219    fn add_adaptor_info_to_stats(&self, stats: SolverStats) -> SolverStats {
220        SolverStats {
221            solver_adaptor: self.get_name(),
222            solver_family: Some(self.get_family()),
223            ..stats
224        }
225    }
226
227    fn write_solver_input_file(
228        &self,
229        writer: &mut impl std::io::Write,
230    ) -> Result<(), std::io::Error> {
231        // TODO: add comments saying what conjure oxide variables each clause has
232        // e.g.
233        //      c y x z
234        //        1 2 3
235        //      c x -y
236        //        1 -1
237        // This will require handwriting a dimacs writer, but that should be easy. For now, just
238        // let rustsat write the dimacs.
239
240        let model = self.model_inst.clone().expect("model should exist when we write the solver input file, as we should be in the LoadedModel state");
241        let (cnf, var_manager): (Cnf, BasicVarManager) = model.into_cnf();
242        cnf.write_dimacs(writer, var_manager.n_used())
243    }
244}