ac自动机

󰃭 2016-08-05

AC自动机算法原理及应用

ac自动机的背景

在字符串搜索中,某些场景下——如在一篇长文档中搜索一系列的子串(常见的敏感词过滤),会出现需要对文档进行多次扫描的问题,时间复杂度达到O(mn)。为了提高性能,ac自动机应运而生。

原理

ac自动机的核心思想是减少相同前缀的重复搜索。为了达到这一目的,需要以下条件:

  • 需要一种结构,将具有相同前缀的pattern绑定在一起——Trie树
  • 需要一直规则,在匹配失败的时候能够快速跳转或者回溯——失败指针

Trie树

Trie树又称字典树,它利用树结构的特点,将具有公共前缀的数据聚集到同一个节点之下,从而在搜索的时候省略掉公共前缀的搜索时间。
Trie树需要满足以下几个特征:

  • 根节点无数据,其余节点含单元数据
  • 根节点到叶节点的路径为数据集中的一条数据元素
  • 每个节点的子节点互斥

失败指针

失败指针(Fail)也可以成为跳转指针(Turnto),一般在树结构填充完毕后进行统计,当然也可以在插入的时候动态更新,他可以在匹配失败时快速定位到其他具有相同前缀的数据,或者回溯到上一层的前缀。一般来说,所有的失败指针连接的串都是平行的(除root),而他们的最终归宿又都在root。

实现

/**
 * ac.h
 * Author : erdao
 * Date : 2016/8/5
 */

#ifndef _AC_H__
#define _AC_H__

#include <stdint.h>
#include <unordered_map>
#include <vector>

//< 转换规则,用作map的key
class TransformPolicy {
public:
	virtual uint64_t transform(void *) = 0;
};


// Trie树结构的定义
template <typename valueType, typename tagType, class Trans>
struct TrieNode {
	//unordered_map快速定位子节点
	//using ChildMapType = std::unordered_map<uint64_t, TrieNode *>;

	valueType 	value;
	tagType 	tag;
	bool		flag;

	std::unordered_map<uint64_t, TrieNode *>	children;
	TrieNode		*fail;

	TrieNode()
		: value{}
		, tag{}
		, flag{false}
		, fail{ nullptr }
	{}
};

template <typename valueType, typename tagType, class Trans>
class AC_Automation {
	using ACNode = TrieNode<valueType, tagType, Trans>;

public:
	AC_Automation()
		: pRoot(nullptr) {
		pRoot = new ACNode{};
	}
	~AC_Automation() {
		this->removeNode(pRoot);
		pRoot = nullptr;
	}

	void addData(valueType* pValArray, uint32_t uValNum, tagType vTag);
	void buildFail();

	//返回的格式包括位置和tagType
	std::vector<std::pair<uint32_t, tagType>> Search(valueType *pValArray, uint32_t uValNum);

private:
	void removeNode(ACNode *pNode) {
		if (pNode) {
			for (auto &child : pNode->children) {
				removeNode(child.second);
			}
			pNode->children.clear();
			delete pNode;
		}
	}

private:
	ACNode *pRoot;
};

#endif//_AC_H__
/**
 * ac.cpp
 * Author : erdao
 * Date : 2016/8/5
 */
 
#include <stdint.h>
#include <queue>
#include <vector>

#include "ac.h"

template <typename valueType, typename tagType, class Trans>
void AC_Automation<valueType, tagType, Trans>::addData(valueType* pValArray, uint32_t uValNum, tagType vTag)  {
	using ACNode = TrieNode<valueType, tagType, Trans>;
	
	Trans t{};
	ACNode *p = pRoot;
	for (uint32_t i = 0; i < uValNum; ++i) {
		uint64_t key = t.transform(&pValArray[i]);
		if (p->children.find(key) != p->children.end()) {
			p = p->children[key];
		}
		else {
			p->children[key] = new ACNode{};
			p = p->children[key];
			p->value = pValArray[i];
		}
	}
	p->tag = vTag;
	p->flag = true;
}

template <typename valueType, typename tagType, class Trans>
void AC_Automation<valueType, tagType, Trans>::buildFail() {
	using ACNode = TrieNode<valueType, tagType, Trans>;
	
	//实际上,由于所有的fail指针都相交于pRoot,且相互间平行,不存在交叉情况,可以采用自上而下的遍历方式进行创建
	Trans t{};
	std::queue<ACNode *> qNode;
	qNode.push(pRoot);
	/*pRoot默认指向nullptr,可以作为fail跳转终止条件;另外pRoot本身也是终止条件,按照喜好处理*/

	while (!qNode.empty()) {
		ACNode *p = qNode.front();
		//为p的所有child创建fail
		for (const auto &child : p->children) {
			if (p == pRoot) {
				//深度为1的节点失败都跳转到root
				child.second->fail = pRoot;
			}
			//否则的话,寻找父节点失败指针串中与child值相同的节点
			else{
				uint64_t key = t.transform(&child.second->value);
				ACNode *pFail = p->fail;
				while (pFail) {
					if (pFail->children.find(key) != pFail->children.end()) {
						child.second->fail = pFail->children[key];
						break;
					}
					pFail = pFail->fail;
				}
				if (!pFail)
					child.second->fail = pRoot;
			}
			//将本身处理好的节点放入队列,准备处理子节点
			qNode.push(child.second);
		}
		//将子节点处理好的节点弹出队列
		qNode.pop();
	}	

}

template <typename valueType, typename tagType, class Trans>
std::vector<std::pair<uint32_t, tagType>> AC_Automation<valueType, tagType, Trans>::Search(valueType *pValArray, uint32_t uValNum) {
	using ACNode = TrieNode<valueType, tagType, Trans>;

	//搜索的入口从root开始,不停的迭代Trie树中的fail指针,匹配上时进行记录
	std::vector<std::pair<uint32_t, tagType>> vRes;
	ACNode *p = pRoot;
	Trans t{};
	for (uint32_t i = 0; i < uValNum; ++i) {
		uint64_t key = t.transform(&pValArray[i]);
		//寻找到当前p下fail串中能匹配上的p
		while (p->children.find(key) == p->children.end() && p != pRoot) {
			p = p->fail;
		}
		//如果遍历完没找到的话,那么终止条件是p=pRoot(所有fail最终交会)
		if (p->children.find(key) == p->children.end())
			continue;
		p = p->children[key];

		//如果找到了,检查这一层的fail中是否具有终结标记
		auto pFail = p;
		while (pFail) {
			if (pFail->flag)
				vRes.push_back(std::move (std::make_pair(i, pFail->tag)));
			pFail = pFail->fail;
		}
	}

	return std::move(vRes);
}

在进行自动机的使用时,我们需要给定valueType,tagType以及valueType转换成key的TransformPolicy来进行构造。
以下是一个测试:

//	test.cpp
int main() {
	//转换规则
	class TransChar : public TransformPolicy {
	public:
		uint64_t transform(void *a) { return (uint64_t)*(char *)a; }
	};

	char test_words[][10] {
		"her", "she", "shy", "here", "hi", "he"
	};

	AC_Automation<char, int, TransChar> ac{};
	for (uint32_t i = 0; i < sizeof(test_words) / sizeof(test_words[0]); i++) {
		ac.addData(test_words[i], strlen(test_words[i]), i);
	}
	ac.buildFail();

	const char *query = "Oh, she is there so shy, let's go say hi.";
	auto ret = ac.Search((char *)query, strlen(query));

	std::cout << query << std::endl;
	for (auto &r : ret) {
		uint32_t pos = r.first + 1 - strlen(test_words[r.second]);
		for (uint32_t k = 0; k < pos; ++k)
			std::cout << "-";
		std::cout << test_words[r.second] << std::endl;
	}
	
	return 0;
}

案例的运行结果为:

Oh, she is there so shy, let's go say hi.
----she
-----he
------------he
------------her
------------here
--------------------shy
--------------------------------------hi