conjure_oxide/utils/
testing.rs

1use std::collections::{BTreeMap, HashMap, HashSet};
2use std::fmt::Debug;
3use std::vec;
4
5use conjure_core::ast::records::RecordValue;
6use conjure_core::bug;
7use itertools::Itertools as _;
8use std::fs::File;
9use std::fs::{read_to_string, OpenOptions};
10use std::hash::Hash;
11use std::io::Write;
12use std::sync::{Arc, RwLock};
13use uniplate::Uniplate;
14
15use conjure_core::ast::{AbstractLiteral, Domain, SerdeModel};
16use conjure_core::context::Context;
17use serde_json::{json, Error as JsonError, Value as JsonValue};
18
19use conjure_core::error::Error;
20
21use crate::ast::Name::User;
22use crate::ast::{Literal, Name};
23use crate::utils::conjure::solutions_to_json;
24use crate::utils::json::sort_json_object;
25use crate::utils::misc::to_set;
26use crate::Model as ConjureModel;
27use crate::SolverFamily;
28
29pub fn assert_eq_any_order<T: Eq + Hash + Debug + Clone>(a: &Vec<Vec<T>>, b: &Vec<Vec<T>>) {
30    assert_eq!(a.len(), b.len());
31
32    let mut a_rows: Vec<HashSet<T>> = Vec::new();
33    for row in a {
34        let hash_row = to_set(row);
35        a_rows.push(hash_row);
36    }
37
38    let mut b_rows: Vec<HashSet<T>> = Vec::new();
39    for row in b {
40        let hash_row = to_set(row);
41        b_rows.push(hash_row);
42    }
43
44    println!("{:?},{:?}", a_rows, b_rows);
45    for row in a_rows {
46        assert!(b_rows.contains(&row));
47    }
48}
49
50pub fn serialise_model(model: &ConjureModel) -> Result<String, JsonError> {
51    // A consistent sorting of the keys of json objects
52    // only required for the generated version
53    // since the expected version will already be sorted
54    let serde_model: SerdeModel = model.clone().into();
55    let generated_json = sort_json_object(&serde_json::to_value(serde_model)?, false);
56
57    // serialise to string
58    let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
59
60    Ok(generated_json_str)
61}
62
63pub fn save_model_json(
64    model: &ConjureModel,
65    path: &str,
66    test_name: &str,
67    test_stage: &str,
68) -> Result<(), std::io::Error> {
69    let generated_json_str = serialise_model(model)?;
70    let filename = format!("{path}/{test_name}.generated-{test_stage}.serialised.json");
71    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
72    Ok(())
73}
74
75pub fn save_stats_json(
76    context: Arc<RwLock<Context<'static>>>,
77    path: &str,
78    test_name: &str,
79) -> Result<(), std::io::Error> {
80    #[allow(clippy::unwrap_used)]
81    let stats = context.read().unwrap().clone();
82    let generated_json = sort_json_object(&serde_json::to_value(stats)?, false);
83
84    // serialise to string
85    let generated_json_str = serde_json::to_string_pretty(&generated_json)?;
86
87    File::create(format!("{path}/{test_name}-stats.json"))?
88        .write_all(generated_json_str.as_bytes())?;
89
90    Ok(())
91}
92
93pub fn read_model_json(
94    ctx: &Arc<RwLock<Context<'static>>>,
95    path: &str,
96    test_name: &str,
97    prefix: &str,
98    test_stage: &str,
99) -> Result<ConjureModel, std::io::Error> {
100    let expected_json_str = std::fs::read_to_string(format!(
101        "{path}/{test_name}.{prefix}-{test_stage}.serialised.json"
102    ))?;
103    println!("{path}/{test_name}.{prefix}-{test_stage}.serialised.json");
104    let expected_model: SerdeModel = serde_json::from_str(&expected_json_str)?;
105
106    Ok(expected_model.initialise(ctx.clone()).unwrap())
107}
108
109pub fn minion_solutions_from_json(
110    serialized: &str,
111) -> Result<Vec<HashMap<Name, Literal>>, anyhow::Error> {
112    let json: JsonValue = serde_json::from_str(serialized)?;
113
114    let json_array = json
115        .as_array()
116        .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
117
118    let mut solutions = Vec::new();
119
120    for solution in json_array {
121        let mut sol = HashMap::new();
122        let solution = solution
123            .as_object()
124            .ok_or(Error::Parse("Invalid JSON".to_owned()))?;
125
126        for (var_name, constant) in solution {
127            let constant = match constant {
128                JsonValue::Number(n) => {
129                    let n = n
130                        .as_i64()
131                        .ok_or(Error::Parse("Invalid integer".to_owned()))?;
132                    Literal::Int(n as i32)
133                }
134                JsonValue::Bool(b) => Literal::Bool(*b),
135                _ => return Err(Error::Parse("Invalid constant".to_owned()).into()),
136            };
137
138            sol.insert(User(var_name.into()), constant);
139        }
140
141        solutions.push(sol);
142    }
143
144    Ok(solutions)
145}
146
147/// Writes the minion solutions to a generated JSON file, and returns the JSON structure.
148pub fn save_solutions_json(
149    solutions: &Vec<BTreeMap<Name, Literal>>,
150    path: &str,
151    test_name: &str,
152    solver: SolverFamily,
153) -> Result<JsonValue, std::io::Error> {
154    let json_solutions = solutions_to_json(solutions);
155    let generated_json_str = serde_json::to_string_pretty(&json_solutions)?;
156
157    let solver_name = match solver {
158        SolverFamily::Sat => "sat",
159        SolverFamily::Minion => "minion",
160    };
161
162    let filename = format!("{path}/{test_name}.generated-{solver_name}.solutions.json");
163    File::create(&filename)?.write_all(generated_json_str.as_bytes())?;
164
165    Ok(json_solutions)
166}
167
168pub fn read_solutions_json(
169    path: &str,
170    test_name: &str,
171    prefix: &str,
172    solver: SolverFamily,
173) -> Result<JsonValue, anyhow::Error> {
174    let solver_name = match solver {
175        SolverFamily::Sat => "sat",
176        SolverFamily::Minion => "minion",
177    };
178
179    let expected_json_str = std::fs::read_to_string(format!(
180        "{path}/{test_name}.{prefix}-{solver_name}.solutions.json"
181    ))?;
182
183    let expected_solutions: JsonValue =
184        sort_json_object(&serde_json::from_str(&expected_json_str)?, true);
185
186    Ok(expected_solutions)
187}
188
189/// Reads a rule trace from a file. For the generated prefix, it appends a count message.
190/// Returns the lines of the file as a vector of strings.
191pub fn read_rule_trace(
192    path: &str,
193    test_name: &str,
194    prefix: &str,
195) -> Result<Vec<String>, std::io::Error> {
196    let filename = format!("{path}/{test_name}-{prefix}-rule-trace.json");
197    let mut rules_trace: Vec<String> = read_to_string(&filename)?
198        .lines()
199        .map(String::from)
200        .collect();
201
202    // If prefix is "generated", append the count message
203    if prefix == "generated" {
204        let rule_count = rules_trace.len();
205        let count_message = json!({
206            "message": "Number of rules applied",
207            "count": rule_count
208        });
209        let count_message_string = serde_json::to_string(&count_message)?;
210        rules_trace.push(count_message_string);
211
212        // Overwrite the file with updated content (including the count message)
213        let mut file = OpenOptions::new()
214            .write(true)
215            .truncate(true)
216            .open(&filename)?;
217        writeln!(file, "{}", rules_trace.join("\n"))?;
218    }
219
220    Ok(rules_trace)
221}
222
223/// Reads a human-readable rule trace text file.
224pub fn read_human_rule_trace(
225    path: &str,
226    test_name: &str,
227    prefix: &str,
228) -> Result<Vec<String>, std::io::Error> {
229    let filename = format!("{path}/{test_name}-{prefix}-rule-trace-human.txt");
230    let rules_trace: Vec<String> = read_to_string(&filename)?
231        .lines()
232        .map(String::from)
233        .collect();
234
235    Ok(rules_trace)
236}
237
238#[doc(hidden)]
239pub fn normalize_solutions_for_comparison(
240    input_solutions: &[BTreeMap<Name, Literal>],
241) -> Vec<BTreeMap<Name, Literal>> {
242    let mut normalized = input_solutions.to_vec();
243
244    for solset in &mut normalized {
245        // remove machine names
246        let keys_to_remove: Vec<Name> = solset
247            .keys()
248            .filter(|k| matches!(k, Name::Machine(_)))
249            .cloned()
250            .collect();
251        for k in keys_to_remove {
252            solset.remove(&k);
253        }
254
255        let mut updates = vec![];
256        for (k, v) in solset.clone() {
257            if let Name::User(_) = k {
258                match v {
259                    Literal::Bool(true) => updates.push((k, Literal::Int(1))),
260                    Literal::Bool(false) => updates.push((k, Literal::Int(0))),
261                    Literal::Int(_) => {}
262                    Literal::AbstractLiteral(AbstractLiteral::Matrix(elems, _)) => {
263                        // make all domains the same (this is just in the tester so the types dont
264                        // actually matter)
265
266                        let mut matrix =
267                            AbstractLiteral::Matrix(elems, Box::new(Domain::Int(vec![])));
268                        matrix =
269                            matrix.transform(Arc::new(
270                                move |x: AbstractLiteral<Literal>| match x {
271                                    AbstractLiteral::Matrix(items, _) => {
272                                        let items = items
273                                            .into_iter()
274                                            .map(|x| match x {
275                                                Literal::Bool(false) => Literal::Int(0),
276                                                Literal::Bool(true) => Literal::Int(1),
277                                                x => x,
278                                            })
279                                            .collect_vec();
280
281                                        AbstractLiteral::Matrix(
282                                            items,
283                                            Box::new(Domain::Int(vec![])),
284                                        )
285                                    }
286                                    x => x,
287                                },
288                            ));
289                        updates.push((k, Literal::AbstractLiteral(matrix)));
290                    }
291                    Literal::AbstractLiteral(AbstractLiteral::Tuple(elems)) => {
292                        // just the same as matrix but with tuples instead
293                        // only conversion needed is to convert bools to ints
294                        let mut tuple = AbstractLiteral::Tuple(elems);
295                        tuple =
296                            tuple.transform(Arc::new(move |x: AbstractLiteral<Literal>| match x {
297                                AbstractLiteral::Tuple(items) => {
298                                    let items = items
299                                        .into_iter()
300                                        .map(|x| match x {
301                                            Literal::Bool(false) => Literal::Int(0),
302                                            Literal::Bool(true) => Literal::Int(1),
303                                            x => x,
304                                        })
305                                        .collect_vec();
306
307                                    AbstractLiteral::Tuple(items)
308                                }
309                                x => x,
310                            }));
311                        updates.push((k, Literal::AbstractLiteral(tuple)));
312                    }
313                    Literal::AbstractLiteral(AbstractLiteral::Record(entries)) => {
314                        // just the same as matrix but with tuples instead
315                        // only conversion needed is to convert bools to ints
316                        let mut record = AbstractLiteral::Record(entries);
317                        record =
318                            record.transform(Arc::new(
319                                move |x: AbstractLiteral<Literal>| match x {
320                                    AbstractLiteral::Record(entries) => {
321                                        let entries = entries
322                                            .into_iter()
323                                            .map(|x| {
324                                                let RecordValue { name, value } = x;
325                                                {
326                                                    let value = match value {
327                                                        Literal::Bool(false) => Literal::Int(0),
328                                                        Literal::Bool(true) => Literal::Int(1),
329                                                        x => x,
330                                                    };
331                                                    RecordValue { name, value }
332                                                }
333                                            })
334                                            .collect_vec();
335
336                                        AbstractLiteral::Record(entries)
337                                    }
338                                    x => x,
339                                },
340                            ));
341                        updates.push((k, Literal::AbstractLiteral(record)));
342                    }
343                    e => bug!("unexpected literal type: {e:?}"),
344                }
345            }
346        }
347
348        for (k, v) in updates {
349            solset.insert(k, v);
350        }
351    }
352
353    // Remove duplicates
354    normalized = normalized.into_iter().unique().collect();
355    normalized
356}