1use std::{
2 cell::RefCell,
3 collections::{HashMap, HashSet},
4 fmt::Display,
5 rc::Rc,
6 sync::{Arc, Mutex, RwLock},
7};
8
9use itertools::Itertools as _;
10use serde::{Deserialize, Serialize};
11use uniplate::{derive::Uniplate, Biplate};
12
13use crate::{
14 ast::{Atom, DeclarationKind},
15 bug,
16 context::Context,
17 into_matrix_expr, matrix_expr,
18 metadata::Metadata,
19 solver::{Solver, SolverError},
20};
21
22use super::{Declaration, Domain, Expression, Model, Name, SubModel, SymbolTable};
23
24pub enum ComprehensionKind {
25 Sum,
26 And,
27 Or,
28}
29#[derive(Clone, PartialEq, Eq, Uniplate, Serialize, Deserialize, Debug)]
31#[uniplate(walk_into=[SubModel])]
32#[biplate(to=SubModel)]
33#[biplate(to=Expression,walk_into=[SubModel])]
34pub struct Comprehension {
35 return_expression_submodel: SubModel,
36 generator_submodel: SubModel,
37 induction_vars: Vec<Name>,
38}
39
40impl Comprehension {
41 pub fn domain_of(&self, syms: &SymbolTable) -> Option<Domain> {
42 self.return_expression_submodel
43 .clone()
44 .into_single_expression()
45 .domain_of(syms)
46 }
47
48 pub fn solve_with_minion(
53 self,
54 symtab: &mut SymbolTable,
55 ) -> Result<Vec<Expression>, SolverError> {
56 let minion = Solver::new(crate::solver::adaptors::Minion::new());
57 let mut model = Model::new(Arc::new(RwLock::new(Context::default())));
59
60 model.search_order = Some(self.induction_vars.clone());
62
63 *model.as_submodel_mut() = self.generator_submodel.clone();
64
65 let minion = minion.load_model(model.clone())?;
66
67 let values = Arc::new(Mutex::new(Vec::new()));
68 let values_ptr = Arc::clone(&values);
69
70 tracing::debug!(model=%model.clone(),comprehension=%self.clone(),"Minion solving comprehension");
71 minion.solve(Box::new(move |sols| {
72 let values = &mut *values_ptr.lock().unwrap();
74 values.push(sols);
75 true
76 }))?;
77
78 let values = values.lock().unwrap().clone();
79
80 let mut return_expressions = vec![];
81
82 for value in values {
83 let return_expression_submodel = self.return_expression_submodel.clone();
86 let child_symtab = return_expression_submodel.symbols().clone();
87 let return_expression = return_expression_submodel.into_single_expression();
88
89 let value: HashMap<_, _> = value
92 .into_iter()
93 .filter(|(n, _)| self.induction_vars.contains(n))
94 .collect();
95
96 let value_ptr = Arc::new(value);
97 let value_ptr_2 = Arc::clone(&value_ptr);
98
99 let return_expression = return_expression.transform_bi(Arc::new(move |x: Atom| {
101 let Atom::Reference(ref name) = x else {
102 return x;
103 };
104
105 let Some(lit) = value_ptr_2.get(name) else {
107 return x;
108 };
109
110 Atom::Literal(lit.clone())
111 }));
112
113 let mut machine_name_translations: HashMap<Name, Name> = HashMap::new();
118
119 for (name, decl) in child_symtab.into_iter_local() {
121 if value_ptr.get(&name).is_some()
123 && matches!(decl.kind(), DeclarationKind::Given(_))
124 {
125 continue;
126 }
127
128 let Name::Machine(_) = &name else {
129 bug!("the symbol table of the return expression of a comprehension should only contain machine names");
130 };
131
132 let new_machine_name = symtab.gensym();
133
134 let new_decl = (*decl).clone().with_new_name(new_machine_name.clone());
135 symtab.insert(Rc::new(new_decl)).unwrap();
136
137 machine_name_translations.insert(name, new_machine_name);
138 }
139
140 let return_expression =
142 return_expression.transform_bi(Arc::new(
143 move |name| match machine_name_translations.get(&name) {
144 Some(new_name) => new_name.clone(),
145 None => name,
146 },
147 ));
148
149 return_expressions.push(return_expression);
150 }
151
152 Ok(return_expressions)
153 }
154
155 pub fn return_expression(self) -> Expression {
156 self.return_expression_submodel.into_single_expression()
157 }
158
159 pub fn replace_return_expression(&mut self, new_expr: Expression) {
160 let new_expr = match new_expr {
161 Expression::And(_, exprs) if exprs.clone().unwrap_list().is_some() => {
162 Expression::Root(Metadata::new(), exprs.unwrap_list().unwrap())
163 }
164 expr => Expression::Root(Metadata::new(), vec![expr]),
165 };
166
167 *self.return_expression_submodel.root_mut_unchecked() = new_expr;
168 }
169
170 pub fn add_induction_guard(&mut self, guard: Expression) -> bool {
172 if self.is_induction_guard(&guard) {
173 self.generator_submodel.add_constraint(guard);
174 true
175 } else {
176 false
177 }
178 }
179
180 pub fn is_induction_guard(&self, expr: &Expression) -> bool {
182 is_induction_guard(&(self.induction_vars.clone().into_iter().collect()), expr)
183 }
184}
185
186impl Display for Comprehension {
187 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
188 let generators: String = self
189 .generator_submodel
190 .symbols()
191 .clone()
192 .into_iter_local()
193 .map(|(name, decl)| (name, decl.domain().unwrap().clone()))
194 .map(|(name, domain)| format!("{name}: {domain}"))
195 .join(",");
196
197 let guards = self
198 .generator_submodel
199 .constraints()
200 .iter()
201 .map(|x| format!("{x}"))
202 .join(",");
203
204 let generators_and_guards = itertools::join([generators, guards], ",");
205
206 let expression = &self.return_expression_submodel;
207 write!(f, "[{expression} | {generators_and_guards}]")
208 }
209}
210
211#[derive(Clone, Debug, PartialEq, Eq, Default)]
213pub struct ComprehensionBuilder {
214 guards: Vec<Expression>,
215 generators: Vec<(Name, Domain)>,
216 induction_variables: HashSet<Name>,
217}
218
219impl ComprehensionBuilder {
220 pub fn new() -> Self {
221 Default::default()
222 }
223 pub fn guard(mut self, guard: Expression) -> Self {
224 self.guards.push(guard);
225 self
226 }
227
228 pub fn generator(mut self, name: Name, domain: Domain) -> Self {
229 assert!(!self.induction_variables.contains(&name));
230 self.induction_variables.insert(name.clone());
231 self.generators.push((name, domain));
232 self
233 }
234
235 pub fn with_return_value(
240 self,
241 mut expression: Expression,
242 parent: Rc<RefCell<SymbolTable>>,
243 comprehension_kind: Option<ComprehensionKind>,
244 ) -> Comprehension {
245 let mut generator_submodel = SubModel::new(parent.clone());
246
247 let induction_variables = self.induction_variables;
250
251 let (induction_guards, other_guards): (Vec<_>, Vec<_>) = self
253 .guards
254 .into_iter()
255 .partition(|x| is_induction_guard(&induction_variables, x));
256
257 if !other_guards.is_empty() {
259 let comprehension_kind = comprehension_kind.expect(
260 "if any guards reference decision variables, a comprehension kind should be given",
261 );
262
263 let guard_expr = match other_guards.as_slice() {
264 [x] => x.clone(),
265 xs => Expression::And(Metadata::new(), Box::new(into_matrix_expr!(xs.to_vec()))),
266 };
267
268 expression = match comprehension_kind {
269 ComprehensionKind::And => {
270 Expression::Imply(Metadata::new(), Box::new(guard_expr), Box::new(expression))
271 }
272 ComprehensionKind::Or => Expression::And(
273 Metadata::new(),
274 Box::new(Expression::And(
275 Metadata::new(),
276 Box::new(matrix_expr![guard_expr, expression]),
277 )),
278 ),
279
280 ComprehensionKind::Sum => {
281 panic!("guards that reference decision variables not yet implemented for sum");
282 }
283 }
284 }
285
286 generator_submodel.add_constraints(induction_guards);
287 for (name, domain) in self.generators.clone() {
288 generator_submodel
289 .symbols_mut()
290 .insert(Rc::new(Declaration::new_var(name, domain)));
291 }
292
293 let mut return_expression_submodel = SubModel::new(parent);
302 for (name, domain) in self.generators {
303 return_expression_submodel
304 .symbols_mut()
305 .insert(Rc::new(Declaration::new_given(name, domain)))
306 .unwrap();
307 }
308
309 return_expression_submodel.add_constraint(expression);
310
311 Comprehension {
312 return_expression_submodel,
313 generator_submodel,
314 induction_vars: induction_variables.into_iter().collect_vec(),
315 }
316 }
317}
318
319fn is_induction_guard(induction_variables: &HashSet<Name>, guard: &Expression) -> bool {
321 guard
322 .universe_bi()
323 .iter()
324 .all(|x| induction_variables.contains(x))
325}