conjure_core/ast/
comprehension.rs

1use std::{
2    cell::RefCell,
3    collections::{HashMap, HashSet},
4    fmt::Display,
5    rc::Rc,
6    sync::{Arc, Mutex, RwLock},
7};
8
9use itertools::Itertools as _;
10use serde::{Deserialize, Serialize};
11use uniplate::{derive::Uniplate, Biplate};
12
13use crate::{
14    ast::{Atom, DeclarationKind},
15    bug,
16    context::Context,
17    into_matrix_expr, matrix_expr,
18    metadata::Metadata,
19    solver::{Solver, SolverError},
20};
21
22use super::{Declaration, Domain, Expression, Model, Name, SubModel, SymbolTable};
23
24pub enum ComprehensionKind {
25    Sum,
26    And,
27    Or,
28}
29/// A comprehension.
30#[derive(Clone, PartialEq, Eq, Uniplate, Serialize, Deserialize, Debug)]
31#[uniplate(walk_into=[SubModel])]
32#[biplate(to=SubModel)]
33#[biplate(to=Expression,walk_into=[SubModel])]
34pub struct Comprehension {
35    return_expression_submodel: SubModel,
36    generator_submodel: SubModel,
37    induction_vars: Vec<Name>,
38}
39
40impl Comprehension {
41    pub fn domain_of(&self, syms: &SymbolTable) -> Option<Domain> {
42        self.return_expression_submodel
43            .clone()
44            .into_single_expression()
45            .domain_of(syms)
46    }
47
48    /// Solves this comprehension using Minion, returning the resulting expressions.
49    ///
50    /// If successful, this modifies the symbol table given to add aux-variables needed inside the
51    /// expanded expressions.
52    pub fn solve_with_minion(
53        self,
54        symtab: &mut SymbolTable,
55    ) -> Result<Vec<Expression>, SolverError> {
56        let minion = Solver::new(crate::solver::adaptors::Minion::new());
57        // FIXME: weave proper context through
58        let mut model = Model::new(Arc::new(RwLock::new(Context::default())));
59
60        // only branch on the induction variables.
61        model.search_order = Some(self.induction_vars.clone());
62
63        *model.as_submodel_mut() = self.generator_submodel.clone();
64
65        let minion = minion.load_model(model.clone())?;
66
67        let values = Arc::new(Mutex::new(Vec::new()));
68        let values_ptr = Arc::clone(&values);
69
70        tracing::debug!(model=%model.clone(),comprehension=%self.clone(),"Minion solving comprehension");
71        minion.solve(Box::new(move |sols| {
72            // TODO: deal with represented names if induction variables are abslits.
73            let values = &mut *values_ptr.lock().unwrap();
74            values.push(sols);
75            true
76        }))?;
77
78        let values = values.lock().unwrap().clone();
79
80        let mut return_expressions = vec![];
81
82        for value in values {
83            // convert back to an expression
84
85            let return_expression_submodel = self.return_expression_submodel.clone();
86            let child_symtab = return_expression_submodel.symbols().clone();
87            let return_expression = return_expression_submodel.into_single_expression();
88
89            // we only want to substitute induction variables.
90            // (definitely not machine names, as they mean something different in this scope!)
91            let value: HashMap<_, _> = value
92                .into_iter()
93                .filter(|(n, _)| self.induction_vars.contains(n))
94                .collect();
95
96            let value_ptr = Arc::new(value);
97            let value_ptr_2 = Arc::clone(&value_ptr);
98
99            // substitute in the values for the induction variables
100            let return_expression = return_expression.transform_bi(Arc::new(move |x: Atom| {
101                let Atom::Reference(ref name) = x else {
102                    return x;
103                };
104
105                // is this referencing an induction var?
106                let Some(lit) = value_ptr_2.get(name) else {
107                    return x;
108                };
109
110                Atom::Literal(lit.clone())
111            }));
112
113            // merge symbol table into parent scope
114
115            // convert machine names in child_symtab to ones that we know are unused in the parent
116            // symtab
117            let mut machine_name_translations: HashMap<Name, Name> = HashMap::new();
118
119            // populate machine_name_translations, and move the declarations from child to parent
120            for (name, decl) in child_symtab.into_iter_local() {
121                // skip givens for induction vars§
122                if value_ptr.get(&name).is_some()
123                    && matches!(decl.kind(), DeclarationKind::Given(_))
124                {
125                    continue;
126                }
127
128                let Name::Machine(_) = &name else {
129                    bug!("the symbol table of the return expression of a comprehension should only contain machine names");
130                };
131
132                let new_machine_name = symtab.gensym();
133
134                let new_decl = (*decl).clone().with_new_name(new_machine_name.clone());
135                symtab.insert(Rc::new(new_decl)).unwrap();
136
137                machine_name_translations.insert(name, new_machine_name);
138            }
139
140            // rename references to aux vars in the return_expression
141            let return_expression =
142                return_expression.transform_bi(Arc::new(
143                    move |name| match machine_name_translations.get(&name) {
144                        Some(new_name) => new_name.clone(),
145                        None => name,
146                    },
147                ));
148
149            return_expressions.push(return_expression);
150        }
151
152        Ok(return_expressions)
153    }
154
155    pub fn return_expression(self) -> Expression {
156        self.return_expression_submodel.into_single_expression()
157    }
158
159    pub fn replace_return_expression(&mut self, new_expr: Expression) {
160        let new_expr = match new_expr {
161            Expression::And(_, exprs) if exprs.clone().unwrap_list().is_some() => {
162                Expression::Root(Metadata::new(), exprs.unwrap_list().unwrap())
163            }
164            expr => Expression::Root(Metadata::new(), vec![expr]),
165        };
166
167        *self.return_expression_submodel.root_mut_unchecked() = new_expr;
168    }
169
170    /// Adds a guard to the comprehension. Returns false if the guard does not only reference induction variables.
171    pub fn add_induction_guard(&mut self, guard: Expression) -> bool {
172        if self.is_induction_guard(&guard) {
173            self.generator_submodel.add_constraint(guard);
174            true
175        } else {
176            false
177        }
178    }
179
180    /// True iff expr only references induction variables.
181    pub fn is_induction_guard(&self, expr: &Expression) -> bool {
182        is_induction_guard(&(self.induction_vars.clone().into_iter().collect()), expr)
183    }
184}
185
186impl Display for Comprehension {
187    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188        let generators: String = self
189            .generator_submodel
190            .symbols()
191            .clone()
192            .into_iter_local()
193            .map(|(name, decl)| (name, decl.domain().unwrap().clone()))
194            .map(|(name, domain)| format!("{name}: {domain}"))
195            .join(",");
196
197        let guards = self
198            .generator_submodel
199            .constraints()
200            .iter()
201            .map(|x| format!("{x}"))
202            .join(",");
203
204        let generators_and_guards = itertools::join([generators, guards], ",");
205
206        let expression = &self.return_expression_submodel;
207        write!(f, "[{expression} | {generators_and_guards}]")
208    }
209}
210
211/// A builder for a comprehension.
212#[derive(Clone, Debug, PartialEq, Eq, Default)]
213pub struct ComprehensionBuilder {
214    guards: Vec<Expression>,
215    generators: Vec<(Name, Domain)>,
216    induction_variables: HashSet<Name>,
217}
218
219impl ComprehensionBuilder {
220    pub fn new() -> Self {
221        Default::default()
222    }
223    pub fn guard(mut self, guard: Expression) -> Self {
224        self.guards.push(guard);
225        self
226    }
227
228    pub fn generator(mut self, name: Name, domain: Domain) -> Self {
229        assert!(!self.induction_variables.contains(&name));
230        self.induction_variables.insert(name.clone());
231        self.generators.push((name, domain));
232        self
233    }
234
235    /// Creates a comprehension with the given return expression.
236    ///
237    /// If a comprehension kind is not given, comprehension guards containing decision variables
238    /// are invalid, and will cause a panic.
239    pub fn with_return_value(
240        self,
241        mut expression: Expression,
242        parent: Rc<RefCell<SymbolTable>>,
243        comprehension_kind: Option<ComprehensionKind>,
244    ) -> Comprehension {
245        let mut generator_submodel = SubModel::new(parent.clone());
246
247        // TODO:also allow guards that reference lettings and givens.
248
249        let induction_variables = self.induction_variables;
250
251        // only guards referencing induction variables can go inside the comprehension
252        let (induction_guards, other_guards): (Vec<_>, Vec<_>) = self
253            .guards
254            .into_iter()
255            .partition(|x| is_induction_guard(&induction_variables, x));
256
257        // handle guards that reference non-induction variables
258        if !other_guards.is_empty() {
259            let comprehension_kind = comprehension_kind.expect(
260                "if any guards reference decision variables, a comprehension kind should be given",
261            );
262
263            let guard_expr = match other_guards.as_slice() {
264                [x] => x.clone(),
265                xs => Expression::And(Metadata::new(), Box::new(into_matrix_expr!(xs.to_vec()))),
266            };
267
268            expression = match comprehension_kind {
269                ComprehensionKind::And => {
270                    Expression::Imply(Metadata::new(), Box::new(guard_expr), Box::new(expression))
271                }
272                ComprehensionKind::Or => Expression::And(
273                    Metadata::new(),
274                    Box::new(Expression::And(
275                        Metadata::new(),
276                        Box::new(matrix_expr![guard_expr, expression]),
277                    )),
278                ),
279
280                ComprehensionKind::Sum => {
281                    panic!("guards that reference decision variables not yet implemented for sum");
282                }
283            }
284        }
285
286        generator_submodel.add_constraints(induction_guards);
287        for (name, domain) in self.generators.clone() {
288            generator_submodel
289                .symbols_mut()
290                .insert(Rc::new(Declaration::new_var(name, domain)));
291        }
292
293        // The return_expression is a sub-model of `parent` containing the return_expression and
294        // the induction variables as givens. This allows us to rewrite it as per usual without
295        // doing weird things to the induction vars.
296        //
297        // All the machine name declarations created by flattening the return expression will be
298        // kept inside the scope, allowing us to duplicate them during unrolling (we need a copy of
299        // each aux var for each set of assignments of induction variables).
300
301        let mut return_expression_submodel = SubModel::new(parent);
302        for (name, domain) in self.generators {
303            return_expression_submodel
304                .symbols_mut()
305                .insert(Rc::new(Declaration::new_given(name, domain)))
306                .unwrap();
307        }
308
309        return_expression_submodel.add_constraint(expression);
310
311        Comprehension {
312            return_expression_submodel,
313            generator_submodel,
314            induction_vars: induction_variables.into_iter().collect_vec(),
315        }
316    }
317}
318
319/// True iff the guard only references induction variables.
320fn is_induction_guard(induction_variables: &HashSet<Name>, guard: &Expression) -> bool {
321    guard
322        .universe_bi()
323        .iter()
324        .all(|x| induction_variables.contains(x))
325}