瀏覽代碼

More optimizations

1. Remove useless jumps
Cesar Rodas 9 月之前
父節點
當前提交
d830f48c3b
共有 2 個文件被更改,包括 92 次插入20 次删除
  1. 23 13
      utxo/src/filter_expr/compiler.rs
  2. 69 7
      utxo/src/filter_expr/filter.rs

+ 23 - 13
utxo/src/filter_expr/compiler.rs

@@ -130,22 +130,26 @@ impl<'a> Compiler<'a> {
         })
     }
 
-    pub fn resolve_label_to_addr(opcodes: Vec<OpCode>) -> Result<Vec<OpCode>, Error> {
-        let mut pos = 0;
-        let used_labels = opcodes
+    pub fn labels_to_addr(opcodes: &[OpCode]) -> HashMap<Addr, Addr> {
+        let mut pos_without_labels = 0;
+        opcodes
             .iter()
-            .map(|x| match x {
-                OpCode::LABEL(label) => (x, pos),
+            .map(|opcode| match opcode {
+                OpCode::LABEL(label) => (opcode, pos_without_labels),
                 _ => {
-                    pos = pos + 1usize;
-                    (x, pos - 1)
+                    pos_without_labels = pos_without_labels + 1usize;
+                    (opcode, pos_without_labels - 1)
                 }
             })
             .filter_map(|(opcode, pos)| match opcode {
                 OpCode::LABEL(id) => Some((*id, pos.into())),
                 _ => None,
             })
-            .collect::<HashMap<_, _>>();
+            .collect::<HashMap<_, _>>()
+    }
+
+    pub fn resolve_label_to_addr(opcodes: Vec<OpCode>) -> Result<Vec<OpCode>, Error> {
+        let labels_to_addr = Self::labels_to_addr(&opcodes);
 
         opcodes
             .into_iter()
@@ -153,16 +157,22 @@ impl<'a> Compiler<'a> {
             .map(|opcode| {
                 // Rewrite JMP to not use labels but instead addresses
                 Ok(match opcode {
-                    OpCode::JMP(label) => {
-                        OpCode::JMP(*used_labels.get(&label).ok_or(Error::UnknownLabel(*label))?)
-                    }
+                    OpCode::JMP(label) => OpCode::JMP(
+                        *labels_to_addr
+                            .get(&label)
+                            .ok_or(Error::UnknownLabel(*label))?,
+                    ),
                     OpCode::JEQ(register, label) => OpCode::JEQ(
                         register,
-                        *used_labels.get(&label).ok_or(Error::UnknownLabel(*label))?,
+                        *labels_to_addr
+                            .get(&label)
+                            .ok_or(Error::UnknownLabel(*label))?,
                     ),
                     OpCode::JNE(register, label) => OpCode::JNE(
                         register,
-                        *used_labels.get(&label).ok_or(Error::UnknownLabel(*label))?,
+                        *labels_to_addr
+                            .get(&label)
+                            .ok_or(Error::UnknownLabel(*label))?,
                     ),
                     opcode => opcode,
                 })

+ 69 - 7
utxo/src/filter_expr/filter.rs

@@ -58,9 +58,9 @@ impl Filter {
         })
     }
 
-    /// Executes operations that can be done at compile time, and sets the initial_register to the
-    /// result. This is useful for optimizing the program, by executing at `compile` time as
-    /// much as possible.
+    /// Executes operations that can be performed at compile time, and sets the initial_register to
+    /// the result. This is useful for optimizing the program, by executing at `compile` time as
+    /// much as possible, when the terms of the opcodes are not dependent on the runtime state.
     fn calculate_static_values(&mut self) -> bool {
         let mut register = HashMap::new();
         let mut has_changed = false;
@@ -70,6 +70,30 @@ impl Filter {
                 OpCode::LOAD(dst, value) => {
                     register.insert(*dst, value.clone());
                 }
+                OpCode::JEQ(reg, addr) => {
+                    if let Some(Value::Bool(true)) = register.get(reg) {
+                        *opcode = OpCode::JMP(*addr);
+                    }
+                }
+                OpCode::JNE(reg, addr) => {
+                    if let Some(Value::Bool(false)) = register.get(reg) {
+                        *opcode = OpCode::JMP(*addr);
+                    }
+                }
+                OpCode::EQ(dst, reg1, reg2) => {
+                    let value1 = if let Some(value) = register.get(reg1) {
+                        value.clone()
+                    } else {
+                        return;
+                    };
+                    let value2 = if let Some(value) = register.get(reg2) {
+                        value.clone()
+                    } else {
+                        return;
+                    };
+                    let result = value1 == value2;
+                    *opcode = OpCode::LOAD(*dst, result.into());
+                }
                 OpCode::MUL(dst, reg1, reg2)
                 | OpCode::SUB(dst, reg1, reg2)
                 | OpCode::DIV(dst, reg1, reg2)
@@ -169,6 +193,42 @@ impl Filter {
         has_changed
     }
 
+    /// Remove useless jumps
+    ///
+    /// Useless jumps are any kind of jump that is just going to the next line.
+    fn remove_useless_jumps(mut self) -> (Self, bool) {
+        // `pos` is the position of the opcode in the new list of opcodes, where the LABELS are
+        // ignored and not counted
+        let mut pos: usize = 0;
+        let label_to_addr = Compiler::labels_to_addr(&self.opcodes);
+        let old_total_opcodes = self.opcodes.len();
+        let new_opcodes = self
+            .opcodes
+            .into_iter()
+            .map(|opcode| match &opcode {
+                OpCode::LABEL(_) => (pos, opcode),
+                _ => {
+                    pos += 1;
+                    (pos - 1, opcode)
+                }
+            })
+            .filter_map(|(pos, opcode)| match &opcode {
+                OpCode::JEQ(_, addr) | OpCode::JMP(addr) | OpCode::JNE(_, addr) => {
+                    if label_to_addr.get(addr).map(|x| *x) == pos.checked_add(1).map(|r| r.into()) {
+                        None
+                    } else {
+                        Some(opcode)
+                    }
+                }
+                _ => Some(opcode),
+            })
+            .collect::<Vec<_>>();
+        self.opcodes = new_opcodes;
+        let new_total_opcodes = self.opcodes.len();
+
+        (self, old_total_opcodes != new_total_opcodes)
+    }
+
     /// Remove loaded values that are not read by any opcode. This is useful for reducing the size
     /// of the opcodes to be executed
     fn remove_unused_values(mut self) -> (Self, bool) {
@@ -216,7 +276,7 @@ impl Filter {
                 acc
             });
 
-        let total_opcodes = self.opcodes.len();
+        let old_total_opcodes = self.opcodes.len();
 
         // remove unused registers. If the register has not been read by any opcode, then it can be
         // removed.
@@ -239,7 +299,7 @@ impl Filter {
         self.opcodes = new_opcodes;
         let new_total_opcodes = self.opcodes.len();
 
-        (self, total_opcodes != new_total_opcodes)
+        (self, old_total_opcodes != new_total_opcodes)
     }
 
     /// Attempts to optiomize the `raw_opcodes` inside the Filter. Returns a tuple with the new
@@ -248,6 +308,7 @@ impl Filter {
         let has_calculated_static_values = self.calculate_static_values();
         let has_changed_register_addresses = self.assign_unique_register_addresses();
         let (mut new_self, has_removed_unused_values) = self.remove_unused_values();
+        let (mut new_self, has_removed_useless_jumps) = new_self.remove_useless_jumps();
 
         new_self.opcodes_to_execute =
             Compiler::resolve_label_to_addr(new_self.opcodes.clone()).unwrap();
@@ -256,7 +317,8 @@ impl Filter {
             new_self,
             has_calculated_static_values
                 || has_removed_unused_values
-                || has_changed_register_addresses,
+                || has_changed_register_addresses
+                || has_removed_useless_jumps,
         )
     }
 
@@ -308,7 +370,7 @@ mod test {
             r#"
             WHERE
                 $foo = 3 + 2 * 4 / 2 * 298210 + $bar
-                AND 5 = 5
+                AND 25 = 5*5
         "#,
         )
         .unwrap();