From d4c66e294bb9c76d043f41f5db88f001292dcd02 Mon Sep 17 00:00:00 2001 From: Justin Viiret Date: Fri, 31 Mar 2017 14:04:44 +1100 Subject: [PATCH] smallwrite: aho-corasick construction for literals --- src/smallwrite/smallwrite_build.cpp | 410 ++++++++++++++++++++++------ src/smallwrite/smallwrite_build.h | 19 +- 2 files changed, 333 insertions(+), 96 deletions(-) diff --git a/src/smallwrite/smallwrite_build.cpp b/src/smallwrite/smallwrite_build.cpp index f7c9ad8c..a27db736 100644 --- a/src/smallwrite/smallwrite_build.cpp +++ b/src/smallwrite/smallwrite_build.cpp @@ -26,6 +26,11 @@ * POSSIBILITY OF SUCH DAMAGE. */ +/** + * \file + * \brief Small-write engine build code. + */ + #include "smallwrite/smallwrite_build.h" #include "grey.h" @@ -48,6 +53,7 @@ #include "util/alloc.h" #include "util/bytecode_ptr.h" #include "util/charreach.h" +#include "util/compare.h" #include "util/compile_context.h" #include "util/container.h" #include "util/make_unique.h" @@ -60,36 +66,58 @@ #include #include +#include + using namespace std; namespace ue2 { -#define LITERAL_MERGE_CHUNK_SIZE 25 #define DFA_MERGE_MAX_STATES 8000 #define MAX_TRIE_VERTICES 8000 -namespace { // unnamed - struct LitTrieVertexProps { LitTrieVertexProps() = default; - explicit LitTrieVertexProps(char c_in) : c(c_in) {} - char c = 0; + explicit LitTrieVertexProps(u8 c_in) : c(c_in) {} size_t index; // managed by ue2_graph + u8 c = 0; //!< character reached on this vertex + flat_set reports; //!< managed reports fired on this vertex }; struct LitTrieEdgeProps { - LitTrieEdgeProps() = default; size_t index; // managed by ue2_graph }; +/** + * \brief BGL graph used to store a trie of literals (for later AC construction + * into a DFA). + */ struct LitTrie : public ue2_graph { LitTrie() : root(add_vertex(*this)) {} - const vertex_descriptor root; + const vertex_descriptor root; //!< Root vertex for the trie. }; +static +bool is_empty(const LitTrie &trie) { + return num_vertices(trie) <= 1; +} + +static +std::set all_reports(const LitTrie &trie) { + std::set reports; + for (auto v : vertices_range(trie)) { + insert(&reports, trie[v].reports); + } + return reports; +} + +using LitTrieVertex = LitTrie::vertex_descriptor; +using LitTrieEdge = LitTrie::edge_descriptor; + +namespace { // unnamed + // Concrete impl class class SmallWriteBuildImpl : public SmallWriteBuild { public: @@ -110,15 +138,15 @@ public: const CompileContext &cc; unique_ptr rdfa; - vector > cand_literals; LitTrie lit_trie; LitTrie lit_trie_nocase; + size_t num_literals = 0; bool poisoned; }; } // namespace -SmallWriteBuild::~SmallWriteBuild() { } +SmallWriteBuild::~SmallWriteBuild() = default; SmallWriteBuildImpl::SmallWriteBuildImpl(size_t num_patterns, const ReportManager &rm_in, @@ -272,25 +300,27 @@ void SmallWriteBuildImpl::add(const NGHolder &g, const ExpressionInfo &expr) { } static -bool add_to_trie(const ue2_literal &literal, LitTrie &trie) { +bool add_to_trie(const ue2_literal &literal, ReportID report, LitTrie &trie) { auto u = trie.root; - for (auto &c : literal) { + for (const auto &c : literal) { auto next = LitTrie::null_vertex(); for (auto v : adjacent_vertices_range(u, trie)) { - if (trie[v].c == c.c) { + if (trie[v].c == (u8)c.c) { next = v; break; } } - if (next == LitTrie::null_vertex()) { - next = add_vertex(LitTrieVertexProps(c.c), trie); + if (!next) { + next = add_vertex(LitTrieVertexProps((u8)c.c), trie); add_edge(u, next, trie); } u = next; } - DEBUG_PRINTF("added '%s' to trie, now %zu vertices\n", - escapeString(literal).c_str(), num_vertices(trie)); + trie[u].reports.insert(report); + + DEBUG_PRINTF("added '%s' (report %u) to trie, now %zu vertices\n", + escapeString(literal).c_str(), report, num_vertices(trie)); return num_vertices(trie) <= MAX_TRIE_VERTICES; } @@ -298,105 +328,310 @@ void SmallWriteBuildImpl::add(const ue2_literal &literal, ReportID r) { // If the graph is poisoned (i.e. we can't build a SmallWrite version), // we don't even try. if (poisoned) { + DEBUG_PRINTF("poisoned\n"); return; } if (literal.length() > cc.grey.smallWriteLargestBuffer) { + DEBUG_PRINTF("exceeded length limit\n"); return; /* too long */ } - cand_literals.push_back(make_pair(literal, r)); - - if (!add_to_trie(literal, - literal.any_nocase() ? lit_trie_nocase : lit_trie)) { + if (++num_literals > cc.grey.smallWriteMaxLiterals) { + DEBUG_PRINTF("exceeded literal limit\n"); poisoned = true; return; } - if (cand_literals.size() > cc.grey.smallWriteMaxLiterals) { + auto &trie = literal.any_nocase() ? lit_trie_nocase : lit_trie; + if (!add_to_trie(literal, r, trie)) { + DEBUG_PRINTF("trie add failed\n"); poisoned = true; } } -static -void lit_to_graph(NGHolder *h, const ue2_literal &literal, ReportID r) { - NFAVertex u = h->startDs; - for (const auto &c : literal) { - NFAVertex v = add_vertex(*h); - add_edge(u, v, *h); - (*h)[v].char_reach = c; - u = v; +namespace { + +/** + * \brief BFS visitor for Aho-Corasick automaton construction. + * + * This is doing two things: + * + * - Computing the failure edges (also called fall or supply edges) for each + * vertex, giving the longest suffix of the path to that point that is also + * a prefix in the trie reached on the same character. The BFS traversal + * makes it possible to build these from earlier failure paths. + * + * - Computing the output function for each vertex, which is done by + * propagating the reports from failure paths as well. This ensures that + * substrings of the current path also report correctly. + */ +struct ACVisitor : public boost::default_bfs_visitor { + ACVisitor(LitTrie &trie_in, + map &failure_map_in, + vector &ordering_in) + : mutable_trie(trie_in), failure_map(failure_map_in), + ordering(ordering_in) {} + + LitTrieVertex find_failure_target(LitTrieVertex u, LitTrieVertex v, + const LitTrie &trie) { + assert(u == trie.root || contains(failure_map, u)); + assert(!contains(failure_map, v)); + + const auto &c = trie[v].c; + + while (u != trie.root) { + auto f = failure_map.at(u); + for (auto w : adjacent_vertices_range(f, trie)) { + if (trie[w].c == c) { + return w; + } + } + u = f; + } + + DEBUG_PRINTF("no failure edge\n"); + return LitTrie::null_vertex(); } - (*h)[u].reports.insert(r); - add_edge(u, h->accept, *h); + + void tree_edge(LitTrieEdge e, const LitTrie &trie) { + auto u = source(e, trie); + auto v = target(e, trie); + DEBUG_PRINTF("bfs (%zu, %zu) on '%c'\n", trie[u].index, trie[v].index, + trie[v].c); + ordering.push_back(v); + + auto f = find_failure_target(u, v, trie); + + if (f) { + DEBUG_PRINTF("final failure vertex %zu\n", trie[f].index); + failure_map.emplace(v, f); + + // Propagate reports from failure path to ensure we correctly + // report substrings. + insert(&mutable_trie[v].reports, mutable_trie[f].reports); + } else { + DEBUG_PRINTF("final failure vertex root\n"); + failure_map.emplace(v, trie.root); + } + } + +private: + LitTrie &mutable_trie; //!< For setting reports property. + map &failure_map; + vector &ordering; //!< BFS ordering for vertices. +}; +} + +static UNUSED +bool isSaneTrie(const LitTrie &trie) { + CharReach seen; + for (auto u : vertices_range(trie)) { + seen.clear(); + for (auto v : adjacent_vertices_range(u, trie)) { + if (seen.test(trie[v].c)) { + return false; + } + seen.set(trie[v].c); + } + } + return true; +} + +/** + * \brief Turn the given literal trie into an AC automaton by adding additional + * edges and reports. + */ +static +void buildAutomaton(LitTrie &trie) { + assert(isSaneTrie(trie)); + + // Find our failure transitions and reports. + map failure_map; + vector ordering; + ACVisitor ac_vis(trie, failure_map, ordering); + boost::breadth_first_search(trie, trie.root, visitor(ac_vis)); + + // Compute missing edges from failure map. + for (auto v : ordering) { + DEBUG_PRINTF("vertex %zu\n", trie[v].index); + CharReach seen; + for (auto w : adjacent_vertices_range(v, trie)) { + DEBUG_PRINTF("edge to %zu with reach 0x%02x\n", trie[w].index, + trie[w].c); + assert(!seen.test(trie[w].c)); + seen.set(trie[w].c); + } + auto parent = failure_map.at(v); + for (auto w : adjacent_vertices_range(parent, trie)) { + if (!seen.test(trie[w].c)) { + add_edge(v, w, trie); + } + } + } +} + +static +vector getAlphabet(const LitTrie &trie, bool nocase) { + vector esets = {CharReach::dot()}; + for (auto v : vertices_range(trie)) { + if (v == trie.root) { + continue; + } + + CharReach cr; + if (nocase) { + cr.set(mytoupper(trie[v].c)); + cr.set(mytolower(trie[v].c)); + } else { + cr.set(trie[v].c); + } + + for (size_t i = 0; i < esets.size(); i++) { + if (esets[i].count() == 1) { + continue; + } + + CharReach t = cr & esets[i]; + if (t.any() && t != esets[i]) { + esets[i] &= ~t; + esets.push_back(t); + } + } + } + + // For deterministic compiles. + sort(esets.begin(), esets.end()); + return esets; +} + +static +u16 buildAlphabet(const LitTrie &trie, bool nocase, + array &alpha, + array &unalpha) { + const auto &esets = getAlphabet(trie, nocase); + + u16 i = 0; + for (const auto &cr : esets) { + u16 leader = cr.find_first(); + for (size_t s = cr.find_first(); s != cr.npos; s = cr.find_next(s)) { + alpha[s] = i; + } + unalpha[i] = leader; + i++; + } + + for (u16 j = N_CHARS; j < ALPHABET_SIZE; j++, i++) { + alpha[j] = i; + unalpha[i] = j; + } + + DEBUG_PRINTF("alphabet size %u\n", i); + return i; +} + +/** \brief Construct a raw_dfa from a literal trie. */ +static +unique_ptr buildDfa(LitTrie &trie, bool nocase) { + DEBUG_PRINTF("trie has %zu states\n", num_vertices(trie)); + + buildAutomaton(trie); + + auto rdfa = make_unique(NFA_OUTFIX); + + // Calculate alphabet. + array unalpha; + auto &alpha = rdfa->alpha_remap; + rdfa->alpha_size = buildAlphabet(trie, nocase, alpha, unalpha); + + // Construct states and transitions. + const u16 root_state = DEAD_STATE + 1; + rdfa->start_anchored = root_state; + rdfa->start_floating = root_state; + rdfa->states.resize(num_vertices(trie) + 1, dstate(rdfa->alpha_size)); + + // Dead state. + fill(rdfa->states[DEAD_STATE].next.begin(), + rdfa->states[DEAD_STATE].next.end(), DEAD_STATE); + + for (auto u : vertices_range(trie)) { + auto u_state = trie[u].index + 1; + DEBUG_PRINTF("state %zu\n", u_state); + assert(u_state < rdfa->states.size()); + auto &ds = rdfa->states[u_state]; + ds.daddy = root_state; + ds.reports = trie[u].reports; + + if (!ds.reports.empty()) { + DEBUG_PRINTF("reports: %s\n", as_string_list(ds.reports).c_str()); + } + + // By default, transition back to the root. + fill(ds.next.begin(), ds.next.end(), root_state); + // TOP should be a self-loop. + ds.next[alpha[TOP]] = u_state; + + // Add in the real transitions. + for (auto v : adjacent_vertices_range(u, trie)) { + if (v == trie.root) { + continue; + } + auto v_state = trie[v].index + 1; + assert((u16)trie[v].c < alpha.size()); + u16 sym = alpha[trie[v].c]; + DEBUG_PRINTF("edge to %zu on 0x%02x (sym %u)\n", v_state, + trie[v].c, sym); + assert(sym < ds.next.size()); + assert(ds.next[sym] == root_state); + ds.next[sym] = v_state; + } + } + + return rdfa; } bool SmallWriteBuildImpl::determiniseLiterals() { DEBUG_PRINTF("handling literals\n"); assert(!poisoned); - assert(cand_literals.size() <= cc.grey.smallWriteMaxLiterals); + assert(num_literals <= cc.grey.smallWriteMaxLiterals); - if (cand_literals.empty()) { + if (is_empty(lit_trie) && is_empty(lit_trie_nocase)) { + DEBUG_PRINTF("no literals\n"); return true; /* nothing to do */ } - vector > temp_dfas; + vector> dfas; - for (const auto &cand : cand_literals) { - NGHolder h; - DEBUG_PRINTF("determinising %s\n", dumpString(cand.first).c_str()); - lit_to_graph(&h, cand.first, cand.second); - temp_dfas.push_back(buildMcClellan(h, &rm, cc.grey)); - - // If we couldn't build a McClellan DFA for this portion, then we - // can't SmallWrite optimize the entire graph, so we can't - // optimize any of it - if (!temp_dfas.back()) { - DEBUG_PRINTF("failed to determinise\n"); - poisoned = true; - return false; - } + if (!is_empty(lit_trie)) { + dfas.push_back(buildDfa(lit_trie, false)); + DEBUG_PRINTF("caseful literal dfa with %zu states\n", + dfas.back()->states.size()); + } + if (!is_empty(lit_trie_nocase)) { + dfas.push_back(buildDfa(lit_trie_nocase, true)); + DEBUG_PRINTF("nocase literal dfa with %zu states\n", + dfas.back()->states.size()); } - if (!rdfa && temp_dfas.size() == 1) { - /* no need to merge there is only one dfa */ - rdfa = move(temp_dfas[0]); + if (rdfa) { + dfas.push_back(move(rdfa)); + DEBUG_PRINTF("general dfa with %zu states\n", + dfas.back()->states.size()); + } + + // If we only have one DFA, no merging is necessary. + if (dfas.size() == 1) { + DEBUG_PRINTF("only one dfa\n"); + rdfa = move(dfas.front()); return true; } - /* do a merge of the new dfas */ - + // Merge all DFAs. vector to_merge; - - if (rdfa) {/* also include the existing dfa */ - to_merge.push_back(rdfa.get()); - } - - for (const auto &d : temp_dfas) { + for (const auto &d : dfas) { to_merge.push_back(d.get()); } - assert(to_merge.size() > 1); - - while (to_merge.size() > LITERAL_MERGE_CHUNK_SIZE) { - vector small_merge; - small_merge.insert(small_merge.end(), to_merge.begin(), - to_merge.begin() + LITERAL_MERGE_CHUNK_SIZE); - - temp_dfas.push_back( - mergeAllDfas(small_merge, DFA_MERGE_MAX_STATES, &rm, cc.grey)); - - if (!temp_dfas.back()) { - DEBUG_PRINTF("merge failed\n"); - poisoned = true; - return false; - } - - to_merge.erase(to_merge.begin(), - to_merge.begin() + LITERAL_MERGE_CHUNK_SIZE); - to_merge.push_back(temp_dfas.back().get()); - } - auto merged = mergeAllDfas(to_merge, DFA_MERGE_MAX_STATES, &rm, cc.grey); if (!merged) { @@ -405,11 +640,11 @@ bool SmallWriteBuildImpl::determiniseLiterals() { return false; } - DEBUG_PRINTF("merge succeeded, built %p\n", merged.get()); + DEBUG_PRINTF("merge succeeded, built dfa with %zu states\n", + merged->states.size()); - // Replace our only DFA with the merged one + // Replace our only DFA with the merged one. rdfa = move(merged); - return true; } @@ -527,7 +762,7 @@ unique_ptr makeSmallWriteBuilder(size_t num_patterns, } bytecode_ptr SmallWriteBuildImpl::build(u32 roseQuality) { - if (!rdfa && cand_literals.empty()) { + if (!rdfa && is_empty(lit_trie) && is_empty(lit_trie_nocase)) { DEBUG_PRINTF("no smallwrite engine\n"); poisoned = true; return nullptr; @@ -579,9 +814,10 @@ set SmallWriteBuildImpl::all_reports() const { if (rdfa) { insert(&reports, ::ue2::all_reports(*rdfa)); } - for (const auto &cand : cand_literals) { - reports.insert(cand.second); - } + + insert(&reports, ::ue2::all_reports(lit_trie)); + insert(&reports, ::ue2::all_reports(lit_trie_nocase)); + return reports; } diff --git a/src/smallwrite/smallwrite_build.h b/src/smallwrite/smallwrite_build.h index 92222d62..648b13db 100644 --- a/src/smallwrite/smallwrite_build.h +++ b/src/smallwrite/smallwrite_build.h @@ -30,13 +30,14 @@ #define SMWR_BUILD_H /** - * SmallWrite Build interface. Everything you ever needed to feed literals in - * and get a SmallWriteEngine out. This header should be everything needed by - * the rest of UE2. + * \file + * \brief Small-write engine build interface. + * + * Everything you ever needed to feed literals in and get a SmallWriteEngine + * out. This header should be everything needed by the rest of UE2. */ #include "ue2common.h" -#include "util/alloc.h" #include "util/bytecode_ptr.h" #include "util/noncopyable.h" @@ -53,14 +54,14 @@ class ExpressionInfo; class NGHolder; class ReportManager; -// Abstract interface intended for callers from elsewhere in the tree, real -// underlying implementation is SmallWriteBuildImpl in smwr_build_impl.h. +/** + * Abstract interface intended for callers from elsewhere in the tree, real + * underlying implementation is SmallWriteBuildImpl in smwr_build_impl.h. + */ class SmallWriteBuild : noncopyable { public: - // Destructor virtual ~SmallWriteBuild(); - // Construct a runtime implementation. virtual bytecode_ptr build(u32 roseQuality) = 0; virtual void add(const NGHolder &g, const ExpressionInfo &expr) = 0; @@ -69,7 +70,7 @@ public: virtual std::set all_reports() const = 0; }; -// Construct a usable SmallWrite builder. +/** \brief Construct a usable SmallWrite builder. */ std::unique_ptr makeSmallWriteBuilder(size_t num_patterns, const ReportManager &rm, const CompileContext &cc);