conjure_cp_core/ast/
comprehension.rs

1#![allow(clippy::arc_with_non_send_sync)]
2
3mod expand_ac;
4
5use std::{
6    cell::RefCell,
7    collections::{BTreeSet, HashMap},
8    fmt::Display,
9    rc::Rc,
10    sync::{
11        Arc, Mutex, RwLock,
12        atomic::{AtomicBool, Ordering},
13    },
14};
15
16use expand_ac::add_return_expression_to_generator_model;
17use itertools::Itertools as _;
18use serde::{Deserialize, Serialize};
19use tracing::warn;
20use uniplate::{Biplate, Uniplate};
21
22use crate::{
23    ast::{
24        Atom, DeclarationKind,
25        serde::{HasId as _, ObjId},
26    },
27    bug,
28    context::Context,
29    into_matrix_expr, matrix_expr,
30    metadata::Metadata,
31    rule_engine::{resolve_rule_sets, rewrite_morph, rewrite_naive},
32    solver::{Solver, SolverError},
33};
34
35use super::{
36    DeclarationPtr, Domain, Expression, Model, Moo, Name, Range, SubModel, SymbolTable,
37    ac_operators::ACOperatorKind,
38};
39
40// TODO: better way to specify this?
41
42/// The rewriter to use for rewriting comprehensions.
43///
44/// True for optimised, false for naive
45pub static USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS: AtomicBool = AtomicBool::new(false);
46
47// TODO: do not use Names to compare variables, use DeclarationPtr and ids instead
48// see issue #930
49//
50// this will simplify *a lot* of the knarly stuff here, but can only be done once everything else
51// uses DeclarationPtr.
52//
53// ~ nikdewally, 10/06/25
54
55pub enum ComprehensionKind {
56    Sum,
57    And,
58    Or,
59}
60/// A comprehension.
61#[derive(Clone, PartialEq, Eq, Uniplate, Serialize, Deserialize, Debug)]
62#[biplate(to=SubModel)]
63#[biplate(to=Expression)]
64pub struct Comprehension {
65    return_expression_submodel: SubModel,
66    generator_submodel: SubModel,
67    induction_vars: Vec<Name>,
68}
69
70impl Comprehension {
71    pub fn domain_of(&self) -> Option<Domain> {
72        let return_expr_domain = self
73            .return_expression_submodel
74            .clone()
75            .into_single_expression()
76            .domain_of()?;
77
78        // return a list (matrix with index domain int(1..)) of return_expr elements
79        Some(Domain::Matrix(
80            Box::new(return_expr_domain),
81            vec![Domain::Int(vec![Range::UnboundedR(1)])],
82        ))
83    }
84
85    /// Expands the comprehension using Minion, returning the resulting expressions.
86    ///
87    /// This method performs simple pruning of the induction variables: an expression is returned
88    /// for each assignment to the induction variables that satisfy the static guards of the
89    /// comprehension. If the comprehension is inside an associative-commutative operation, use
90    /// [`expand_ac`] instead, as this performs further pruning of "uninteresting" return values.
91    ///
92    /// If successful, this modifies the symbol table given to add aux-variables needed inside the
93    /// expanded expressions.
94    pub fn expand_simple(self, symtab: &mut SymbolTable) -> Result<Vec<Expression>, SolverError> {
95        let minion = Solver::new(crate::solver::adaptors::Minion::new());
96        // FIXME: weave proper context through
97        let mut model = Model::new(Arc::new(RwLock::new(Context::default())));
98
99        // only branch on the induction variables.
100        model.search_order = Some(self.induction_vars.clone());
101        *model.as_submodel_mut() = self.generator_submodel.clone();
102
103        // TODO:  if expand_ac is enabled, add Better_AC_Comprehension_Expansion here.
104
105        // call rewrite here as well as in expand_ac, just to be consistent
106        let extra_rule_sets = &["Base", "Constant", "Bubble"];
107
108        let rule_sets =
109            resolve_rule_sets(crate::solver::SolverFamily::Minion, extra_rule_sets).unwrap();
110
111        let model = if USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS.load(Ordering::Relaxed) {
112            rewrite_morph(model, &rule_sets, false)
113        } else {
114            rewrite_naive(&model, &rule_sets, false, false).unwrap()
115        };
116
117        // HACK: also call the rewriter to rewrite inside the comprehension
118        //
119        // The original idea was to let the top level rewriter rewrite the return expression model
120        // and the generator model. The comprehension wouldn't be expanded until the generator
121        // model is in valid minion that can be ran, at which point the return expression model
122        // should also be in valid minion.
123        //
124        // By calling the rewriter inside the rule, we no longer wait for the generator model to be
125        // valid Minion, so we don't get the simplified return model either...
126        //
127        // We need to do this as we want to modify the generator model (add the dummy Z's) then
128        // solve and return in one go.
129        //
130        // Comprehensions need a big rewrite soon, as theres lots of sharp edges such as this in
131        // my original implementation, and I don't think we can fit our new optimisation into it.
132        // If we wanted to avoid calling the rewriter, we would need to run the first half the rule
133        // up to adding the return expr to the generator model, yield, then come back later to
134        // actually solve it?
135
136        let return_expression_submodel = self.return_expression_submodel.clone();
137        let mut return_expression_model = Model::new(Arc::new(RwLock::new(Context::default())));
138        *return_expression_model.as_submodel_mut() = return_expression_submodel;
139
140        let return_expression_model =
141            if USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS.load(Ordering::Relaxed) {
142                rewrite_morph(return_expression_model, &rule_sets, false)
143            } else {
144                rewrite_naive(&return_expression_model, &rule_sets, false, false).unwrap()
145            };
146
147        let minion = minion.load_model(model.clone())?;
148
149        let values = Arc::new(Mutex::new(Vec::new()));
150        let values_ptr = Arc::clone(&values);
151
152        tracing::debug!(model=%model,comprehension=%self,"Minion solving comprehension (simple mode)");
153        minion.solve(Box::new(move |sols| {
154            // TODO: deal with represented names if induction variables are abslits.
155            let values = &mut *values_ptr.lock().unwrap();
156            values.push(sols);
157            true
158        }))?;
159
160        let values = values.lock().unwrap().clone();
161
162        let mut return_expressions = vec![];
163
164        for value in values {
165            // convert back to an expression
166
167            let return_expression_submodel = return_expression_model.as_submodel().clone();
168            let child_symtab = return_expression_submodel.symbols().clone();
169            let return_expression = return_expression_submodel.into_single_expression();
170
171            // we only want to substitute induction variables.
172            // (definitely not machine names, as they mean something different in this scope!)
173            let value: HashMap<_, _> = value
174                .into_iter()
175                .filter(|(n, _)| self.induction_vars.contains(n))
176                .collect();
177
178            let value_ptr = Arc::new(value);
179            let value_ptr_2 = Arc::clone(&value_ptr);
180
181            // substitute in the values for the induction variables
182            let return_expression = return_expression.transform_bi(&move |x: Atom| {
183                let Atom::Reference(ref ptr) = x else {
184                    return x;
185                };
186
187                // is this referencing an induction var?
188                let Some(lit) = value_ptr_2.get(&ptr.name()) else {
189                    return x;
190                };
191
192                Atom::Literal(lit.clone())
193            });
194
195            // Copy the return expression's symbols into parent scope.
196
197            // For variables in the return expression with machine names, create new declarations
198            // for them in the parent symbol table, so that the machine names used are unique.
199            //
200            // Store the declaration translations in `machine_name_translations`.
201            // These are stored as a map of (old declaration id) -> (new declaration ptr), as
202            // declaration pointers do not implement hash.
203            //
204            let mut machine_name_translations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
205
206            // Populate `machine_name_translations`
207            for (name, decl) in child_symtab.into_iter_local() {
208                // do not add givens for induction vars to the parent symbol table.
209                if value_ptr.get(&name).is_some()
210                    && matches!(&decl.kind() as &DeclarationKind, DeclarationKind::Given(_))
211                {
212                    continue;
213                }
214
215                let Name::Machine(_) = &name else {
216                    bug!(
217                        "the symbol table of the return expression of a comprehension should only contain machine names"
218                    );
219                };
220
221                let id = decl.id();
222                let new_decl = symtab.gensym(&decl.domain().unwrap());
223
224                machine_name_translations.insert(id, new_decl);
225            }
226
227            // Update references to use the new delcarations.
228            #[allow(clippy::arc_with_non_send_sync)]
229            let return_expression = return_expression.transform_bi(&move |atom: Atom| {
230                if let Atom::Reference(ref decl) = atom
231                    && let id = decl.id()
232                    && let Some(new_decl) = machine_name_translations.get(&id)
233                {
234                    Atom::Reference(new_decl.clone())
235                } else {
236                    atom
237                }
238            });
239
240            return_expressions.push(return_expression);
241        }
242
243        Ok(return_expressions)
244    }
245
246    /// Expands the comprehension using Minion, returning the resulting expressions.
247    ///
248    /// This method is only suitable for comprehensions inside an AC operator. The AC operator that
249    /// contains this comprehension should be passed into the `ac_operator` argument.
250    ///
251    /// This method performs additional pruning of "uninteresting" values, only possible when the
252    /// comprehension is inside an AC operator.
253    ///
254    /// If successful, this modifies the symbol table given to add aux-variables needed inside the
255    /// expanded expressions.
256    pub fn expand_ac(
257        self,
258        symtab: &mut SymbolTable,
259        ac_operator: ACOperatorKind,
260    ) -> Result<Vec<Expression>, SolverError> {
261        // ADD RETURN EXPRESSION TO GENERATOR MODEL AS CONSTRAINT
262        // ======================================================
263
264        // References to induction variables in the return expression point to entries in the
265        // return_expression symbol table.
266        //
267        // Change these to point to the corresponding entry in the generator symbol table instead.
268        //
269        // In the generator symbol-table, induction variables are decision variables (as we are
270        // solving for them), but in the return expression symbol table they are givens.
271        let induction_vars_2 = self.induction_vars.clone();
272        let generator_symtab_ptr = Rc::clone(self.generator_submodel.symbols_ptr_unchecked());
273        let return_expression =
274            self.clone()
275                .return_expression()
276                .transform_bi(&move |decl: DeclarationPtr| {
277                    // if this variable is an induction var...
278                    if induction_vars_2.contains(&decl.name()) {
279                        // ... use the generator symbol tables version of it
280
281                        (*generator_symtab_ptr)
282                            .borrow()
283                            .lookup_local(&decl.name())
284                            .unwrap()
285                    } else {
286                        decl
287                    }
288                });
289
290        // Replace all boolean expressions referencing non-induction variables in the return
291        // expression with dummy variables. This allows us to add it as a constraint to the
292        // generator model.
293        let generator_submodel = add_return_expression_to_generator_model(
294            self.generator_submodel.clone(),
295            return_expression,
296            &ac_operator,
297        );
298
299        // REWRITE GENERATOR MODEL AND PASS TO MINION
300        // ==========================================
301
302        let mut generator_model = Model::new(Arc::new(RwLock::new(Context::default())));
303
304        *generator_model.as_submodel_mut() = generator_submodel;
305
306        // only branch on the induction variables.
307        generator_model.search_order = Some(self.induction_vars.clone());
308
309        let extra_rule_sets = &[
310            "Base",
311            "Constant",
312            "Bubble",
313            "Better_AC_Comprehension_Expansion",
314        ];
315
316        let rule_sets =
317            resolve_rule_sets(crate::solver::SolverFamily::Minion, extra_rule_sets).unwrap();
318
319        let generator_model = if USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS.load(Ordering::Relaxed) {
320            rewrite_morph(generator_model, &rule_sets, false)
321        } else {
322            rewrite_naive(&generator_model, &rule_sets, false, false).unwrap()
323        };
324
325        let minion = Solver::new(crate::solver::adaptors::Minion::new());
326        let minion = minion.load_model(generator_model.clone());
327
328        let minion = match minion {
329            Err(e) => {
330                warn!(why=%e,model=%generator_model,"Loading generator model failed, failing expand_ac rule");
331                return Err(e);
332            }
333            Ok(minion) => minion,
334        };
335
336        // REWRITE RETURN EXPRESSION
337        // =========================
338
339        let return_expression_submodel = self.return_expression_submodel.clone();
340        let mut return_expression_model = Model::new(Arc::new(RwLock::new(Context::default())));
341        *return_expression_model.as_submodel_mut() = return_expression_submodel;
342
343        let return_expression_model =
344            if USE_OPTIMISED_REWRITER_FOR_COMPREHENSIONS.load(Ordering::Relaxed) {
345                rewrite_morph(return_expression_model, &rule_sets, false)
346            } else {
347                rewrite_naive(&return_expression_model, &rule_sets, false, false).unwrap()
348            };
349
350        let values = Arc::new(Mutex::new(Vec::new()));
351        let values_ptr = Arc::clone(&values);
352
353        // SOLVE FOR THE INDUCTION VARIABLES, AND SUBSTITUTE INTO THE REWRITTEN RETURN EXPRESSION
354        // ======================================================================================
355
356        tracing::debug!(model=%generator_model,comprehension=%self,"Minion solving comprehnesion (ac mode)");
357
358        minion.solve(Box::new(move |sols| {
359            // TODO: deal with represented names if induction variables are abslits.
360            let values = &mut *values_ptr.lock().unwrap();
361            values.push(sols);
362            true
363        }))?;
364
365        let values = values.lock().unwrap().clone();
366
367        let mut return_expressions = vec![];
368
369        for value in values {
370            // convert back to an expression
371
372            let return_expression_submodel = return_expression_model.as_submodel().clone();
373            let child_symtab = return_expression_submodel.symbols().clone();
374            let return_expression = return_expression_submodel.into_single_expression();
375
376            // we only want to substitute induction variables.
377            // (definitely not machine names, as they mean something different in this scope!)
378            let value: HashMap<_, _> = value
379                .into_iter()
380                .filter(|(n, _)| self.induction_vars.contains(n))
381                .collect();
382
383            let value_ptr = Arc::new(value);
384            let value_ptr_2 = Arc::clone(&value_ptr);
385
386            // substitute in the values for the induction variables
387            let return_expression = return_expression.transform_bi(&move |x: Atom| {
388                let Atom::Reference(ref ptr) = x else {
389                    return x;
390                };
391
392                // is this referencing an induction var?
393                let Some(lit) = value_ptr_2.get(&ptr.name()) else {
394                    return x;
395                };
396
397                Atom::Literal(lit.clone())
398            });
399
400            // Copy the return expression's symbols into parent scope.
401
402            // For variables in the return expression with machine names, create new declarations
403            // for them in the parent symbol table, so that the machine names used are unique.
404            //
405            // Store the declaration translations in `machine_name_translations`.
406            // These are stored as a map of (old declaration id) -> (new declaration ptr), as
407            // declaration pointers do not implement hash.
408            //
409            let mut machine_name_translations: HashMap<ObjId, DeclarationPtr> = HashMap::new();
410
411            // Populate `machine_name_translations`
412            for (name, decl) in child_symtab.into_iter_local() {
413                // do not add givens for induction vars to the parent symbol table.
414                if value_ptr.get(&name).is_some()
415                    && matches!(&decl.kind() as &DeclarationKind, DeclarationKind::Given(_))
416                {
417                    continue;
418                }
419
420                let Name::Machine(_) = &name else {
421                    bug!(
422                        "the symbol table of the return expression of a comprehension should only contain machine names"
423                    );
424                };
425
426                let id = decl.id();
427                let new_decl = symtab.gensym(&decl.domain().unwrap());
428
429                machine_name_translations.insert(id, new_decl);
430            }
431
432            // Update references to use the new delcarations.
433            #[allow(clippy::arc_with_non_send_sync)]
434            let return_expression = return_expression.transform_bi(&move |atom: Atom| {
435                if let Atom::Reference(ref decl) = atom
436                    && let id = decl.id()
437                    && let Some(new_decl) = machine_name_translations.get(&id)
438                {
439                    Atom::Reference(new_decl.clone())
440                } else {
441                    atom
442                }
443            });
444
445            return_expressions.push(return_expression);
446        }
447
448        Ok(return_expressions)
449    }
450
451    pub fn return_expression(self) -> Expression {
452        self.return_expression_submodel.into_single_expression()
453    }
454
455    pub fn replace_return_expression(&mut self, new_expr: Expression) {
456        let new_expr = match new_expr {
457            Expression::And(_, exprs) if (*exprs).clone().unwrap_list().is_some() => {
458                Expression::Root(Metadata::new(), (*exprs).clone().unwrap_list().unwrap())
459            }
460            expr => Expression::Root(Metadata::new(), vec![expr]),
461        };
462
463        *self.return_expression_submodel.root_mut_unchecked() = new_expr;
464    }
465
466    /// Adds a guard to the comprehension. Returns false if the guard does not only reference induction variables.
467    pub fn add_induction_guard(&mut self, guard: Expression) -> bool {
468        if self.is_induction_guard(&guard) {
469            self.generator_submodel.add_constraint(guard);
470            true
471        } else {
472            false
473        }
474    }
475
476    /// True iff expr only references induction variables.
477    pub fn is_induction_guard(&self, expr: &Expression) -> bool {
478        is_induction_guard(&(self.induction_vars.clone().into_iter().collect()), expr)
479    }
480}
481
482impl Display for Comprehension {
483    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
484        let generators: String = self
485            .generator_submodel
486            .symbols()
487            .clone()
488            .into_iter_local()
489            .map(|(name, decl): (Name, DeclarationPtr)| {
490                let domain: Domain = decl.domain().unwrap();
491                (name, domain)
492            })
493            .map(|(name, domain): (Name, Domain)| format!("{name}: {domain}"))
494            .join(",");
495
496        let guards = self
497            .generator_submodel
498            .constraints()
499            .iter()
500            .map(|x| format!("{x}"))
501            .join(",");
502
503        let generators_and_guards = itertools::join([generators, guards], ",");
504
505        let expression = &self.return_expression_submodel;
506        write!(f, "[{expression} | {generators_and_guards}]")
507    }
508}
509
510/// A builder for a comprehension.
511#[derive(Clone, Debug, PartialEq, Eq)]
512pub struct ComprehensionBuilder {
513    guards: Vec<Expression>,
514    // symbol table containing all the generators
515    // for now, this is just used during parsing - a new symbol table is created using this when we initialise the comprehension
516    // this is not ideal, but i am chucking all this code very soon anyways...
517    generator_symboltable: Rc<RefCell<SymbolTable>>,
518    return_expr_symboltable: Rc<RefCell<SymbolTable>>,
519    induction_variables: BTreeSet<Name>,
520}
521
522impl ComprehensionBuilder {
523    pub fn new(symbol_table_ptr: Rc<RefCell<SymbolTable>>) -> Self {
524        ComprehensionBuilder {
525            guards: vec![],
526            generator_symboltable: Rc::new(RefCell::new(SymbolTable::with_parent(
527                symbol_table_ptr.clone(),
528            ))),
529            return_expr_symboltable: Rc::new(RefCell::new(SymbolTable::with_parent(
530                symbol_table_ptr,
531            ))),
532            induction_variables: BTreeSet::new(),
533        }
534    }
535
536    /// The symbol table for the comprehension generators
537    pub fn generator_symboltable(&mut self) -> Rc<RefCell<SymbolTable>> {
538        Rc::clone(&self.generator_symboltable)
539    }
540
541    /// The symbol table for the comprehension return expression
542    pub fn return_expr_symboltable(&mut self) -> Rc<RefCell<SymbolTable>> {
543        Rc::clone(&self.return_expr_symboltable)
544    }
545
546    pub fn guard(mut self, guard: Expression) -> Self {
547        self.guards.push(guard);
548        self
549    }
550
551    pub fn generator(mut self, declaration: DeclarationPtr) -> Self {
552        let name = declaration.name().clone();
553        let domain = declaration.domain().unwrap();
554        assert!(!self.induction_variables.contains(&name));
555
556        self.induction_variables.insert(name.clone());
557
558        // insert into generator symbol table as a variable
559        (*self.generator_symboltable)
560            .borrow_mut()
561            .insert(declaration);
562
563        // insert into return expression symbol table as a given
564        (*self.return_expr_symboltable)
565            .borrow_mut()
566            .insert(DeclarationPtr::new_given(name, domain));
567
568        self
569    }
570
571    /// Creates a comprehension with the given return expression.
572    ///
573    /// If a comprehension kind is not given, comprehension guards containing decision variables
574    /// are invalid, and will cause a panic.
575    pub fn with_return_value(
576        self,
577        mut expression: Expression,
578        comprehension_kind: Option<ComprehensionKind>,
579    ) -> Comprehension {
580        let parent_symboltable = self
581            .generator_symboltable
582            .as_ref()
583            .borrow_mut()
584            .parent_mut_unchecked()
585            .clone()
586            .unwrap();
587        let mut generator_submodel = SubModel::new(parent_symboltable.clone());
588        let mut return_expression_submodel = SubModel::new(parent_symboltable);
589
590        *generator_submodel.symbols_ptr_unchecked_mut() = self.generator_symboltable;
591        *return_expression_submodel.symbols_ptr_unchecked_mut() = self.return_expr_symboltable;
592
593        // TODO:also allow guards that reference lettings and givens.
594
595        let induction_variables = self.induction_variables;
596
597        // only guards referencing induction variables can go inside the comprehension
598        let (mut induction_guards, mut other_guards): (Vec<_>, Vec<_>) = self
599            .guards
600            .into_iter()
601            .partition(|x| is_induction_guard(&induction_variables, x));
602
603        let induction_variables_2 = induction_variables.clone();
604        let generator_symboltable_ptr = generator_submodel.symbols_ptr_unchecked().clone();
605
606        // fix induction guard pointers so that they all point to variables in the generator model
607        induction_guards =
608            Biplate::<DeclarationPtr>::transform_bi(&induction_guards, &move |decl| {
609                if induction_variables_2.contains(&decl.name()) {
610                    (*generator_symboltable_ptr)
611                        .borrow()
612                        .lookup_local(&decl.name())
613                        .unwrap()
614                } else {
615                    decl
616                }
617            })
618            .into_iter()
619            .collect_vec();
620
621        let induction_variables_2 = induction_variables.clone();
622        let return_expr_symboltable_ptr =
623            return_expression_submodel.symbols_ptr_unchecked().clone();
624
625        // fix other guard pointers so that they all point to variables in the return expr model
626        other_guards = Biplate::<DeclarationPtr>::transform_bi(&other_guards, &move |decl| {
627            if induction_variables_2.contains(&decl.name()) {
628                (*return_expr_symboltable_ptr)
629                    .borrow()
630                    .lookup_local(&decl.name())
631                    .unwrap()
632            } else {
633                decl
634            }
635        })
636        .into_iter()
637        .collect_vec();
638
639        // handle guards that reference non-induction variables
640        if !other_guards.is_empty() {
641            let comprehension_kind = comprehension_kind.expect(
642                "if any guards reference decision variables, a comprehension kind should be given",
643            );
644
645            let guard_expr = match other_guards.as_slice() {
646                [x] => x.clone(),
647                xs => Expression::And(Metadata::new(), Moo::new(into_matrix_expr!(xs.to_vec()))),
648            };
649
650            expression = match comprehension_kind {
651                ComprehensionKind::And => {
652                    Expression::Imply(Metadata::new(), Moo::new(guard_expr), Moo::new(expression))
653                }
654                ComprehensionKind::Or => Expression::And(
655                    Metadata::new(),
656                    Moo::new(Expression::And(
657                        Metadata::new(),
658                        Moo::new(matrix_expr![guard_expr, expression]),
659                    )),
660                ),
661
662                ComprehensionKind::Sum => {
663                    panic!("guards that reference decision variables not yet implemented for sum");
664                }
665            }
666        }
667
668        generator_submodel.add_constraints(induction_guards);
669
670        return_expression_submodel.add_constraint(expression);
671
672        Comprehension {
673            return_expression_submodel,
674            generator_submodel,
675            induction_vars: induction_variables.into_iter().collect_vec(),
676        }
677    }
678}
679
680/// True iff the guard only references induction variables.
681fn is_induction_guard(induction_variables: &BTreeSet<Name>, guard: &Expression) -> bool {
682    guard
683        .universe_bi()
684        .iter()
685        .all(|x| induction_variables.contains(x))
686}