smallwrite: aho-corasick construction for literals

This commit is contained in:
Justin Viiret 2017-03-31 14:04:44 +11:00 committed by Matthew Barr
parent b75b169b49
commit d4c66e294b
2 changed files with 333 additions and 96 deletions

View File

@ -26,6 +26,11 @@
* POSSIBILITY OF SUCH DAMAGE. * POSSIBILITY OF SUCH DAMAGE.
*/ */
/**
* \file
* \brief Small-write engine build code.
*/
#include "smallwrite/smallwrite_build.h" #include "smallwrite/smallwrite_build.h"
#include "grey.h" #include "grey.h"
@ -48,6 +53,7 @@
#include "util/alloc.h" #include "util/alloc.h"
#include "util/bytecode_ptr.h" #include "util/bytecode_ptr.h"
#include "util/charreach.h" #include "util/charreach.h"
#include "util/compare.h"
#include "util/compile_context.h" #include "util/compile_context.h"
#include "util/container.h" #include "util/container.h"
#include "util/make_unique.h" #include "util/make_unique.h"
@ -60,36 +66,58 @@
#include <vector> #include <vector>
#include <utility> #include <utility>
#include <boost/graph/breadth_first_search.hpp>
using namespace std; using namespace std;
namespace ue2 { namespace ue2 {
#define LITERAL_MERGE_CHUNK_SIZE 25
#define DFA_MERGE_MAX_STATES 8000 #define DFA_MERGE_MAX_STATES 8000
#define MAX_TRIE_VERTICES 8000 #define MAX_TRIE_VERTICES 8000
namespace { // unnamed
struct LitTrieVertexProps { struct LitTrieVertexProps {
LitTrieVertexProps() = default; LitTrieVertexProps() = default;
explicit LitTrieVertexProps(char c_in) : c(c_in) {} explicit LitTrieVertexProps(u8 c_in) : c(c_in) {}
char c = 0;
size_t index; // managed by ue2_graph size_t index; // managed by ue2_graph
u8 c = 0; //!< character reached on this vertex
flat_set<ReportID> reports; //!< managed reports fired on this vertex
}; };
struct LitTrieEdgeProps { struct LitTrieEdgeProps {
LitTrieEdgeProps() = default;
size_t index; // managed by ue2_graph size_t index; // managed by ue2_graph
}; };
/**
* \brief BGL graph used to store a trie of literals (for later AC construction
* into a DFA).
*/
struct LitTrie struct LitTrie
: public ue2_graph<LitTrie, LitTrieVertexProps, LitTrieEdgeProps> { : public ue2_graph<LitTrie, LitTrieVertexProps, LitTrieEdgeProps> {
LitTrie() : root(add_vertex(*this)) {} LitTrie() : root(add_vertex(*this)) {}
const vertex_descriptor root; const vertex_descriptor root; //!< Root vertex for the trie.
}; };
static
bool is_empty(const LitTrie &trie) {
return num_vertices(trie) <= 1;
}
static
std::set<ReportID> all_reports(const LitTrie &trie) {
std::set<ReportID> reports;
for (auto v : vertices_range(trie)) {
insert(&reports, trie[v].reports);
}
return reports;
}
using LitTrieVertex = LitTrie::vertex_descriptor;
using LitTrieEdge = LitTrie::edge_descriptor;
namespace { // unnamed
// Concrete impl class // Concrete impl class
class SmallWriteBuildImpl : public SmallWriteBuild { class SmallWriteBuildImpl : public SmallWriteBuild {
public: public:
@ -110,15 +138,15 @@ public:
const CompileContext &cc; const CompileContext &cc;
unique_ptr<raw_dfa> rdfa; unique_ptr<raw_dfa> rdfa;
vector<pair<ue2_literal, ReportID> > cand_literals;
LitTrie lit_trie; LitTrie lit_trie;
LitTrie lit_trie_nocase; LitTrie lit_trie_nocase;
size_t num_literals = 0;
bool poisoned; bool poisoned;
}; };
} // namespace } // namespace
SmallWriteBuild::~SmallWriteBuild() { } SmallWriteBuild::~SmallWriteBuild() = default;
SmallWriteBuildImpl::SmallWriteBuildImpl(size_t num_patterns, SmallWriteBuildImpl::SmallWriteBuildImpl(size_t num_patterns,
const ReportManager &rm_in, const ReportManager &rm_in,
@ -272,25 +300,27 @@ void SmallWriteBuildImpl::add(const NGHolder &g, const ExpressionInfo &expr) {
} }
static static
bool add_to_trie(const ue2_literal &literal, LitTrie &trie) { bool add_to_trie(const ue2_literal &literal, ReportID report, LitTrie &trie) {
auto u = trie.root; auto u = trie.root;
for (auto &c : literal) { for (const auto &c : literal) {
auto next = LitTrie::null_vertex(); auto next = LitTrie::null_vertex();
for (auto v : adjacent_vertices_range(u, trie)) { for (auto v : adjacent_vertices_range(u, trie)) {
if (trie[v].c == c.c) { if (trie[v].c == (u8)c.c) {
next = v; next = v;
break; break;
} }
} }
if (next == LitTrie::null_vertex()) { if (!next) {
next = add_vertex(LitTrieVertexProps(c.c), trie); next = add_vertex(LitTrieVertexProps((u8)c.c), trie);
add_edge(u, next, trie); add_edge(u, next, trie);
} }
u = next; u = next;
} }
DEBUG_PRINTF("added '%s' to trie, now %zu vertices\n", trie[u].reports.insert(report);
escapeString(literal).c_str(), num_vertices(trie));
DEBUG_PRINTF("added '%s' (report %u) to trie, now %zu vertices\n",
escapeString(literal).c_str(), report, num_vertices(trie));
return num_vertices(trie) <= MAX_TRIE_VERTICES; return num_vertices(trie) <= MAX_TRIE_VERTICES;
} }
@ -298,105 +328,310 @@ void SmallWriteBuildImpl::add(const ue2_literal &literal, ReportID r) {
// If the graph is poisoned (i.e. we can't build a SmallWrite version), // If the graph is poisoned (i.e. we can't build a SmallWrite version),
// we don't even try. // we don't even try.
if (poisoned) { if (poisoned) {
DEBUG_PRINTF("poisoned\n");
return; return;
} }
if (literal.length() > cc.grey.smallWriteLargestBuffer) { if (literal.length() > cc.grey.smallWriteLargestBuffer) {
DEBUG_PRINTF("exceeded length limit\n");
return; /* too long */ return; /* too long */
} }
cand_literals.push_back(make_pair(literal, r)); if (++num_literals > cc.grey.smallWriteMaxLiterals) {
DEBUG_PRINTF("exceeded literal limit\n");
if (!add_to_trie(literal,
literal.any_nocase() ? lit_trie_nocase : lit_trie)) {
poisoned = true; poisoned = true;
return; return;
} }
if (cand_literals.size() > cc.grey.smallWriteMaxLiterals) { auto &trie = literal.any_nocase() ? lit_trie_nocase : lit_trie;
if (!add_to_trie(literal, r, trie)) {
DEBUG_PRINTF("trie add failed\n");
poisoned = true; poisoned = true;
} }
} }
static namespace {
void lit_to_graph(NGHolder *h, const ue2_literal &literal, ReportID r) {
NFAVertex u = h->startDs; /**
for (const auto &c : literal) { * \brief BFS visitor for Aho-Corasick automaton construction.
NFAVertex v = add_vertex(*h); *
add_edge(u, v, *h); * This is doing two things:
(*h)[v].char_reach = c; *
u = v; * - Computing the failure edges (also called fall or supply edges) for each
* vertex, giving the longest suffix of the path to that point that is also
* a prefix in the trie reached on the same character. The BFS traversal
* makes it possible to build these from earlier failure paths.
*
* - Computing the output function for each vertex, which is done by
* propagating the reports from failure paths as well. This ensures that
* substrings of the current path also report correctly.
*/
struct ACVisitor : public boost::default_bfs_visitor {
ACVisitor(LitTrie &trie_in,
map<LitTrieVertex, LitTrieVertex> &failure_map_in,
vector<LitTrieVertex> &ordering_in)
: mutable_trie(trie_in), failure_map(failure_map_in),
ordering(ordering_in) {}
LitTrieVertex find_failure_target(LitTrieVertex u, LitTrieVertex v,
const LitTrie &trie) {
assert(u == trie.root || contains(failure_map, u));
assert(!contains(failure_map, v));
const auto &c = trie[v].c;
while (u != trie.root) {
auto f = failure_map.at(u);
for (auto w : adjacent_vertices_range(f, trie)) {
if (trie[w].c == c) {
return w;
}
}
u = f;
}
DEBUG_PRINTF("no failure edge\n");
return LitTrie::null_vertex();
} }
(*h)[u].reports.insert(r);
add_edge(u, h->accept, *h); void tree_edge(LitTrieEdge e, const LitTrie &trie) {
auto u = source(e, trie);
auto v = target(e, trie);
DEBUG_PRINTF("bfs (%zu, %zu) on '%c'\n", trie[u].index, trie[v].index,
trie[v].c);
ordering.push_back(v);
auto f = find_failure_target(u, v, trie);
if (f) {
DEBUG_PRINTF("final failure vertex %zu\n", trie[f].index);
failure_map.emplace(v, f);
// Propagate reports from failure path to ensure we correctly
// report substrings.
insert(&mutable_trie[v].reports, mutable_trie[f].reports);
} else {
DEBUG_PRINTF("final failure vertex root\n");
failure_map.emplace(v, trie.root);
}
}
private:
LitTrie &mutable_trie; //!< For setting reports property.
map<LitTrieVertex, LitTrieVertex> &failure_map;
vector<LitTrieVertex> &ordering; //!< BFS ordering for vertices.
};
}
static UNUSED
bool isSaneTrie(const LitTrie &trie) {
CharReach seen;
for (auto u : vertices_range(trie)) {
seen.clear();
for (auto v : adjacent_vertices_range(u, trie)) {
if (seen.test(trie[v].c)) {
return false;
}
seen.set(trie[v].c);
}
}
return true;
}
/**
* \brief Turn the given literal trie into an AC automaton by adding additional
* edges and reports.
*/
static
void buildAutomaton(LitTrie &trie) {
assert(isSaneTrie(trie));
// Find our failure transitions and reports.
map<LitTrieVertex, LitTrieVertex> failure_map;
vector<LitTrieVertex> ordering;
ACVisitor ac_vis(trie, failure_map, ordering);
boost::breadth_first_search(trie, trie.root, visitor(ac_vis));
// Compute missing edges from failure map.
for (auto v : ordering) {
DEBUG_PRINTF("vertex %zu\n", trie[v].index);
CharReach seen;
for (auto w : adjacent_vertices_range(v, trie)) {
DEBUG_PRINTF("edge to %zu with reach 0x%02x\n", trie[w].index,
trie[w].c);
assert(!seen.test(trie[w].c));
seen.set(trie[w].c);
}
auto parent = failure_map.at(v);
for (auto w : adjacent_vertices_range(parent, trie)) {
if (!seen.test(trie[w].c)) {
add_edge(v, w, trie);
}
}
}
}
static
vector<CharReach> getAlphabet(const LitTrie &trie, bool nocase) {
vector<CharReach> esets = {CharReach::dot()};
for (auto v : vertices_range(trie)) {
if (v == trie.root) {
continue;
}
CharReach cr;
if (nocase) {
cr.set(mytoupper(trie[v].c));
cr.set(mytolower(trie[v].c));
} else {
cr.set(trie[v].c);
}
for (size_t i = 0; i < esets.size(); i++) {
if (esets[i].count() == 1) {
continue;
}
CharReach t = cr & esets[i];
if (t.any() && t != esets[i]) {
esets[i] &= ~t;
esets.push_back(t);
}
}
}
// For deterministic compiles.
sort(esets.begin(), esets.end());
return esets;
}
static
u16 buildAlphabet(const LitTrie &trie, bool nocase,
array<u16, ALPHABET_SIZE> &alpha,
array<u16, ALPHABET_SIZE> &unalpha) {
const auto &esets = getAlphabet(trie, nocase);
u16 i = 0;
for (const auto &cr : esets) {
u16 leader = cr.find_first();
for (size_t s = cr.find_first(); s != cr.npos; s = cr.find_next(s)) {
alpha[s] = i;
}
unalpha[i] = leader;
i++;
}
for (u16 j = N_CHARS; j < ALPHABET_SIZE; j++, i++) {
alpha[j] = i;
unalpha[i] = j;
}
DEBUG_PRINTF("alphabet size %u\n", i);
return i;
}
/** \brief Construct a raw_dfa from a literal trie. */
static
unique_ptr<raw_dfa> buildDfa(LitTrie &trie, bool nocase) {
DEBUG_PRINTF("trie has %zu states\n", num_vertices(trie));
buildAutomaton(trie);
auto rdfa = make_unique<raw_dfa>(NFA_OUTFIX);
// Calculate alphabet.
array<u16, ALPHABET_SIZE> unalpha;
auto &alpha = rdfa->alpha_remap;
rdfa->alpha_size = buildAlphabet(trie, nocase, alpha, unalpha);
// Construct states and transitions.
const u16 root_state = DEAD_STATE + 1;
rdfa->start_anchored = root_state;
rdfa->start_floating = root_state;
rdfa->states.resize(num_vertices(trie) + 1, dstate(rdfa->alpha_size));
// Dead state.
fill(rdfa->states[DEAD_STATE].next.begin(),
rdfa->states[DEAD_STATE].next.end(), DEAD_STATE);
for (auto u : vertices_range(trie)) {
auto u_state = trie[u].index + 1;
DEBUG_PRINTF("state %zu\n", u_state);
assert(u_state < rdfa->states.size());
auto &ds = rdfa->states[u_state];
ds.daddy = root_state;
ds.reports = trie[u].reports;
if (!ds.reports.empty()) {
DEBUG_PRINTF("reports: %s\n", as_string_list(ds.reports).c_str());
}
// By default, transition back to the root.
fill(ds.next.begin(), ds.next.end(), root_state);
// TOP should be a self-loop.
ds.next[alpha[TOP]] = u_state;
// Add in the real transitions.
for (auto v : adjacent_vertices_range(u, trie)) {
if (v == trie.root) {
continue;
}
auto v_state = trie[v].index + 1;
assert((u16)trie[v].c < alpha.size());
u16 sym = alpha[trie[v].c];
DEBUG_PRINTF("edge to %zu on 0x%02x (sym %u)\n", v_state,
trie[v].c, sym);
assert(sym < ds.next.size());
assert(ds.next[sym] == root_state);
ds.next[sym] = v_state;
}
}
return rdfa;
} }
bool SmallWriteBuildImpl::determiniseLiterals() { bool SmallWriteBuildImpl::determiniseLiterals() {
DEBUG_PRINTF("handling literals\n"); DEBUG_PRINTF("handling literals\n");
assert(!poisoned); assert(!poisoned);
assert(cand_literals.size() <= cc.grey.smallWriteMaxLiterals); assert(num_literals <= cc.grey.smallWriteMaxLiterals);
if (cand_literals.empty()) { if (is_empty(lit_trie) && is_empty(lit_trie_nocase)) {
DEBUG_PRINTF("no literals\n");
return true; /* nothing to do */ return true; /* nothing to do */
} }
vector<unique_ptr<raw_dfa> > temp_dfas; vector<unique_ptr<raw_dfa>> dfas;
for (const auto &cand : cand_literals) { if (!is_empty(lit_trie)) {
NGHolder h; dfas.push_back(buildDfa(lit_trie, false));
DEBUG_PRINTF("determinising %s\n", dumpString(cand.first).c_str()); DEBUG_PRINTF("caseful literal dfa with %zu states\n",
lit_to_graph(&h, cand.first, cand.second); dfas.back()->states.size());
temp_dfas.push_back(buildMcClellan(h, &rm, cc.grey)); }
if (!is_empty(lit_trie_nocase)) {
// If we couldn't build a McClellan DFA for this portion, then we dfas.push_back(buildDfa(lit_trie_nocase, true));
// can't SmallWrite optimize the entire graph, so we can't DEBUG_PRINTF("nocase literal dfa with %zu states\n",
// optimize any of it dfas.back()->states.size());
if (!temp_dfas.back()) {
DEBUG_PRINTF("failed to determinise\n");
poisoned = true;
return false;
}
} }
if (!rdfa && temp_dfas.size() == 1) { if (rdfa) {
/* no need to merge there is only one dfa */ dfas.push_back(move(rdfa));
rdfa = move(temp_dfas[0]); DEBUG_PRINTF("general dfa with %zu states\n",
dfas.back()->states.size());
}
// If we only have one DFA, no merging is necessary.
if (dfas.size() == 1) {
DEBUG_PRINTF("only one dfa\n");
rdfa = move(dfas.front());
return true; return true;
} }
/* do a merge of the new dfas */ // Merge all DFAs.
vector<const raw_dfa *> to_merge; vector<const raw_dfa *> to_merge;
for (const auto &d : dfas) {
if (rdfa) {/* also include the existing dfa */
to_merge.push_back(rdfa.get());
}
for (const auto &d : temp_dfas) {
to_merge.push_back(d.get()); to_merge.push_back(d.get());
} }
assert(to_merge.size() > 1);
while (to_merge.size() > LITERAL_MERGE_CHUNK_SIZE) {
vector<const raw_dfa *> small_merge;
small_merge.insert(small_merge.end(), to_merge.begin(),
to_merge.begin() + LITERAL_MERGE_CHUNK_SIZE);
temp_dfas.push_back(
mergeAllDfas(small_merge, DFA_MERGE_MAX_STATES, &rm, cc.grey));
if (!temp_dfas.back()) {
DEBUG_PRINTF("merge failed\n");
poisoned = true;
return false;
}
to_merge.erase(to_merge.begin(),
to_merge.begin() + LITERAL_MERGE_CHUNK_SIZE);
to_merge.push_back(temp_dfas.back().get());
}
auto merged = mergeAllDfas(to_merge, DFA_MERGE_MAX_STATES, &rm, cc.grey); auto merged = mergeAllDfas(to_merge, DFA_MERGE_MAX_STATES, &rm, cc.grey);
if (!merged) { if (!merged) {
@ -405,11 +640,11 @@ bool SmallWriteBuildImpl::determiniseLiterals() {
return false; return false;
} }
DEBUG_PRINTF("merge succeeded, built %p\n", merged.get()); DEBUG_PRINTF("merge succeeded, built dfa with %zu states\n",
merged->states.size());
// Replace our only DFA with the merged one // Replace our only DFA with the merged one.
rdfa = move(merged); rdfa = move(merged);
return true; return true;
} }
@ -527,7 +762,7 @@ unique_ptr<SmallWriteBuild> makeSmallWriteBuilder(size_t num_patterns,
} }
bytecode_ptr<SmallWriteEngine> SmallWriteBuildImpl::build(u32 roseQuality) { bytecode_ptr<SmallWriteEngine> SmallWriteBuildImpl::build(u32 roseQuality) {
if (!rdfa && cand_literals.empty()) { if (!rdfa && is_empty(lit_trie) && is_empty(lit_trie_nocase)) {
DEBUG_PRINTF("no smallwrite engine\n"); DEBUG_PRINTF("no smallwrite engine\n");
poisoned = true; poisoned = true;
return nullptr; return nullptr;
@ -579,9 +814,10 @@ set<ReportID> SmallWriteBuildImpl::all_reports() const {
if (rdfa) { if (rdfa) {
insert(&reports, ::ue2::all_reports(*rdfa)); insert(&reports, ::ue2::all_reports(*rdfa));
} }
for (const auto &cand : cand_literals) {
reports.insert(cand.second); insert(&reports, ::ue2::all_reports(lit_trie));
} insert(&reports, ::ue2::all_reports(lit_trie_nocase));
return reports; return reports;
} }

View File

@ -30,13 +30,14 @@
#define SMWR_BUILD_H #define SMWR_BUILD_H
/** /**
* SmallWrite Build interface. Everything you ever needed to feed literals in * \file
* and get a SmallWriteEngine out. This header should be everything needed by * \brief Small-write engine build interface.
* the rest of UE2. *
* Everything you ever needed to feed literals in and get a SmallWriteEngine
* out. This header should be everything needed by the rest of UE2.
*/ */
#include "ue2common.h" #include "ue2common.h"
#include "util/alloc.h"
#include "util/bytecode_ptr.h" #include "util/bytecode_ptr.h"
#include "util/noncopyable.h" #include "util/noncopyable.h"
@ -53,14 +54,14 @@ class ExpressionInfo;
class NGHolder; class NGHolder;
class ReportManager; class ReportManager;
// Abstract interface intended for callers from elsewhere in the tree, real /**
// underlying implementation is SmallWriteBuildImpl in smwr_build_impl.h. * Abstract interface intended for callers from elsewhere in the tree, real
* underlying implementation is SmallWriteBuildImpl in smwr_build_impl.h.
*/
class SmallWriteBuild : noncopyable { class SmallWriteBuild : noncopyable {
public: public:
// Destructor
virtual ~SmallWriteBuild(); virtual ~SmallWriteBuild();
// Construct a runtime implementation.
virtual bytecode_ptr<SmallWriteEngine> build(u32 roseQuality) = 0; virtual bytecode_ptr<SmallWriteEngine> build(u32 roseQuality) = 0;
virtual void add(const NGHolder &g, const ExpressionInfo &expr) = 0; virtual void add(const NGHolder &g, const ExpressionInfo &expr) = 0;
@ -69,7 +70,7 @@ public:
virtual std::set<ReportID> all_reports() const = 0; virtual std::set<ReportID> all_reports() const = 0;
}; };
// Construct a usable SmallWrite builder. /** \brief Construct a usable SmallWrite builder. */
std::unique_ptr<SmallWriteBuild> std::unique_ptr<SmallWriteBuild>
makeSmallWriteBuilder(size_t num_patterns, const ReportManager &rm, makeSmallWriteBuilder(size_t num_patterns, const ReportManager &rm,
const CompileContext &cc); const CompileContext &cc);