You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ntool/nstr/ac/dfa.go

730 lines
19 KiB
Go

package ac
import "unsafe"
type iDFA struct {
atom automaton
}
func (d iDFA) MatchKind() *matchKind {
return d.atom.MatchKind()
}
func (d iDFA) StartState() stateID {
return d.atom.StartState()
}
func (d iDFA) MaxPatternLen() int {
return d.atom.Repr().maxPatternLen
}
func (d iDFA) PatternCount() int {
return d.atom.Repr().patternCount
}
func (d iDFA) Prefilter() prefilter {
return d.atom.Prefilter()
}
func (d iDFA) UsePrefilter() bool {
p := d.Prefilter()
if p == nil {
return false
}
return !p.LooksForNonStartOfMatch()
}
func (d iDFA) OverlappingFindAt(prestate *prefilterState, haystack []byte, at int, stateId *stateID, matchIndex *int) *Match {
return overlappingFindAt(d.atom, prestate, haystack, at, stateId, matchIndex)
}
func (d iDFA) EarliestFindAt(prestate *prefilterState, haystack []byte, at int, stateId *stateID) *Match {
return earliestFindAt(d.atom, prestate, haystack, at, stateId)
}
func (d iDFA) FindAtNoState(prestate *prefilterState, haystack []byte, at int) *Match {
return findAtNoState(d.atom, prestate, haystack, at)
}
func (d iDFA) LeftmostFindAtNoState(prestate *prefilterState, haystack []byte, at int) *Match {
return leftmostFindAtNoState(d.atom, prestate, haystack, at)
}
type iDFABuilder struct {
premultiply bool
byteClasses bool
}
func (d *iDFABuilder) build(nfa *iNFA) iDFA {
var bc byteClasses
if d.byteClasses {
bc = nfa.byteClasses
} else {
bc = singletons()
}
alphabetLen := bc.alphabetLen()
trans := make([]stateID, alphabetLen*len(nfa.states))
for i := range trans {
trans[i] = failedStateID
}
matches := make([][]pattern, len(nfa.states))
var p prefilter
if nfa.prefilter != nil {
p = nfa.prefilter.clone()
}
rep := iRepr{
matchKind: nfa.matchKind,
anchored: nfa.anchored,
premultiplied: false,
startId: nfa.startID,
maxPatternLen: nfa.maxPatternLen,
patternCount: nfa.patternCount,
stateCount: len(nfa.states),
maxMatch: failedStateID,
heapBytes: 0,
prefilter: p,
byteClasses: bc,
trans: trans,
matches: matches,
}
for id := 0; id < len(nfa.states); id += 1 {
rep.matches[id] = append(rep.matches[id], nfa.states[id].matches...)
fail := nfa.states[id].fail
nfa.iterAllTransitions(&bc, stateID(id), func(tr *next) {
if tr.id == failedStateID {
tr.id = nfaNextStateMemoized(nfa, &rep, stateID(id), fail, tr.key)
}
rep.setNextState(stateID(id), tr.key, tr.id)
})
}
rep.shuffleMatchStates()
rep.calculateSize()
if d.premultiply {
rep.premultiply()
if bc.isSingleton() {
return iDFA{&iPremultiplied{rep}}
} else {
return iDFA{&iPremultipliedByteClass{&rep}}
}
}
if bc.isSingleton() {
return iDFA{&iStandard{rep}}
}
return iDFA{&iByteClass{&rep}}
}
type iByteClass struct {
repr *iRepr
}
func (p iByteClass) FindAtNoState(prefilterState *prefilterState, bytes []byte, i int) *Match {
return findAtNoState(p, prefilterState, bytes, i)
}
func (p iByteClass) Repr() *iRepr {
return p.repr
}
func (p iByteClass) MatchKind() *matchKind {
return &p.repr.matchKind
}
func (p iByteClass) Anchored() bool {
return p.repr.anchored
}
func (p iByteClass) Prefilter() prefilter {
return p.repr.prefilter
}
func (p iByteClass) StartState() stateID {
return p.repr.startId
}
func (p iByteClass) IsValid(id stateID) bool {
return int(id) < p.repr.stateCount
}
func (p iByteClass) IsMatchState(id stateID) bool {
return p.repr.isMatchState(id)
}
func (p iByteClass) IsMatchOrDeadState(id stateID) bool {
return p.repr.isMatchStateOrDeadState(id)
}
func (p iByteClass) GetMatch(id stateID, i int, i2 int) *Match {
return p.repr.GetMatch(id, i, i2)
}
func (p iByteClass) MatchCount(id stateID) int {
return p.repr.MatchCount(id)
}
func (p iByteClass) NextState(id stateID, b2 byte) stateID {
alphabetLen := p.repr.byteClasses.alphabetLen()
input := p.repr.byteClasses.bytes[b2]
o := int(id)*alphabetLen + int(input)
return p.repr.trans[o]
}
func (p iByteClass) NextStateNoFail(id stateID, b byte) stateID {
next := p.NextState(id, b)
if next == failedStateID {
panic("automaton should never return fail_id for next state")
}
return next
}
func (p iByteClass) StandardFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match {
return standardFindAt(&p, prefilterState, bytes, i, id)
}
func (p iByteClass) StandardFindAtImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match {
return standardFindAtImp(&p, prefilterState, prefilter, bytes, i, id)
}
func (p iByteClass) LeftmostFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match {
return leftmostFindAt(&p, prefilterState, bytes, i, id)
}
func (p iByteClass) LeftmostFindAtImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match {
return leftmostFindAtImp(&p, prefilterState, prefilter, bytes, i, id)
}
func (p iByteClass) LeftmostFindAtNoState(prefilterState *prefilterState, bytes []byte, i int) *Match {
return leftmostFindAtNoState(&p, prefilterState, bytes, i)
}
func (p iByteClass) LeftmostFindAtNoStateImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int) *Match {
return leftmostFindAtNoStateImp(&p, prefilterState, prefilter, bytes, i)
}
func (p iByteClass) OverlappingFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID, i2 *int) *Match {
return overlappingFindAt(&p, prefilterState, bytes, i, id, i2)
}
func (p iByteClass) EarliestFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match {
return earliestFindAt(&p, prefilterState, bytes, i, id)
}
func (p iByteClass) FindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match {
return findAt(&p, prefilterState, bytes, i, id)
}
type iPremultipliedByteClass struct {
repr *iRepr
}
func (p iPremultipliedByteClass) FindAtNoState(prefilterState *prefilterState, bytes []byte, i int) *Match {
return findAtNoState(p, prefilterState, bytes, i)
}
func (p iPremultipliedByteClass) Repr() *iRepr {
return p.repr
}
func (p iPremultipliedByteClass) MatchKind() *matchKind {
return &p.repr.matchKind
}
func (p iPremultipliedByteClass) Anchored() bool {
return p.repr.anchored
}
func (p iPremultipliedByteClass) Prefilter() prefilter {
return p.repr.prefilter
}
func (p iPremultipliedByteClass) StartState() stateID {
return p.repr.startId
}
func (p iPremultipliedByteClass) IsValid(id stateID) bool {
return (int(id) / p.repr.alphabetLen()) < p.repr.stateCount
}
func (p iPremultipliedByteClass) IsMatchState(id stateID) bool {
return p.repr.isMatchState(id)
}
func (p iPremultipliedByteClass) IsMatchOrDeadState(id stateID) bool {
return p.repr.isMatchStateOrDeadState(id)
}
func (p iPremultipliedByteClass) GetMatch(id stateID, matchIndex int, end int) *Match {
if id > p.repr.maxMatch {
return nil
}
m := p.repr.matches[int(id)/p.repr.alphabetLen()][matchIndex]
return &Match{
pattern: m.PatternID,
len: m.PatternLength,
end: end,
}
}
func (p iPremultipliedByteClass) MatchCount(id stateID) int {
o := int(id) / p.repr.alphabetLen()
return len(p.repr.matches[o])
}
func (p iPremultipliedByteClass) NextState(id stateID, b byte) stateID {
input := p.repr.byteClasses.bytes[b]
o := int(id) + int(input)
return p.repr.trans[o]
}
func (p iPremultipliedByteClass) NextStateNoFail(id stateID, b byte) stateID {
// TODO this leaks garbage
n := p.NextState(id, b)
if n == failedStateID {
panic("automaton should never return fail_id for next state")
}
return n
}
func (p iPremultipliedByteClass) StandardFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match {
return standardFindAt(&p, prefilterState, bytes, i, id)
}
func (p iPremultipliedByteClass) StandardFindAtImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match {
return standardFindAtImp(&p, prefilterState, prefilter, bytes, i, id)
}
func (p iPremultipliedByteClass) LeftmostFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match {
return leftmostFindAt(&p, prefilterState, bytes, i, id)
}
func (p iPremultipliedByteClass) LeftmostFindAtImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match {
return leftmostFindAtImp(&p, prefilterState, prefilter, bytes, i, id)
}
func (p iPremultipliedByteClass) LeftmostFindAtNoState(prefilterState *prefilterState, bytes []byte, i int) *Match {
return leftmostFindAtNoState(&p, prefilterState, bytes, i)
}
func (p iPremultipliedByteClass) LeftmostFindAtNoStateImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int) *Match {
return leftmostFindAtNoStateImp(&p, prefilterState, prefilter, bytes, i)
}
func (p iPremultipliedByteClass) OverlappingFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID, i2 *int) *Match {
return overlappingFindAt(&p, prefilterState, bytes, i, id, i2)
}
func (p iPremultipliedByteClass) EarliestFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match {
return earliestFindAt(&p, prefilterState, bytes, i, id)
}
func (p iPremultipliedByteClass) FindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match {
return findAt(&p, prefilterState, bytes, i, id)
}
type iPremultiplied struct {
repr iRepr
}
func (p iPremultiplied) FindAtNoState(prefilterState *prefilterState, bytes []byte, i int) *Match {
return findAtNoState(p, prefilterState, bytes, i)
}
func (p iPremultiplied) Repr() *iRepr {
return &p.repr
}
func (p iPremultiplied) MatchKind() *matchKind {
return &p.repr.matchKind
}
func (p iPremultiplied) Anchored() bool {
return p.repr.anchored
}
func (p iPremultiplied) Prefilter() prefilter {
return p.repr.prefilter
}
func (p iPremultiplied) StartState() stateID {
return p.repr.startId
}
func (p iPremultiplied) IsValid(id stateID) bool {
return int(id)/256 < p.repr.stateCount
}
func (p iPremultiplied) IsMatchState(id stateID) bool {
return p.repr.isMatchState(id)
}
func (p iPremultiplied) IsMatchOrDeadState(id stateID) bool {
return p.repr.isMatchStateOrDeadState(id)
}
func (p iPremultiplied) GetMatch(id stateID, matchIndex int, end int) *Match {
if id > p.repr.maxMatch {
return nil
}
m := p.repr.matches[int(id)/256][matchIndex]
return &Match{
pattern: m.PatternID,
len: m.PatternLength,
end: end,
}
}
func (p iPremultiplied) MatchCount(id stateID) int {
return len(p.repr.matches[int(id)/256])
}
func (p iPremultiplied) NextState(id stateID, b byte) stateID {
o := int(id) + int(b)
return p.repr.trans[o]
}
func (p iPremultiplied) NextStateNoFail(id stateID, b byte) stateID {
next := p.NextState(id, b)
if next == failedStateID {
panic("automaton should never return fail_id for next state")
}
return next
}
func (p iPremultiplied) StandardFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match {
return standardFindAt(&p, prefilterState, bytes, i, id)
}
func (p iPremultiplied) StandardFindAtImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match {
return standardFindAtImp(&p, prefilterState, prefilter, bytes, i, id)
}
func (p iPremultiplied) LeftmostFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match {
return leftmostFindAt(&p, prefilterState, bytes, i, id)
}
func (p iPremultiplied) LeftmostFindAtImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match {
return leftmostFindAtImp(&p, prefilterState, prefilter, bytes, i, id)
}
func (p iPremultiplied) LeftmostFindAtNoState(prefilterState *prefilterState, bytes []byte, i int) *Match {
return leftmostFindAtNoState(&p, prefilterState, bytes, i)
}
func (p iPremultiplied) LeftmostFindAtNoStateImp(prefilterState *prefilterState, prefilter prefilter, bytes []byte, i int) *Match {
return leftmostFindAtNoStateImp(&p, prefilterState, prefilter, bytes, i)
}
func (p iPremultiplied) OverlappingFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID, i2 *int) *Match {
return overlappingFindAt(&p, prefilterState, bytes, i, id, i2)
}
func (p iPremultiplied) EarliestFindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match {
return earliestFindAt(&p, prefilterState, bytes, i, id)
}
func (p iPremultiplied) FindAt(prefilterState *prefilterState, bytes []byte, i int, id *stateID) *Match {
return findAt(&p, prefilterState, bytes, i, id)
}
func nfaNextStateMemoized(nfa *iNFA, dfa *iRepr, populating stateID, current stateID, input byte) stateID {
for {
if current < populating {
return dfa.nextState(current, input)
}
next := nfa.states[current].nextState(input)
if next != failedStateID {
return next
}
current = nfa.states[current].fail
}
}
func newDFABuilder() *iDFABuilder {
return &iDFABuilder{
premultiply: true,
byteClasses: true,
}
}
type iStandard struct {
repr iRepr
}
func (p *iStandard) FindAtNoState(prefilterState *prefilterState, bytes []byte, i int) *Match {
return findAtNoState(p, prefilterState, bytes, i)
}
func (p *iStandard) Repr() *iRepr {
return &p.repr
}
func (p *iStandard) MatchKind() *matchKind {
return &p.repr.matchKind
}
func (p *iStandard) Anchored() bool {
return p.repr.anchored
}
func (p *iStandard) Prefilter() prefilter {
return p.repr.prefilter
}
func (p *iStandard) StartState() stateID {
return p.repr.startId
}
func (p *iStandard) IsValid(id stateID) bool {
return int(id) < p.repr.stateCount
}
func (p *iStandard) IsMatchState(id stateID) bool {
return p.repr.isMatchState(id)
}
func (p *iStandard) IsMatchOrDeadState(id stateID) bool {
return p.repr.isMatchStateOrDeadState(id)
}
func (p *iStandard) GetMatch(id stateID, matchIndex int, end int) *Match {
return p.repr.GetMatch(id, matchIndex, end)
}
func (p *iStandard) MatchCount(id stateID) int {
return p.repr.MatchCount(id)
}
func (p *iStandard) NextState(current stateID, input byte) stateID {
o := int(current)*256 + int(input)
return p.repr.trans[o]
}
func (p *iStandard) NextStateNoFail(id stateID, b byte) stateID {
next := p.NextState(id, b)
if next == failedStateID {
panic("automaton should never return fail_id for next state")
}
return next
}
func (p *iStandard) StandardFindAt(state *prefilterState, bytes []byte, i int, id *stateID) *Match {
return standardFindAt(p, state, bytes, i, id)
}
func (p *iStandard) StandardFindAtImp(state *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match {
return standardFindAtImp(p, state, prefilter, bytes, i, id)
}
func (p *iStandard) LeftmostFindAt(state *prefilterState, bytes []byte, i int, id *stateID) *Match {
return leftmostFindAt(p, state, bytes, i, id)
}
func (p *iStandard) LeftmostFindAtImp(state *prefilterState, prefilter prefilter, bytes []byte, i int, id *stateID) *Match {
return leftmostFindAtImp(p, state, prefilter, bytes, i, id)
}
func (p *iStandard) LeftmostFindAtNoState(state *prefilterState, bytes []byte, i int) *Match {
return leftmostFindAtNoState(p, state, bytes, i)
}
func (p *iStandard) LeftmostFindAtNoStateImp(state *prefilterState, prefilter prefilter, bytes []byte, i int) *Match {
return leftmostFindAtNoStateImp(p, state, prefilter, bytes, i)
}
func (p *iStandard) OverlappingFindAt(state *prefilterState, bytes []byte, i int, id *stateID, i2 *int) *Match {
return overlappingFindAt(p, state, bytes, i, id, i2)
}
func (p *iStandard) EarliestFindAt(state *prefilterState, bytes []byte, i int, id *stateID) *Match {
return earliestFindAt(p, state, bytes, i, id)
}
func (p *iStandard) FindAt(state *prefilterState, bytes []byte, i int, id *stateID) *Match {
return findAt(p, state, bytes, i, id)
}
type iRepr struct {
matchKind matchKind
anchored bool
premultiplied bool
startId stateID
maxPatternLen int
patternCount int
stateCount int
maxMatch stateID
heapBytes int
prefilter prefilter
byteClasses byteClasses
trans []stateID
matches [][]pattern
}
func (r *iRepr) premultiply() {
if r.premultiplied || r.stateCount <= 1 {
return
}
alphaLen := r.alphabetLen()
for id := 2; id < r.stateCount; id++ {
offset := id * alphaLen
slice := r.trans[offset : offset+alphaLen]
for i := range slice {
if slice[i] == deadStateID {
continue
}
slice[i] = stateID(int(slice[i]) * alphaLen)
}
}
r.premultiplied = true
r.startId = stateID(int(r.startId) * alphaLen)
r.maxMatch = stateID(int(r.maxMatch) * alphaLen)
}
func (r *iRepr) setNextState(from stateID, b byte, to stateID) {
alphabetLen := r.alphabetLen()
b = r.byteClasses.bytes[b]
r.trans[int(from)*alphabetLen+int(b)] = to
}
func (r *iRepr) alphabetLen() int {
return r.byteClasses.alphabetLen()
}
func (r *iRepr) nextState(from stateID, b byte) stateID {
alphabetLen := r.alphabetLen()
b = r.byteClasses.bytes[b]
return r.trans[int(from)*alphabetLen+int(b)]
}
func (r *iRepr) isMatchState(id stateID) bool {
return id <= r.maxMatch && id > deadStateID
}
func (r *iRepr) isMatchStateOrDeadState(id stateID) bool {
return id <= r.maxMatch
}
func (r *iRepr) GetMatch(id stateID, matchIndex int, end int) *Match {
i := int(id)
if id > r.maxMatch {
return nil
}
if i > len(r.matches) {
return nil
}
matches := r.matches[int(id)]
if matchIndex > len(matches) {
return nil
}
pattern := matches[matchIndex]
return &Match{
pattern: pattern.PatternID,
len: pattern.PatternLength,
end: end,
}
}
func (r *iRepr) MatchCount(id stateID) int {
return len(r.matches[id])
}
func (r *iRepr) swapStates(id1 stateID, id2 stateID) {
if r.premultiplied {
panic("cannot shuffle match states of premultiplied iDFA")
}
o1 := int(id1) * r.alphabetLen()
o2 := int(id2) * r.alphabetLen()
for b := 0; b < r.alphabetLen(); b++ {
r.trans[o1+b], r.trans[o2+b] = r.trans[o2+b], r.trans[o1+b]
}
r.matches[int(id1)], r.matches[int(id2)] = r.matches[int(id2)], r.matches[int(id1)]
}
func (r *iRepr) calculateSize() {
intSize := int(unsafe.Sizeof(stateID(1)))
size := (len(r.trans) * intSize) + (len(r.matches) * (intSize * 3))
for _, stateMatches := range r.matches {
size += len(stateMatches) * (intSize * 2)
}
var hb int
if r.prefilter != nil {
hb = r.prefilter.HeapBytes()
}
size += hb
r.heapBytes = size
}
func (r *iRepr) shuffleMatchStates() {
if r.premultiplied {
panic("cannot shuffle match states of premultiplied iDFA")
}
if r.stateCount <= 1 {
return
}
firstNonMatch := int(r.startId)
for firstNonMatch < r.stateCount && len(r.matches[firstNonMatch]) > 0 {
firstNonMatch += 1
}
swaps := make([]stateID, r.stateCount)
for i := range swaps {
swaps[i] = failedStateID
}
cur := r.stateCount - 1
for cur > firstNonMatch {
if len(r.matches[cur]) > 0 {
r.swapStates(stateID(cur), stateID(firstNonMatch))
swaps[cur] = stateID(firstNonMatch)
swaps[firstNonMatch] = stateID(cur)
firstNonMatch += 1
for firstNonMatch < cur && len(r.matches[firstNonMatch]) > 0 {
firstNonMatch += 1
}
}
cur -= 1
}
for id := 0; id < r.stateCount; id++ {
alphabetLen := r.alphabetLen()
offset := id * alphabetLen
slice := r.trans[offset : offset+alphabetLen]
for i := range slice {
if swaps[slice[i]] != failedStateID {
slice[i] = swaps[slice[i]]
}
}
}
if swaps[r.startId] != failedStateID {
r.startId = swaps[r.startId]
}
r.maxMatch = stateID(firstNonMatch - 1)
}
type pattern struct {
PatternID int
PatternLength int
}