Reformat Zeek in Spicy style

This largely copies over Spicy's `.clang-format` configuration file. The
one place where we deviate is header include order since Zeek depends on
headers being included in a certain order.
This commit is contained in:
Benjamin Bannier 2023-10-10 21:13:34 +02:00
parent 7b8e7ed72c
commit f5a76c1aed
786 changed files with 131714 additions and 153609 deletions

View file

@ -1,74 +1,66 @@
# Clang-format configuration for Zeek. This configuration requires # Copyright (c) 2020-2023 by the Zeek Project. See LICENSE for details.
# at least clang-format 12.0.1 to format correctly.
Language: Cpp
Standard: c++17
BreakBeforeBraces: Whitesmiths
# BraceWrapping:
# AfterCaseLabel: true
# AfterClass: false
# AfterControlStatement: Always
# AfterEnum: false
# AfterFunction: true
# AfterNamespace: false
# AfterStruct: false
# AfterUnion: false
# AfterExternBlock: false
# BeforeCatch: true
# BeforeElse: true
# BeforeWhile: false
# IndentBraces: true
# SplitEmptyFunction: false
# SplitEmptyRecord: false
# SplitEmptyNamespace: false
---
Language: Cpp
AccessModifierOffset: -4 AccessModifierOffset: -4
AlignAfterOpenBracket: Align AlignAfterOpenBracket: Align
AlignTrailingComments: false AlignConsecutiveAssignments: false
AllowShortBlocksOnASingleLine: Empty AlignConsecutiveDeclarations: false
AllowShortEnumsOnASingleLine: true AlignEscapedNewlines: Right
AllowShortFunctionsOnASingleLine: Inline AlignOperands: true
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: true
AllowShortIfStatementsOnASingleLine: false AllowShortIfStatementsOnASingleLine: false
AllowShortLambdasOnASingleLine: Empty
AllowShortLoopsOnASingleLine: false AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: Yes
BinPackArguments: true BinPackArguments: true
BinPackParameters: true BinPackParameters: true
BreakConstructorInitializers: BeforeColon BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
AfterExternBlock: false
BeforeCatch: false
BeforeElse: true
IndentBraces: false
SplitEmptyFunction: false
SplitEmptyRecord: false
SplitEmptyNamespace: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Custom
BreakBeforeInheritanceComma: false
BreakInheritanceList: BeforeColon BreakInheritanceList: BeforeColon
ColumnLimit: 100 BreakBeforeTernaryOperators: false
ConstructorInitializerAllOnOneLineOrOnePerLine: false BreakConstructorInitializersBeforeComma: false
FixNamespaceComments: false BreakConstructorInitializers: BeforeColon
IndentCaseLabels: true BreakAfterJavaFieldAnnotations: false
IndentCaseBlocks: false BreakStringLiterals: true
IndentExternBlock: NoIndent ColumnLimit: 120
IndentPPDirectives: None CommentPragmas: 'NOLINT'
IndentWidth: 4 CompactNamespaces: false
NamespaceIndentation: None ConstructorInitializerAllOnOneLineOrOnePerLine: true
PointerAlignment: Left ConstructorInitializerIndentWidth: 4
SpaceAfterCStyleCast: false ContinuationIndentWidth: 4
SpaceAfterLogicalNot: true Cpp11BracedListStyle: true
SpaceBeforeAssignmentOperators: true DerivePointerAlignment: false
SpaceBeforeCpp11BracedList: false DisableFormat: false
SpaceBeforeCtorInitializerColon: true ExperimentalAutoDetectBinPacking: false
SpaceBeforeInheritanceColon: true FixNamespaceComments: true
SpaceBeforeParens: ControlStatements ForEachMacros:
SpaceBeforeRangeBasedForLoopColon: true - foreach
SpaceInEmptyBlock: true - Q_FOREACH
SpaceInEmptyParentheses: false - BOOST_FOREACH
SpacesInAngles: false
SpacesInConditionalStatement: true
SpacesInContainerLiterals: false
SpacesInParentheses: false
TabWidth: 4
UseTab: AlignWithSpaces
# Setting this to a high number causes clang-format to prefer breaking somewhere else
# over breaking after the assignment operator in a line that's over the column limit
PenaltyBreakAssignment: 100
IncludeBlocks: Regroup IncludeBlocks: Regroup
# Include categories go like this: # Include categories go like this:
@ -98,3 +90,55 @@ IncludeCategories:
Priority: 4 Priority: 4
- Regex: '.*' - Regex: '.*'
Priority: 5 Priority: 5
IncludeIsMainRegex: '$'
IndentCaseLabels: true
IndentPPDirectives: None
IndentWidth: 4
IndentWrappedFunctionNames: false
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: '^BEGIN_'
MacroBlockEnd: '^END_'
MaxEmptyLinesToKeep: 2
NamespaceIndentation: None
ObjCBinPackProtocolList: Auto
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: true
PenaltyBreakAssignment: 2
PenaltyBreakBeforeFirstCallParameter: 500
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 1000
PointerAlignment: Left
ReflowComments: true
SortIncludes: true
SortUsingDeclarations: true
SpaceAfterCStyleCast: false
SpaceAfterTemplateKeyword: false
SpaceAfterLogicalNot: true
SpaceBeforeAssignmentOperators: true
SpaceBeforeCpp11BracedList: false
SpaceBeforeCtorInitializerColon: true
SpaceBeforeInheritanceColon: true
SpaceBeforeParens: ControlStatements
SpaceBeforeRangeBasedForLoopColon: true
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
SpacesInConditionalStatement: true
Standard: Cpp11
StatementMacros:
- STANDARD_OPERATOR_1
TabWidth: 4
UseTab: Never
...

@ -1 +1 @@
Subproject commit 7b8eff527f60ec58eff3242253bdc1f5f1fccbef Subproject commit d26c81c0a2982ef81339beebff455c23713fb526

2
cmake

@ -1 +1 @@
Subproject commit 98799bb51aabb282e7dd6372aea7dbcf909469ac Subproject commit f7b4fbe4892594034d3d9ca639c0ffa6a99fcbe5

2
doc

@ -1 +1 @@
Subproject commit 22fe25d980131abdfadb4bdb9390aee347e77023 Subproject commit 01d78f885e6aac4e853a0b5da559b4c849fee743

View file

@ -15,435 +15,392 @@
#include "zeek/net_util.h" #include "zeek/net_util.h"
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek::detail namespace zeek::detail {
{
AnonymizeIPAddr* ip_anonymizer[NUM_ADDR_ANONYMIZATION_METHODS] = {nullptr}; AnonymizeIPAddr* ip_anonymizer[NUM_ADDR_ANONYMIZATION_METHODS] = {nullptr};
static uint32_t rand32() static uint32_t rand32() {
{ return ((util::detail::random_number() & 0xffff) << 16) | (util::detail::random_number() & 0xffff);
return ((util::detail::random_number() & 0xffff) << 16) | }
(util::detail::random_number() & 0xffff);
}
// From tcpdpriv. // From tcpdpriv.
static int bi_ffs(uint32_t value) static int bi_ffs(uint32_t value) {
{ int add = 0;
int add = 0; static uint8_t bvals[] = {0, 4, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1};
static uint8_t bvals[] = {0, 4, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1};
if ( (value & 0xFFFF0000) == 0 ) if ( (value & 0xFFFF0000) == 0 ) {
{ if ( value == 0 )
if ( value == 0 ) // Zero input ==> zero output.
// Zero input ==> zero output. return 0;
return 0;
add += 16; add += 16;
} }
else else
value >>= 16; value >>= 16;
if ( (value & 0xFF00) == 0 ) if ( (value & 0xFF00) == 0 )
add += 8; add += 8;
else else
value >>= 8; value >>= 8;
if ( (value & 0xF0) == 0 ) if ( (value & 0xF0) == 0 )
add += 4; add += 4;
else else
value >>= 4; value >>= 4;
return add + bvals[value & 0xf]; return add + bvals[value & 0xf];
} }
#define first_n_bit_mask(n) (~(0xFFFFFFFFU >> n)) #define first_n_bit_mask(n) (~(0xFFFFFFFFU >> n))
ipaddr32_t AnonymizeIPAddr::Anonymize(ipaddr32_t addr) ipaddr32_t AnonymizeIPAddr::Anonymize(ipaddr32_t addr) {
{ std::map<ipaddr32_t, ipaddr32_t>::iterator p = mapping.find(addr);
std::map<ipaddr32_t, ipaddr32_t>::iterator p = mapping.find(addr); if ( p != mapping.end() )
if ( p != mapping.end() ) return p->second;
return p->second; else {
else ipaddr32_t new_addr = anonymize(addr);
{ mapping[addr] = new_addr;
ipaddr32_t new_addr = anonymize(addr);
mapping[addr] = new_addr;
return new_addr; return new_addr;
} }
} }
// Keep the specified prefix unchanged. // Keep the specified prefix unchanged.
bool AnonymizeIPAddr::PreservePrefix(ipaddr32_t /* input */, int /* num_bits */) bool AnonymizeIPAddr::PreservePrefix(ipaddr32_t /* input */, int /* num_bits */) {
{ reporter->InternalError("prefix preserving is not supported for the anonymizer");
reporter->InternalError("prefix preserving is not supported for the anonymizer"); return false;
return false; }
}
bool AnonymizeIPAddr::PreserveNet(ipaddr32_t input) bool AnonymizeIPAddr::PreserveNet(ipaddr32_t input) {
{ switch ( addr_to_class(ntohl(input)) ) {
switch ( addr_to_class(ntohl(input)) ) case 'A': return PreservePrefix(input, 8);
{ case 'B': return PreservePrefix(input, 16);
case 'A': case 'C': return PreservePrefix(input, 24);
return PreservePrefix(input, 8); default: return false;
case 'B': }
return PreservePrefix(input, 16); }
case 'C':
return PreservePrefix(input, 24);
default:
return false;
}
}
ipaddr32_t AnonymizeIPAddr_Seq::anonymize(ipaddr32_t /* input */) ipaddr32_t AnonymizeIPAddr_Seq::anonymize(ipaddr32_t /* input */) {
{ ++seq;
++seq; return htonl(seq);
return htonl(seq); }
}
ipaddr32_t AnonymizeIPAddr_RandomMD5::anonymize(ipaddr32_t input) ipaddr32_t AnonymizeIPAddr_RandomMD5::anonymize(ipaddr32_t input) {
{ uint8_t digest[16];
uint8_t digest[16]; ipaddr32_t output = 0;
ipaddr32_t output = 0;
util::detail::hmac_md5(sizeof(input), (u_char*)(&input), digest); util::detail::hmac_md5(sizeof(input), (u_char*)(&input), digest);
for ( int i = 0; i < 4; ++i ) for ( int i = 0; i < 4; ++i )
output = (output << 8) | digest[i]; output = (output << 8) | digest[i];
return output; return output;
} }
// This code is from "On the Design and Performance of Prefix-Preserving // This code is from "On the Design and Performance of Prefix-Preserving
// IP Traffic Trace Anonymization", by Xu et al (IMW 2001) // IP Traffic Trace Anonymization", by Xu et al (IMW 2001)
// //
// http://www.imconf.net/imw-2001/proceedings.html // http://www.imconf.net/imw-2001/proceedings.html
ipaddr32_t AnonymizeIPAddr_PrefixMD5::anonymize(ipaddr32_t input) ipaddr32_t AnonymizeIPAddr_PrefixMD5::anonymize(ipaddr32_t input) {
{ uint8_t digest[16];
uint8_t digest[16]; ipaddr32_t prefix_mask = 0xffffffff;
ipaddr32_t prefix_mask = 0xffffffff; input = ntohl(input);
input = ntohl(input); ipaddr32_t output = input;
ipaddr32_t output = input;
for ( int i = 0; i < 32; ++i ) for ( int i = 0; i < 32; ++i ) {
{ // PAD(x_0 ... x_{i-1}) = x_0 ... x_{i-1} 1 0 ... 0 .
// PAD(x_0 ... x_{i-1}) = x_0 ... x_{i-1} 1 0 ... 0 . prefix.len = htonl(i + 1);
prefix.len = htonl(i + 1); prefix.prefix = htonl((input & ~(prefix_mask >> i)) | (1 << (31 - i)));
prefix.prefix = htonl((input & ~(prefix_mask >> i)) | (1 << (31 - i)));
// HK(PAD(x_0 ... x_{i-1})). // HK(PAD(x_0 ... x_{i-1})).
util::detail::hmac_md5(sizeof(prefix), (u_char*)&prefix, digest); util::detail::hmac_md5(sizeof(prefix), (u_char*)&prefix, digest);
// f_{i-1} = LSB(HK(PAD(x_0 ... x_{i-1}))). // f_{i-1} = LSB(HK(PAD(x_0 ... x_{i-1}))).
ipaddr32_t bit_mask = (digest[0] & 1) << (31 - i); ipaddr32_t bit_mask = (digest[0] & 1) << (31 - i);
// x_i' = x_i ^ f_{i-1}. // x_i' = x_i ^ f_{i-1}.
output ^= bit_mask; output ^= bit_mask;
} }
return htonl(output); return htonl(output);
} }
AnonymizeIPAddr_A50::~AnonymizeIPAddr_A50() AnonymizeIPAddr_A50::~AnonymizeIPAddr_A50() {
{ for ( auto& b : blocks )
for ( auto& b : blocks ) delete[] b;
delete[] b; }
}
void AnonymizeIPAddr_A50::init() void AnonymizeIPAddr_A50::init() {
{ root = next_free_node = nullptr;
root = next_free_node = nullptr;
// Prepare special nodes for 0.0.0.0 and 255.255.255.255. // Prepare special nodes for 0.0.0.0 and 255.255.255.255.
memset(&special_nodes[0], 0, sizeof(special_nodes)); memset(&special_nodes[0], 0, sizeof(special_nodes));
special_nodes[0].input = special_nodes[0].output = 0; special_nodes[0].input = special_nodes[0].output = 0;
special_nodes[1].input = special_nodes[1].output = 0xFFFFFFFF; special_nodes[1].input = special_nodes[1].output = 0xFFFFFFFF;
method = 0; method = 0;
before_anonymization = 1; before_anonymization = 1;
new_mapping = 0; new_mapping = 0;
} }
bool AnonymizeIPAddr_A50::PreservePrefix(ipaddr32_t input, int num_bits) bool AnonymizeIPAddr_A50::PreservePrefix(ipaddr32_t input, int num_bits) {
{ DEBUG_MSG("%s/%d\n", IPAddr(IPv4, &input, IPAddr::Network).AsString().c_str(), num_bits);
DEBUG_MSG("%s/%d\n", IPAddr(IPv4, &input, IPAddr::Network).AsString().c_str(), num_bits);
if ( ! before_anonymization ) if ( ! before_anonymization ) {
{ reporter->Error("prefix preservation specified after anonymization begun");
reporter->Error("prefix preservation specified after anonymization begun"); return false;
return false; }
}
input = ntohl(input); input = ntohl(input);
// Sanitize input. // Sanitize input.
input = input & first_n_bit_mask(num_bits); input = input & first_n_bit_mask(num_bits);
Node* n = find_node(input); Node* n = find_node(input);
// Preserve the first num_bits bits of addr. // Preserve the first num_bits bits of addr.
if ( num_bits == 32 ) if ( num_bits == 32 )
n->output = input; n->output = input;
else if ( num_bits > 0 ) else if ( num_bits > 0 ) {
{ assert((0xFFFFFFFFU >> 1) == 0x7FFFFFFFU);
assert((0xFFFFFFFFU >> 1) == 0x7FFFFFFFU); uint32_t suffix_mask = (0xFFFFFFFFU >> num_bits);
uint32_t suffix_mask = (0xFFFFFFFFU >> num_bits); uint32_t prefix_mask = ~suffix_mask;
uint32_t prefix_mask = ~suffix_mask; n->output = (input & prefix_mask) | (rand32() & suffix_mask);
n->output = (input & prefix_mask) | (rand32() & suffix_mask); }
}
return true; return true;
} }
ipaddr32_t AnonymizeIPAddr_A50::anonymize(ipaddr32_t a) ipaddr32_t AnonymizeIPAddr_A50::anonymize(ipaddr32_t a) {
{ before_anonymization = 0;
before_anonymization = 0; new_mapping = 0;
new_mapping = 0;
if ( Node* n = find_node(ntohl(a)) ) if ( Node* n = find_node(ntohl(a)) ) {
{ ipaddr32_t output = htonl(n->output);
ipaddr32_t output = htonl(n->output); return output;
return output; }
} else
else return 0;
return 0; }
}
AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node_block() AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node_block() {
{ assert(! next_free_node);
assert(! next_free_node);
int block_size = 1024; int block_size = 1024;
Node* block = new Node[block_size]; Node* block = new Node[block_size];
if ( ! block ) if ( ! block )
reporter->InternalError("out of memory!"); reporter->InternalError("out of memory!");
blocks.push_back(block); blocks.push_back(block);
for ( int i = 1; i < block_size - 1; ++i ) for ( int i = 1; i < block_size - 1; ++i )
block[i].child[0] = &block[i + 1]; block[i].child[0] = &block[i + 1];
block[block_size - 1].child[0] = nullptr; block[block_size - 1].child[0] = nullptr;
next_free_node = &block[1]; next_free_node = &block[1];
return &block[0]; return &block[0];
} }
inline AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node() inline AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node() {
{ new_mapping = 1;
new_mapping = 1;
if ( next_free_node ) if ( next_free_node ) {
{ Node* n = next_free_node;
Node* n = next_free_node; next_free_node = n->child[0];
next_free_node = n->child[0]; return n;
return n; }
} else
else return new_node_block();
return new_node_block(); }
}
inline void AnonymizeIPAddr_A50::free_node(Node* n) inline void AnonymizeIPAddr_A50::free_node(Node* n) {
{ n->child[0] = next_free_node;
n->child[0] = next_free_node; next_free_node = n;
next_free_node = n; }
}
ipaddr32_t AnonymizeIPAddr_A50::make_output(ipaddr32_t old_output, int swivel) const ipaddr32_t AnonymizeIPAddr_A50::make_output(ipaddr32_t old_output, int swivel) const {
{ // -A50 anonymization
// -A50 anonymization if ( swivel == 32 )
if ( swivel == 32 ) return old_output ^ 1;
return old_output ^ 1; else {
else // Bits up to swivel are unchanged; bit swivel is flipped.
{ ipaddr32_t known_part = ((old_output >> (32 - swivel)) ^ 1) << (32 - swivel);
// Bits up to swivel are unchanged; bit swivel is flipped.
ipaddr32_t known_part = ((old_output >> (32 - swivel)) ^ 1) << (32 - swivel);
// Remainder of bits are random. // Remainder of bits are random.
return known_part | ((rand32() & 0x7FFFFFFF) >> swivel); return known_part | ((rand32() & 0x7FFFFFFF) >> swivel);
} }
} }
AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::make_peer(ipaddr32_t a, Node* n) AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::make_peer(ipaddr32_t a, Node* n) {
{ if ( a == 0 || a == 0xFFFFFFFFU )
if ( a == 0 || a == 0xFFFFFFFFU ) reporter->InternalError("0.0.0.0 and 255.255.255.255 should never get into the tree");
reporter->InternalError("0.0.0.0 and 255.255.255.255 should never get into the tree");
// Become a peer. // Become a peer.
// Algorithm: create two nodes, the two peers. Leave orig node as // Algorithm: create two nodes, the two peers. Leave orig node as
// the parent of the two new ones. // the parent of the two new ones.
Node* down[2]; Node* down[2];
down[0] = new_node(); down[0] = new_node();
if ( ! down[0] ) if ( ! down[0] )
return nullptr; return nullptr;
down[1] = new_node(); down[1] = new_node();
if ( ! down[1] ) if ( ! down[1] ) {
{ free_node(down[0]);
free_node(down[0]); return nullptr;
return nullptr; }
}
// swivel is first bit 'a' and 'old->input' differ. // swivel is first bit 'a' and 'old->input' differ.
int swivel = bi_ffs(a ^ n->input); int swivel = bi_ffs(a ^ n->input);
// bitvalue is the value of that bit of 'a'. // bitvalue is the value of that bit of 'a'.
int bitvalue = (a >> (32 - swivel)) & 1; int bitvalue = (a >> (32 - swivel)) & 1;
down[bitvalue]->input = a; down[bitvalue]->input = a;
down[bitvalue]->output = make_output(n->output, swivel); down[bitvalue]->output = make_output(n->output, swivel);
down[bitvalue]->child[0] = down[bitvalue]->child[1] = nullptr; down[bitvalue]->child[0] = down[bitvalue]->child[1] = nullptr;
*down[1 - bitvalue] = *n; // copy orig node down one level *down[1 - bitvalue] = *n; // copy orig node down one level
n->input = down[1]->input; // NB: 1s to the right (0s to the left) n->input = down[1]->input; // NB: 1s to the right (0s to the left)
n->output = down[1]->output; n->output = down[1]->output;
n->child[0] = down[0]; // point to children n->child[0] = down[0]; // point to children
n->child[1] = down[1]; n->child[1] = down[1];
return down[bitvalue]; return down[bitvalue];
} }
AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::find_node(ipaddr32_t a) AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::find_node(ipaddr32_t a) {
{ // Watch out for special IP addresses, which never make it
// Watch out for special IP addresses, which never make it // into the tree.
// into the tree. if ( a == 0 || a == 0xFFFFFFFFU )
if ( a == 0 || a == 0xFFFFFFFFU ) return &special_nodes[a & 1];
return &special_nodes[a & 1];
if ( ! root ) if ( ! root ) {
{ root = new_node();
root = new_node(); root->input = a;
root->input = a; root->output = rand32();
root->output = rand32(); root->child[0] = root->child[1] = nullptr;
root->child[0] = root->child[1] = nullptr;
return root; return root;
} }
// Straight from tcpdpriv. // Straight from tcpdpriv.
Node* n = root; Node* n = root;
while ( n ) while ( n ) {
{ if ( n->input == a )
if ( n->input == a ) return n;
return n;
if ( ! n->child[0] ) if ( ! n->child[0] )
n = make_peer(a, n); n = make_peer(a, n);
else else {
{ // swivel is the first bit in which the two children
// swivel is the first bit in which the two children // differ.
// differ. int swivel = bi_ffs(n->child[0]->input ^ n->child[1]->input);
int swivel = bi_ffs(n->child[0]->input ^ n->child[1]->input);
if ( bi_ffs(a ^ n->input) < swivel ) if ( bi_ffs(a ^ n->input) < swivel )
// Input differs earlier. // Input differs earlier.
n = make_peer(a, n); n = make_peer(a, n);
else if ( a & (1 << (32 - swivel)) ) else if ( a & (1 << (32 - swivel)) )
n = n->child[1]; n = n->child[1];
else else
n = n->child[0]; n = n->child[0];
} }
} }
reporter->InternalError("out of memory!"); reporter->InternalError("out of memory!");
return nullptr; return nullptr;
} }
static TableValPtr anon_preserve_orig_addr; static TableValPtr anon_preserve_orig_addr;
static TableValPtr anon_preserve_resp_addr; static TableValPtr anon_preserve_resp_addr;
static TableValPtr anon_preserve_other_addr; static TableValPtr anon_preserve_other_addr;
void init_ip_addr_anonymizers() void init_ip_addr_anonymizers() {
{ ip_anonymizer[KEEP_ORIG_ADDR] = nullptr;
ip_anonymizer[KEEP_ORIG_ADDR] = nullptr; ip_anonymizer[SEQUENTIALLY_NUMBERED] = new AnonymizeIPAddr_Seq();
ip_anonymizer[SEQUENTIALLY_NUMBERED] = new AnonymizeIPAddr_Seq(); ip_anonymizer[RANDOM_MD5] = new AnonymizeIPAddr_RandomMD5();
ip_anonymizer[RANDOM_MD5] = new AnonymizeIPAddr_RandomMD5(); ip_anonymizer[PREFIX_PRESERVING_A50] = new AnonymizeIPAddr_A50();
ip_anonymizer[PREFIX_PRESERVING_A50] = new AnonymizeIPAddr_A50(); ip_anonymizer[PREFIX_PRESERVING_MD5] = new AnonymizeIPAddr_PrefixMD5();
ip_anonymizer[PREFIX_PRESERVING_MD5] = new AnonymizeIPAddr_PrefixMD5();
auto id = global_scope()->Find("preserve_orig_addr"); auto id = global_scope()->Find("preserve_orig_addr");
if ( id ) if ( id )
anon_preserve_orig_addr = cast_intrusive<TableVal>(id->GetVal()); anon_preserve_orig_addr = cast_intrusive<TableVal>(id->GetVal());
id = global_scope()->Find("preserve_resp_addr"); id = global_scope()->Find("preserve_resp_addr");
if ( id ) if ( id )
anon_preserve_resp_addr = cast_intrusive<TableVal>(id->GetVal()); anon_preserve_resp_addr = cast_intrusive<TableVal>(id->GetVal());
id = global_scope()->Find("preserve_other_addr"); id = global_scope()->Find("preserve_other_addr");
if ( id ) if ( id )
anon_preserve_other_addr = cast_intrusive<TableVal>(id->GetVal()); anon_preserve_other_addr = cast_intrusive<TableVal>(id->GetVal());
} }
ipaddr32_t anonymize_ip(ipaddr32_t ip, enum ip_addr_anonymization_class_t cl) ipaddr32_t anonymize_ip(ipaddr32_t ip, enum ip_addr_anonymization_class_t cl) {
{ TableVal* preserve_addr = nullptr;
TableVal* preserve_addr = nullptr; auto addr = make_intrusive<AddrVal>(ip);
auto addr = make_intrusive<AddrVal>(ip);
int method = -1; int method = -1;
switch ( cl ) switch ( cl ) {
{ case ORIG_ADDR: // client address
case ORIG_ADDR: // client address preserve_addr = anon_preserve_orig_addr.get();
preserve_addr = anon_preserve_orig_addr.get(); method = orig_addr_anonymization;
method = orig_addr_anonymization; break;
break;
case RESP_ADDR: // server address case RESP_ADDR: // server address
preserve_addr = anon_preserve_resp_addr.get(); preserve_addr = anon_preserve_resp_addr.get();
method = resp_addr_anonymization; method = resp_addr_anonymization;
break; break;
default: default:
preserve_addr = anon_preserve_other_addr.get(); preserve_addr = anon_preserve_other_addr.get();
method = other_addr_anonymization; method = other_addr_anonymization;
break; break;
} }
ipaddr32_t new_ip = 0; ipaddr32_t new_ip = 0;
if ( preserve_addr && preserve_addr->FindOrDefault(addr) ) if ( preserve_addr && preserve_addr->FindOrDefault(addr) )
new_ip = ip; new_ip = ip;
else if ( method >= 0 && method < NUM_ADDR_ANONYMIZATION_METHODS ) else if ( method >= 0 && method < NUM_ADDR_ANONYMIZATION_METHODS ) {
{ if ( method == KEEP_ORIG_ADDR )
if ( method == KEEP_ORIG_ADDR ) new_ip = ip;
new_ip = ip;
else if ( ! ip_anonymizer[method] ) else if ( ! ip_anonymizer[method] )
reporter->InternalError("IP anonymizer not initialized"); reporter->InternalError("IP anonymizer not initialized");
else else
new_ip = ip_anonymizer[method]->Anonymize(ip); new_ip = ip_anonymizer[method]->Anonymize(ip);
} }
else else
reporter->InternalError("invalid IP anonymization method"); reporter->InternalError("invalid IP anonymization method");
#ifdef LOG_ANONYMIZATION_MAPPING #ifdef LOG_ANONYMIZATION_MAPPING
log_anonymization_mapping(ip, new_ip); log_anonymization_mapping(ip, new_ip);
#endif #endif
return new_ip; return new_ip;
} }
#ifdef LOG_ANONYMIZATION_MAPPING #ifdef LOG_ANONYMIZATION_MAPPING
void log_anonymization_mapping(ipaddr32_t input, ipaddr32_t output) void log_anonymization_mapping(ipaddr32_t input, ipaddr32_t output) {
{ if ( anonymization_mapping )
if ( anonymization_mapping ) event_mgr.Enqueue(anonymization_mapping, make_intrusive<AddrVal>(input), make_intrusive<AddrVal>(output));
event_mgr.Enqueue(anonymization_mapping, make_intrusive<AddrVal>(input), }
make_intrusive<AddrVal>(output));
}
#endif #endif
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -14,121 +14,111 @@
#include <map> #include <map>
#include <vector> #include <vector>
namespace zeek::detail namespace zeek::detail {
{
// TODO: Anon.h may not be the right place to put these functions ... // TODO: Anon.h may not be the right place to put these functions ...
enum ip_addr_anonymization_class_t enum ip_addr_anonymization_class_t {
{ ORIG_ADDR, // client address
ORIG_ADDR, // client address RESP_ADDR, // server address
RESP_ADDR, // server address OTHER_ADDR,
OTHER_ADDR, NUM_ADDR_ANONYMIZATION_CLASSES,
NUM_ADDR_ANONYMIZATION_CLASSES, };
};
enum ip_addr_anonymization_method_t enum ip_addr_anonymization_method_t {
{ KEEP_ORIG_ADDR,
KEEP_ORIG_ADDR, SEQUENTIALLY_NUMBERED,
SEQUENTIALLY_NUMBERED, RANDOM_MD5,
RANDOM_MD5, PREFIX_PRESERVING_A50,
PREFIX_PRESERVING_A50, PREFIX_PRESERVING_MD5,
PREFIX_PRESERVING_MD5, NUM_ADDR_ANONYMIZATION_METHODS,
NUM_ADDR_ANONYMIZATION_METHODS, };
};
using ipaddr32_t = uint32_t; using ipaddr32_t = uint32_t;
// NOTE: all addresses in parameters of *public* functions are in // NOTE: all addresses in parameters of *public* functions are in
// network order. // network order.
class AnonymizeIPAddr class AnonymizeIPAddr {
{
public: public:
virtual ~AnonymizeIPAddr() = default; virtual ~AnonymizeIPAddr() = default;
ipaddr32_t Anonymize(ipaddr32_t addr); ipaddr32_t Anonymize(ipaddr32_t addr);
virtual bool PreservePrefix(ipaddr32_t input, int num_bits); virtual bool PreservePrefix(ipaddr32_t input, int num_bits);
virtual ipaddr32_t anonymize(ipaddr32_t addr) = 0; virtual ipaddr32_t anonymize(ipaddr32_t addr) = 0;
bool PreserveNet(ipaddr32_t input); bool PreserveNet(ipaddr32_t input);
protected: protected:
std::map<ipaddr32_t, ipaddr32_t> mapping; std::map<ipaddr32_t, ipaddr32_t> mapping;
}; };
class AnonymizeIPAddr_Seq : public AnonymizeIPAddr class AnonymizeIPAddr_Seq : public AnonymizeIPAddr {
{
public: public:
AnonymizeIPAddr_Seq() { seq = 1; } AnonymizeIPAddr_Seq() { seq = 1; }
ipaddr32_t anonymize(ipaddr32_t addr) override; ipaddr32_t anonymize(ipaddr32_t addr) override;
protected: protected:
ipaddr32_t seq; ipaddr32_t seq;
}; };
class AnonymizeIPAddr_RandomMD5 : public AnonymizeIPAddr class AnonymizeIPAddr_RandomMD5 : public AnonymizeIPAddr {
{
public: public:
ipaddr32_t anonymize(ipaddr32_t addr) override; ipaddr32_t anonymize(ipaddr32_t addr) override;
}; };
class AnonymizeIPAddr_PrefixMD5 : public AnonymizeIPAddr class AnonymizeIPAddr_PrefixMD5 : public AnonymizeIPAddr {
{
public: public:
ipaddr32_t anonymize(ipaddr32_t addr) override; ipaddr32_t anonymize(ipaddr32_t addr) override;
protected: protected:
struct anon_prefix struct anon_prefix {
{ int len;
int len; ipaddr32_t prefix;
ipaddr32_t prefix; } prefix;
} prefix; };
};
class AnonymizeIPAddr_A50 : public AnonymizeIPAddr class AnonymizeIPAddr_A50 : public AnonymizeIPAddr {
{
public: public:
AnonymizeIPAddr_A50() { init(); } AnonymizeIPAddr_A50() { init(); }
~AnonymizeIPAddr_A50() override; ~AnonymizeIPAddr_A50() override;
ipaddr32_t anonymize(ipaddr32_t addr) override; ipaddr32_t anonymize(ipaddr32_t addr) override;
bool PreservePrefix(ipaddr32_t input, int num_bits) override; bool PreservePrefix(ipaddr32_t input, int num_bits) override;
protected: protected:
struct Node struct Node {
{ ipaddr32_t input;
ipaddr32_t input; ipaddr32_t output;
ipaddr32_t output; Node* child[2];
Node* child[2]; };
};
int method; int method;
int before_anonymization; int before_anonymization;
int new_mapping; int new_mapping;
// The root of prefix preserving mapping tree. // The root of prefix preserving mapping tree.
Node* root; Node* root;
// A node pool for new_node. // A node pool for new_node.
Node* next_free_node; Node* next_free_node;
std::vector<Node*> blocks; std::vector<Node*> blocks;
// for 0.0.0.0 and 255.255.255.255. // for 0.0.0.0 and 255.255.255.255.
Node special_nodes[2]; Node special_nodes[2];
void init(); void init();
Node* new_node(); Node* new_node();
Node* new_node_block(); Node* new_node_block();
void free_node(Node*); void free_node(Node*);
ipaddr32_t make_output(ipaddr32_t, int) const; ipaddr32_t make_output(ipaddr32_t, int) const;
Node* make_peer(ipaddr32_t, Node*); Node* make_peer(ipaddr32_t, Node*);
Node* find_node(ipaddr32_t); Node* find_node(ipaddr32_t);
}; };
// The global IP anonymizers. // The global IP anonymizers.
extern AnonymizeIPAddr* ip_anonymizer[NUM_ADDR_ANONYMIZATION_METHODS]; extern AnonymizeIPAddr* ip_anonymizer[NUM_ADDR_ANONYMIZATION_METHODS];
@ -139,4 +129,4 @@ ipaddr32_t anonymize_ip(ipaddr32_t ip, enum ip_addr_anonymization_class_t cl);
#define LOG_ANONYMIZATION_MAPPING #define LOG_ANONYMIZATION_MAPPING
void log_anonymization_mapping(ipaddr32_t input, ipaddr32_t output); void log_anonymization_mapping(ipaddr32_t input, ipaddr32_t output);
} // namespace zeek::detail } // namespace zeek::detail

File diff suppressed because it is too large Load diff

View file

@ -14,137 +14,131 @@
// modify expressions or supply metadata on types, and the kind that // modify expressions or supply metadata on types, and the kind that
// are extra metadata on every variable instance. // are extra metadata on every variable instance.
namespace zeek namespace zeek {
{
class Type; class Type;
using TypePtr = IntrusivePtr<Type>; using TypePtr = IntrusivePtr<Type>;
namespace detail namespace detail {
{
class Expr; class Expr;
using ExprPtr = IntrusivePtr<Expr>; using ExprPtr = IntrusivePtr<Expr>;
enum AttrTag enum AttrTag {
{ ATTR_OPTIONAL,
ATTR_OPTIONAL, ATTR_DEFAULT,
ATTR_DEFAULT, ATTR_DEFAULT_INSERT, // insert default value on failed lookups
ATTR_DEFAULT_INSERT, // insert default value on failed lookups ATTR_REDEF,
ATTR_REDEF, ATTR_ADD_FUNC,
ATTR_ADD_FUNC, ATTR_DEL_FUNC,
ATTR_DEL_FUNC, ATTR_EXPIRE_FUNC,
ATTR_EXPIRE_FUNC, ATTR_EXPIRE_READ,
ATTR_EXPIRE_READ, ATTR_EXPIRE_WRITE,
ATTR_EXPIRE_WRITE, ATTR_EXPIRE_CREATE,
ATTR_EXPIRE_CREATE, ATTR_RAW_OUTPUT,
ATTR_RAW_OUTPUT, ATTR_PRIORITY,
ATTR_PRIORITY, ATTR_GROUP,
ATTR_GROUP, ATTR_LOG,
ATTR_LOG, ATTR_ERROR_HANDLER,
ATTR_ERROR_HANDLER, ATTR_TYPE_COLUMN, // for input framework
ATTR_TYPE_COLUMN, // for input framework ATTR_TRACKED, // hidden attribute, tracked by NotifierRegistry
ATTR_TRACKED, // hidden attribute, tracked by NotifierRegistry ATTR_ON_CHANGE, // for table change tracking
ATTR_ON_CHANGE, // for table change tracking ATTR_BROKER_STORE, // for Broker store backed tables
ATTR_BROKER_STORE, // for Broker store backed tables ATTR_BROKER_STORE_ALLOW_COMPLEX, // for Broker store backed tables
ATTR_BROKER_STORE_ALLOW_COMPLEX, // for Broker store backed tables ATTR_BACKEND, // for Broker store backed tables
ATTR_BACKEND, // for Broker store backed tables ATTR_DEPRECATED,
ATTR_DEPRECATED, ATTR_IS_ASSIGNED, // to suppress usage warnings
ATTR_IS_ASSIGNED, // to suppress usage warnings ATTR_IS_USED, // to suppress usage warnings
ATTR_IS_USED, // to suppress usage warnings ATTR_ORDERED, // used to store tables in ordered mode
ATTR_ORDERED, // used to store tables in ordered mode NUM_ATTRS // this item should always be last
NUM_ATTRS // this item should always be last };
};
class Attr; class Attr;
using AttrPtr = IntrusivePtr<Attr>; using AttrPtr = IntrusivePtr<Attr>;
class Attributes; class Attributes;
using AttributesPtr = IntrusivePtr<Attributes>; using AttributesPtr = IntrusivePtr<Attributes>;
class Attr final : public Obj class Attr final : public Obj {
{
public: public:
static inline const AttrPtr nil; static inline const AttrPtr nil;
Attr(AttrTag t, ExprPtr e); Attr(AttrTag t, ExprPtr e);
explicit Attr(AttrTag t); explicit Attr(AttrTag t);
~Attr() override = default; ~Attr() override = default;
AttrTag Tag() const { return tag; } AttrTag Tag() const { return tag; }
const ExprPtr& GetExpr() const { return expr; } const ExprPtr& GetExpr() const { return expr; }
void SetAttrExpr(ExprPtr e); void SetAttrExpr(ExprPtr e);
void Describe(ODesc* d) const override; void Describe(ODesc* d) const override;
void DescribeReST(ODesc* d, bool shorten = false) const; void DescribeReST(ODesc* d, bool shorten = false) const;
/** /**
* Returns the deprecation string associated with a &deprecated attribute * Returns the deprecation string associated with a &deprecated attribute
* or an empty string if this is not such an attribute. * or an empty string if this is not such an attribute.
*/ */
std::string DeprecationMessage() const; std::string DeprecationMessage() const;
bool operator==(const Attr& other) const bool operator==(const Attr& other) const {
{ if ( tag != other.tag )
if ( tag != other.tag ) return false;
return false;
if ( expr || other.expr ) if ( expr || other.expr )
// Too hard to check for equivalency, since one // Too hard to check for equivalency, since one
// might be expressed/compiled differently than // might be expressed/compiled differently than
// the other, so assume they're compatible, as // the other, so assume they're compatible, as
// long as both are present. // long as both are present.
return expr && other.expr; return expr && other.expr;
return true; return true;
} }
detail::TraversalCode Traverse(detail::TraversalCallback* cb) const; detail::TraversalCode Traverse(detail::TraversalCallback* cb) const;
protected: protected:
void AddTag(ODesc* d) const; void AddTag(ODesc* d) const;
AttrTag tag; AttrTag tag;
ExprPtr expr; ExprPtr expr;
}; };
// Manages a collection of attributes. // Manages a collection of attributes.
class Attributes final : public Obj class Attributes final : public Obj {
{
public: public:
Attributes(std::vector<AttrPtr> a, TypePtr t, bool in_record, bool is_global); Attributes(std::vector<AttrPtr> a, TypePtr t, bool in_record, bool is_global);
Attributes(TypePtr t, bool in_record, bool is_global); Attributes(TypePtr t, bool in_record, bool is_global);
~Attributes() override = default; ~Attributes() override = default;
void AddAttr(AttrPtr a, bool is_redef = false); void AddAttr(AttrPtr a, bool is_redef = false);
void AddAttrs(const AttributesPtr& a, bool is_redef = false); void AddAttrs(const AttributesPtr& a, bool is_redef = false);
const AttrPtr& Find(AttrTag t) const; const AttrPtr& Find(AttrTag t) const;
void RemoveAttr(AttrTag t); void RemoveAttr(AttrTag t);
void Describe(ODesc* d) const override; void Describe(ODesc* d) const override;
void DescribeReST(ODesc* d, bool shorten = false) const; void DescribeReST(ODesc* d, bool shorten = false) const;
const std::vector<AttrPtr>& GetAttrs() const { return attrs; } const std::vector<AttrPtr>& GetAttrs() const { return attrs; }
bool operator==(const Attributes& other) const; bool operator==(const Attributes& other) const;
detail::TraversalCode Traverse(detail::TraversalCallback* cb) const; detail::TraversalCode Traverse(detail::TraversalCallback* cb) const;
protected: protected:
void CheckAttr(Attr* attr); void CheckAttr(Attr* attr);
TypePtr type; TypePtr type;
std::vector<AttrPtr> attrs; std::vector<AttrPtr> attrs;
bool in_record; bool in_record;
bool global_var; bool global_var;
}; };
// Checks whether default attribute "a" is compatible with the given type. // Checks whether default attribute "a" is compatible with the given type.
// "global_var" specifies whether the attribute is being associated with // "global_var" specifies whether the attribute is being associated with
@ -154,8 +148,7 @@ protected:
// Returns true on compatibility (which might include modifying "a"), false // Returns true on compatibility (which might include modifying "a"), false
// on an error. If an error message hasn't been directly generated, then // on an error. If an error message hasn't been directly generated, then
// it will be returned in err_msg. // it will be returned in err_msg.
extern bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_record, extern bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_record, std::string& err_msg);
std::string& err_msg);
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -8,277 +8,251 @@
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
#include "zeek/ZeekString.h" #include "zeek/ZeekString.h"
namespace zeek::detail namespace zeek::detail {
{
int Base64Converter::default_base64_table[256]; int Base64Converter::default_base64_table[256];
const std::string Base64Converter::default_alphabet = const std::string Base64Converter::default_alphabet =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
void Base64Converter::Encode(int len, const unsigned char* data, int* pblen, char** pbuf) void Base64Converter::Encode(int len, const unsigned char* data, int* pblen, char** pbuf) {
{ int blen;
int blen; char* buf;
char* buf;
if ( ! pbuf ) if ( ! pbuf )
reporter->InternalError("nil pointer to encoding result buffer"); reporter->InternalError("nil pointer to encoding result buffer");
if ( *pbuf && (*pblen % 4 != 0) ) if ( *pbuf && (*pblen % 4 != 0) )
reporter->InternalError("Base64 encode buffer not a multiple of 4"); reporter->InternalError("Base64 encode buffer not a multiple of 4");
if ( *pbuf ) if ( *pbuf ) {
{ buf = *pbuf;
buf = *pbuf; blen = *pblen;
blen = *pblen; }
} else {
else blen = (int)(4 * ceil((double)len / 3));
{ *pbuf = buf = new char[blen];
blen = (int)(4 * ceil((double)len / 3)); *pblen = blen;
*pbuf = buf = new char[blen]; }
*pblen = blen;
}
for ( int i = 0, j = 0; (i < len) && (j < blen); ) for ( int i = 0, j = 0; (i < len) && (j < blen); ) {
{ uint32_t bit32 = data[i++] << 16;
uint32_t bit32 = data[i++] << 16; bit32 += (i++ < len ? data[i - 1] : 0) << 8;
bit32 += (i++ < len ? data[i - 1] : 0) << 8; bit32 += i++ < len ? data[i - 1] : 0;
bit32 += i++ < len ? data[i - 1] : 0;
buf[j++] = alphabet[(bit32 >> 18) & 0x3f]; buf[j++] = alphabet[(bit32 >> 18) & 0x3f];
buf[j++] = alphabet[(bit32 >> 12) & 0x3f]; buf[j++] = alphabet[(bit32 >> 12) & 0x3f];
buf[j++] = (i == (len + 2)) ? '=' : alphabet[(bit32 >> 6) & 0x3f]; buf[j++] = (i == (len + 2)) ? '=' : alphabet[(bit32 >> 6) & 0x3f];
buf[j++] = (i >= (len + 1)) ? '=' : alphabet[bit32 & 0x3f]; buf[j++] = (i >= (len + 1)) ? '=' : alphabet[bit32 & 0x3f];
} }
} }
int* Base64Converter::InitBase64Table(const std::string& alphabet) int* Base64Converter::InitBase64Table(const std::string& alphabet) {
{ assert(alphabet.size() == 64);
assert(alphabet.size() == 64);
static bool default_table_initialized = false; static bool default_table_initialized = false;
if ( alphabet == default_alphabet && default_table_initialized ) if ( alphabet == default_alphabet && default_table_initialized )
return default_base64_table; return default_base64_table;
int* base64_table = nullptr; int* base64_table = nullptr;
if ( alphabet == default_alphabet ) if ( alphabet == default_alphabet ) {
{ base64_table = default_base64_table;
base64_table = default_base64_table; default_table_initialized = true;
default_table_initialized = true; }
} else
else base64_table = new int[256];
base64_table = new int[256];
int i; int i;
for ( i = 0; i < 256; ++i ) for ( i = 0; i < 256; ++i )
base64_table[i] = -1; base64_table[i] = -1;
for ( i = 0; i < 26; ++i ) for ( i = 0; i < 26; ++i ) {
{ base64_table[int(alphabet[0 + i])] = i;
base64_table[int(alphabet[0 + i])] = i; base64_table[int(alphabet[26 + i])] = i + 26;
base64_table[int(alphabet[26 + i])] = i + 26; }
}
for ( i = 0; i < 10; ++i ) for ( i = 0; i < 10; ++i )
base64_table[int(alphabet[52 + i])] = i + 52; base64_table[int(alphabet[52 + i])] = i + 52;
// Casts to avoid compiler warnings. // Casts to avoid compiler warnings.
base64_table[int(alphabet[62])] = 62; base64_table[int(alphabet[62])] = 62;
base64_table[int(alphabet[63])] = 63; base64_table[int(alphabet[63])] = 63;
base64_table[int('=')] = 0; base64_table[int('=')] = 0;
return base64_table; return base64_table;
} }
Base64Converter::Base64Converter(Connection* arg_conn, const std::string& arg_alphabet) Base64Converter::Base64Converter(Connection* arg_conn, const std::string& arg_alphabet) {
{ if ( arg_alphabet.size() > 0 ) {
if ( arg_alphabet.size() > 0 ) assert(arg_alphabet.size() == 64);
{ alphabet = arg_alphabet;
assert(arg_alphabet.size() == 64); }
alphabet = arg_alphabet; else {
} alphabet = default_alphabet;
else }
{
alphabet = default_alphabet;
}
base64_table = nullptr; base64_table = nullptr;
base64_group_next = 0; base64_group_next = 0;
base64_padding = base64_after_padding = 0; base64_padding = base64_after_padding = 0;
errored = 0; errored = 0;
conn = arg_conn; conn = arg_conn;
} }
Base64Converter::~Base64Converter() Base64Converter::~Base64Converter() {
{ if ( base64_table != default_base64_table )
if ( base64_table != default_base64_table ) delete[] base64_table;
delete[] base64_table; }
}
int Base64Converter::Decode(int len, const char* data, int* pblen, char** pbuf) int Base64Converter::Decode(int len, const char* data, int* pblen, char** pbuf) {
{ int blen;
int blen; char* buf;
char* buf;
// Initialization of table on first_time call of Decode. // Initialization of table on first_time call of Decode.
if ( ! base64_table ) if ( ! base64_table )
base64_table = InitBase64Table(alphabet); base64_table = InitBase64Table(alphabet);
if ( ! pbuf ) if ( ! pbuf )
reporter->InternalError("nil pointer to decoding result buffer"); reporter->InternalError("nil pointer to decoding result buffer");
if ( *pbuf ) if ( *pbuf ) {
{ buf = *pbuf;
buf = *pbuf; blen = *pblen;
blen = *pblen; }
} else {
else // Estimate the maximal number of 3-byte groups needed,
{ // plus 1 byte for the optional ending NUL.
// Estimate the maximal number of 3-byte groups needed, blen = int((len + base64_group_next + 3) / 4) * 3 + 1;
// plus 1 byte for the optional ending NUL. *pbuf = buf = new char[blen];
blen = int((len + base64_group_next + 3) / 4) * 3 + 1; }
*pbuf = buf = new char[blen];
}
int dlen = 0; int dlen = 0;
while ( true ) while ( true ) {
{ if ( base64_group_next == 4 ) {
if ( base64_group_next == 4 ) // For every group of 4 6-bit numbers,
{ // write the decoded 3 bytes to the buffer.
// For every group of 4 6-bit numbers, if ( base64_after_padding ) {
// write the decoded 3 bytes to the buffer. if ( ++errored == 1 )
if ( base64_after_padding ) IllegalEncoding("extra base64 groups after '=' padding are ignored");
{ base64_group_next = 0;
if ( ++errored == 1 ) continue;
IllegalEncoding("extra base64 groups after '=' padding are ignored"); }
base64_group_next = 0;
continue;
}
int num_octets = 3 - base64_padding; int num_octets = 3 - base64_padding;
if ( buf + num_octets > *pbuf + blen ) if ( buf + num_octets > *pbuf + blen )
break; break;
uint32_t bit32 = ((base64_group[0] & 0x3f) << 18) | ((base64_group[1] & 0x3f) << 12) | uint32_t bit32 = ((base64_group[0] & 0x3f) << 18) | ((base64_group[1] & 0x3f) << 12) |
((base64_group[2] & 0x3f) << 6) | ((base64_group[3] & 0x3f)); ((base64_group[2] & 0x3f) << 6) | ((base64_group[3] & 0x3f));
if ( --num_octets >= 0 ) if ( --num_octets >= 0 )
*buf++ = char((bit32 >> 16) & 0xff); *buf++ = char((bit32 >> 16) & 0xff);
if ( --num_octets >= 0 ) if ( --num_octets >= 0 )
*buf++ = char((bit32 >> 8) & 0xff); *buf++ = char((bit32 >> 8) & 0xff);
if ( --num_octets >= 0 ) if ( --num_octets >= 0 )
*buf++ = char((bit32)&0xff); *buf++ = char((bit32)&0xff);
if ( base64_padding > 0 ) if ( base64_padding > 0 )
base64_after_padding = 1; base64_after_padding = 1;
base64_group_next = 0; base64_group_next = 0;
base64_padding = 0; base64_padding = 0;
} }
if ( dlen >= len ) if ( dlen >= len )
break; break;
unsigned char c = (unsigned char)data[dlen]; unsigned char c = (unsigned char)data[dlen];
if ( c == '=' ) if ( c == '=' )
++base64_padding; ++base64_padding;
int k = base64_table[c]; int k = base64_table[c];
if ( k >= 0 ) if ( k >= 0 )
base64_group[base64_group_next++] = k; base64_group[base64_group_next++] = k;
else else {
{ if ( ++errored == 1 )
if ( ++errored == 1 ) IllegalEncoding(util::fmt("character %d ignored by Base64 decoding", (int)c));
IllegalEncoding(util::fmt("character %d ignored by Base64 decoding", (int)c)); }
}
++dlen; ++dlen;
} }
*pblen = buf - *pbuf; *pblen = buf - *pbuf;
return dlen; return dlen;
} }
int Base64Converter::Done(int* pblen, char** pbuf) int Base64Converter::Done(int* pblen, char** pbuf) {
{ const char* padding = "===";
const char* padding = "===";
if ( base64_group_next != 0 ) if ( base64_group_next != 0 ) {
{ if ( base64_group_next < 4 )
if ( base64_group_next < 4 ) IllegalEncoding(
IllegalEncoding(util::fmt("incomplete base64 group, padding with %d bits of 0", util::fmt("incomplete base64 group, padding with %d bits of 0", (4 - base64_group_next) * 6));
(4 - base64_group_next) * 6)); Decode(4 - base64_group_next, padding, pblen, pbuf);
Decode(4 - base64_group_next, padding, pblen, pbuf); return -1;
return -1; }
}
if ( pblen ) if ( pblen )
*pblen = 0; *pblen = 0;
return 0; return 0;
} }
void Base64Converter::IllegalEncoding(const char* msg) void Base64Converter::IllegalEncoding(const char* msg) {
{ // strncpy(error_msg, msg, sizeof(error_msg));
// strncpy(error_msg, msg, sizeof(error_msg)); if ( conn )
if ( conn ) conn->Weird("base64_illegal_encoding", msg);
conn->Weird("base64_illegal_encoding", msg); else
else reporter->Error("%s", msg);
reporter->Error("%s", msg); }
}
String* decode_base64(const String* s, const String* a, Connection* conn) String* decode_base64(const String* s, const String* a, Connection* conn) {
{ if ( a && a->Len() != 0 && a->Len() != 64 ) {
if ( a && a->Len() != 0 && a->Len() != 64 ) reporter->Error("base64 decoding alphabet is not 64 characters: %s", a->CheckString());
{ return nullptr;
reporter->Error("base64 decoding alphabet is not 64 characters: %s", a->CheckString()); }
return nullptr;
}
int buf_len = int((s->Len() + 3) / 4) * 3 + 1; int buf_len = int((s->Len() + 3) / 4) * 3 + 1;
int rlen2, rlen = buf_len; int rlen2, rlen = buf_len;
char *rbuf2, *rbuf = new char[rlen]; char *rbuf2, *rbuf = new char[rlen];
Base64Converter dec(conn, a ? a->CheckString() : ""); Base64Converter dec(conn, a ? a->CheckString() : "");
dec.Decode(s->Len(), (const char*)s->Bytes(), &rlen, &rbuf); dec.Decode(s->Len(), (const char*)s->Bytes(), &rlen, &rbuf);
if ( dec.Errored() ) if ( dec.Errored() )
goto err; goto err;
rlen2 = buf_len - rlen; rlen2 = buf_len - rlen;
rbuf2 = rbuf + rlen; rbuf2 = rbuf + rlen;
// Done() returns -1 if there isn't enough padding, but we just ignore // Done() returns -1 if there isn't enough padding, but we just ignore
// it. // it.
dec.Done(&rlen2, &rbuf2); dec.Done(&rlen2, &rbuf2);
rlen += rlen2; rlen += rlen2;
rbuf[rlen] = '\0'; rbuf[rlen] = '\0';
return new String(true, (u_char*)rbuf, rlen); return new String(true, (u_char*)rbuf, rlen);
err: err:
delete[] rbuf; delete[] rbuf;
return nullptr; return nullptr;
} }
String* encode_base64(const String* s, const String* a, Connection* conn) String* encode_base64(const String* s, const String* a, Connection* conn) {
{ if ( a && a->Len() != 0 && a->Len() != 64 ) {
if ( a && a->Len() != 0 && a->Len() != 64 ) reporter->Error("base64 alphabet is not 64 characters: %s", a->CheckString());
{ return nullptr;
reporter->Error("base64 alphabet is not 64 characters: %s", a->CheckString()); }
return nullptr;
}
char* outbuf = nullptr; char* outbuf = nullptr;
int outlen = 0; int outlen = 0;
Base64Converter enc(conn, a ? a->CheckString() : ""); Base64Converter enc(conn, a ? a->CheckString() : "");
enc.Encode(s->Len(), (const unsigned char*)s->Bytes(), &outlen, &outbuf); enc.Encode(s->Len(), (const unsigned char*)s->Bytes(), &outlen, &outbuf);
return new String(true, (u_char*)outbuf, outlen); return new String(true, (u_char*)outbuf, outlen);
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -4,69 +4,66 @@
#include <string> #include <string>
namespace zeek namespace zeek {
{
class String; class String;
class Connection; class Connection;
namespace detail namespace detail {
{
// Maybe we should have a base class for generic decoders? // Maybe we should have a base class for generic decoders?
class Base64Converter class Base64Converter {
{
public: public:
// <conn> is used for error reporting. If it is set to zero (as, // <conn> is used for error reporting. If it is set to zero (as,
// e.g., done by the built-in functions decode_base64() and // e.g., done by the built-in functions decode_base64() and
// encode_base64()), encoding-errors will go to Reporter instead of // encode_base64()), encoding-errors will go to Reporter instead of
// Weird. Usage errors go to Reporter in any case. Empty alphabet // Weird. Usage errors go to Reporter in any case. Empty alphabet
// indicates the default base64 alphabet. // indicates the default base64 alphabet.
explicit Base64Converter(Connection* conn, const std::string& alphabet = ""); explicit Base64Converter(Connection* conn, const std::string& alphabet = "");
~Base64Converter(); ~Base64Converter();
// A note on Decode(): // A note on Decode():
// //
// The input is specified by <len> and <data> and the output // The input is specified by <len> and <data> and the output
// buffer by <blen> and <buf>. If *buf is nil, a buffer of // buffer by <blen> and <buf>. If *buf is nil, a buffer of
// an appropriate size will be new'd and *buf will point // an appropriate size will be new'd and *buf will point
// to the buffer on return. *blen holds the length of // to the buffer on return. *blen holds the length of
// decoded data on return. The function returns the number of // decoded data on return. The function returns the number of
// input bytes processed, since the decoding will stop when there // input bytes processed, since the decoding will stop when there
// is not enough output buffer space. // is not enough output buffer space.
int Decode(int len, const char* data, int* blen, char** buf); int Decode(int len, const char* data, int* blen, char** buf);
void Encode(int len, const unsigned char* data, int* blen, char** buf); void Encode(int len, const unsigned char* data, int* blen, char** buf);
int Done(int* pblen, char** pbuf); int Done(int* pblen, char** pbuf);
bool HasData() const { return base64_group_next != 0; } bool HasData() const { return base64_group_next != 0; }
// True if an error has occurred. // True if an error has occurred.
int Errored() const { return errored; } int Errored() const { return errored; }
const char* ErrorMsg() const { return error_msg; } const char* ErrorMsg() const { return error_msg; }
void IllegalEncoding(const char* msg); void IllegalEncoding(const char* msg);
protected: protected:
char error_msg[256]; char error_msg[256];
protected: protected:
static const std::string default_alphabet; static const std::string default_alphabet;
std::string alphabet; std::string alphabet;
static int* InitBase64Table(const std::string& alphabet); static int* InitBase64Table(const std::string& alphabet);
static int default_base64_table[256]; static int default_base64_table[256];
char base64_group[4]; char base64_group[4];
int base64_group_next; int base64_group_next;
int base64_padding; int base64_padding;
int base64_after_padding; int base64_after_padding;
int* base64_table; int* base64_table;
int errored; // if true, we encountered an error - skip further processing int errored; // if true, we encountered an error - skip further processing
Connection* conn; Connection* conn;
}; };
String* decode_base64(const String* s, const String* a = nullptr, Connection* conn = nullptr); String* decode_base64(const String* s, const String* a = nullptr, Connection* conn = nullptr);
String* encode_base64(const String* s, const String* a = nullptr, Connection* conn = nullptr); String* encode_base64(const String* s, const String* a = nullptr, Connection* conn = nullptr);
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -4,9 +4,8 @@
#include "zeek/Val.h" #include "zeek/Val.h"
namespace zeek::detail namespace zeek::detail {
{
BifReturnVal::BifReturnVal(std::nullptr_t) noexcept { } BifReturnVal::BifReturnVal(std::nullptr_t) noexcept {}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -6,31 +6,27 @@
#include "zeek/IntrusivePtr.h" #include "zeek/IntrusivePtr.h"
namespace zeek namespace zeek {
{
class Val; class Val;
using ValPtr = IntrusivePtr<Val>; using ValPtr = IntrusivePtr<Val>;
namespace detail namespace detail {
{
/** /**
* A simple wrapper class to use for the return value of BIFs so that * A simple wrapper class to use for the return value of BIFs so that
* they may return either a Val* or IntrusivePtr<Val> (the former could * they may return either a Val* or IntrusivePtr<Val> (the former could
* potentially be deprecated). * potentially be deprecated).
*/ */
class BifReturnVal class BifReturnVal {
{
public: public:
template <typename T> BifReturnVal(IntrusivePtr<T> v) noexcept : rval(AdoptRef{}, v.release()) template<typename T>
{ BifReturnVal(IntrusivePtr<T> v) noexcept : rval(AdoptRef{}, v.release()) {}
}
BifReturnVal(std::nullptr_t) noexcept; BifReturnVal(std::nullptr_t) noexcept;
ValPtr rval; ValPtr rval;
}; };
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -9,43 +9,33 @@
#include "zeek/DFA.h" #include "zeek/DFA.h"
#include "zeek/RE.h" #include "zeek/RE.h"
namespace zeek::detail namespace zeek::detail {
{
CCL::CCL() CCL::CCL() {
{ syms = new int_list;
syms = new int_list; index = -(rem->InsertCCL(this) + 1);
index = -(rem->InsertCCL(this) + 1); negated = 0;
negated = 0; }
}
CCL::~CCL() CCL::~CCL() { delete syms; }
{
delete syms;
}
void CCL::Negate() void CCL::Negate() {
{ negated = 1;
negated = 1; Add(SYM_BOL);
Add(SYM_BOL); Add(SYM_EOL);
Add(SYM_EOL); }
}
void CCL::Add(int sym) void CCL::Add(int sym) {
{ auto sym_p = static_cast<std::intptr_t>(sym);
auto sym_p = static_cast<std::intptr_t>(sym);
// Check to see if the character is already in the ccl. // Check to see if the character is already in the ccl.
for ( auto sym_entry : *syms ) for ( auto sym_entry : *syms )
if ( sym_entry == sym_p ) if ( sym_entry == sym_p )
return; return;
syms->push_back(sym_p); syms->push_back(sym_p);
} }
void CCL::Sort() void CCL::Sort() { std::sort(syms->begin(), syms->end()); }
{
std::sort(syms->begin(), syms->end());
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -5,36 +5,33 @@
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
namespace zeek::detail namespace zeek::detail {
{
using int_list = std::vector<std::intptr_t>; using int_list = std::vector<std::intptr_t>;
class CCL class CCL {
{
public: public:
CCL(); CCL();
~CCL(); ~CCL();
void Add(int sym); void Add(int sym);
void Negate(); void Negate();
bool IsNegated() { return negated != 0; } bool IsNegated() { return negated != 0; }
int Index() { return index; } int Index() { return index; }
void Sort(); void Sort();
int_list* Syms() { return syms; } int_list* Syms() { return syms; }
void ReplaceSyms(int_list* new_syms) void ReplaceSyms(int_list* new_syms) {
{ delete syms;
delete syms; syms = new_syms;
syms = new_syms; }
}
protected: protected:
int_list* syms; int_list* syms;
int negated; int negated;
int index; int index;
}; };
} // namespace zeek::detail } // namespace zeek::detail

File diff suppressed because it is too large Load diff

View file

@ -7,67 +7,61 @@
#include "zeek/Func.h" #include "zeek/Func.h"
#include "zeek/Type.h" #include "zeek/Type.h"
namespace zeek namespace zeek {
{
class ListVal; class ListVal;
using ListValPtr = zeek::IntrusivePtr<ListVal>; using ListValPtr = zeek::IntrusivePtr<ListVal>;
} // namespace zeek } // namespace zeek
namespace zeek::detail namespace zeek::detail {
{
class HashKey; class HashKey;
class CompositeHash class CompositeHash {
{
public: public:
explicit CompositeHash(TypeListPtr composite_type); explicit CompositeHash(TypeListPtr composite_type);
// Compute the hash corresponding to the given index val, // Compute the hash corresponding to the given index val,
// or nullptr if it fails to typecheck. // or nullptr if it fails to typecheck.
std::unique_ptr<HashKey> MakeHashKey(const Val& v, bool type_check) const; std::unique_ptr<HashKey> MakeHashKey(const Val& v, bool type_check) const;
// Given a hash key, recover the values used to create it. // Given a hash key, recover the values used to create it.
ListValPtr RecoverVals(const HashKey& k) const; ListValPtr RecoverVals(const HashKey& k) const;
protected: protected:
bool SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type_check, bool optional, bool SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type_check, bool optional, bool singleton) const;
bool singleton) const;
// Recovers just one Val of possibly many; called from RecoverVals. // Recovers just one Val of possibly many; called from RecoverVals.
// Upon return, pval will point to the recovered Val of type t. // Upon return, pval will point to the recovered Val of type t.
// Returns and updated kp for the next Val. Calls reporter->InternalError() // Returns and updated kp for the next Val. Calls reporter->InternalError()
// upon errors, so there is no return value for invalid input. // upon errors, so there is no return value for invalid input.
bool RecoverOneVal(const HashKey& k, Type* t, ValPtr* pval, bool optional, bool RecoverOneVal(const HashKey& k, Type* t, ValPtr* pval, bool optional, bool singleton) const;
bool singleton) const;
// Compute the size of the composite key. If v is non-nil then // Compute the size of the composite key. If v is non-nil then
// the value is computed for the particular list of values. // the value is computed for the particular list of values.
// Returns 0 if the key has an indeterminate size (if v not given), // Returns 0 if the key has an indeterminate size (if v not given),
// or if v doesn't match the index type (if given). // or if v doesn't match the index type (if given).
bool ReserveKeySize(HashKey& hk, const Val* v, bool type_check, bool calc_static_size) const; bool ReserveKeySize(HashKey& hk, const Val* v, bool type_check, bool calc_static_size) const;
bool ReserveSingleTypeKeySize(HashKey& hk, Type*, const Val* v, bool type_check, bool optional, bool ReserveSingleTypeKeySize(HashKey& hk, Type*, const Val* v, bool type_check, bool optional,
bool calc_static_size, bool singleton) const; bool calc_static_size, bool singleton) const;
bool EnsureTypeReserve(HashKey& hk, const Val* v, Type* bt, bool type_check) const; bool EnsureTypeReserve(HashKey& hk, const Val* v, Type* bt, bool type_check) const;
// The following are for allowing hashing of function values. // The following are for allowing hashing of function values.
// These can occur, for example, in sets of predicates that get // These can occur, for example, in sets of predicates that get
// iterated over. We use pointers in order to keep storage // iterated over. We use pointers in order to keep storage
// lower for the common case of these not being needed. // lower for the common case of these not being needed.
std::unique_ptr<std::unordered_map<const Func*, uint32_t>> func_to_func_id; std::unique_ptr<std::unordered_map<const Func*, uint32_t>> func_to_func_id;
std::unique_ptr<std::vector<FuncPtr>> func_id_to_func; std::unique_ptr<std::vector<FuncPtr>> func_id_to_func;
void BuildFuncMappings() void BuildFuncMappings() {
{ func_to_func_id = std::make_unique<std::unordered_map<const Func*, uint32_t>>();
func_to_func_id = std::make_unique<std::unordered_map<const Func*, uint32_t>>(); func_id_to_func = std::make_unique<std::vector<FuncPtr>>();
func_id_to_func = std::make_unique<std::vector<FuncPtr>>(); }
}
TypeListPtr type; TypeListPtr type;
bool is_singleton = false; // if just one type in index bool is_singleton = false; // if just one type in index
}; };
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -22,470 +22,410 @@
#include "zeek/packet_analysis/protocol/tcp/TCP.h" #include "zeek/packet_analysis/protocol/tcp/TCP.h"
#include "zeek/session/Manager.h" #include "zeek/session/Manager.h"
namespace zeek namespace zeek {
{
uint64_t Connection::total_connections = 0; uint64_t Connection::total_connections = 0;
uint64_t Connection::current_connections = 0; uint64_t Connection::current_connections = 0;
Connection::Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, Connection::Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, const Packet* pkt)
const Packet* pkt) : Session(t, connection_timeout, connection_status_update, detail::connection_status_update_interval), key(k) {
: Session(t, connection_timeout, connection_status_update, orig_addr = id->src_addr;
detail::connection_status_update_interval), resp_addr = id->dst_addr;
key(k) orig_port = id->src_port;
{ resp_port = id->dst_port;
orig_addr = id->src_addr; proto = TRANSPORT_UNKNOWN;
resp_addr = id->dst_addr; orig_flow_label = flow;
orig_port = id->src_port; resp_flow_label = 0;
resp_port = id->dst_port; saw_first_orig_packet = 1;
proto = TRANSPORT_UNKNOWN; saw_first_resp_packet = 0;
orig_flow_label = flow;
resp_flow_label = 0;
saw_first_orig_packet = 1;
saw_first_resp_packet = 0;
if ( pkt->l2_src ) if ( pkt->l2_src )
memcpy(orig_l2_addr, pkt->l2_src, sizeof(orig_l2_addr)); memcpy(orig_l2_addr, pkt->l2_src, sizeof(orig_l2_addr));
else else
memset(orig_l2_addr, 0, sizeof(orig_l2_addr)); memset(orig_l2_addr, 0, sizeof(orig_l2_addr));
if ( pkt->l2_dst ) if ( pkt->l2_dst )
memcpy(resp_l2_addr, pkt->l2_dst, sizeof(resp_l2_addr)); memcpy(resp_l2_addr, pkt->l2_dst, sizeof(resp_l2_addr));
else else
memset(resp_l2_addr, 0, sizeof(resp_l2_addr)); memset(resp_l2_addr, 0, sizeof(resp_l2_addr));
vlan = pkt->vlan; vlan = pkt->vlan;
inner_vlan = pkt->inner_vlan; inner_vlan = pkt->inner_vlan;
weird = 0; weird = 0;
suppress_event = 0; suppress_event = 0;
finished = 0; finished = 0;
hist_seen = 0; hist_seen = 0;
history = ""; history = "";
adapter = nullptr; adapter = nullptr;
primary_PIA = nullptr; primary_PIA = nullptr;
++current_connections; ++current_connections;
++total_connections; ++total_connections;
encapsulation = pkt->encap; encapsulation = pkt->encap;
} }
Connection::~Connection() Connection::~Connection() {
{ if ( ! finished )
if ( ! finished ) reporter->InternalError("Done() not called before destruction of Connection");
reporter->InternalError("Done() not called before destruction of Connection");
CancelTimers(); CancelTimers();
if ( conn_val ) if ( conn_val )
conn_val->SetOrigin(nullptr); conn_val->SetOrigin(nullptr);
delete adapter; delete adapter;
--current_connections; --current_connections;
} }
void Connection::CheckEncapsulation(const std::shared_ptr<EncapsulationStack>& arg_encap) void Connection::CheckEncapsulation(const std::shared_ptr<EncapsulationStack>& arg_encap) {
{ if ( encapsulation && arg_encap ) {
if ( encapsulation && arg_encap ) if ( *encapsulation != *arg_encap ) {
{ if ( tunnel_changed && (zeek::detail::tunnel_max_changes_per_connection == 0 ||
if ( *encapsulation != *arg_encap ) tunnel_changes < zeek::detail::tunnel_max_changes_per_connection) ) {
{ tunnel_changes++;
if ( tunnel_changed && EnqueueEvent(tunnel_changed, nullptr, GetVal(), arg_encap->ToVal());
(zeek::detail::tunnel_max_changes_per_connection == 0 || }
tunnel_changes < zeek::detail::tunnel_max_changes_per_connection) )
{
tunnel_changes++;
EnqueueEvent(tunnel_changed, nullptr, GetVal(), arg_encap->ToVal());
}
encapsulation = std::make_shared<EncapsulationStack>(*arg_encap); encapsulation = std::make_shared<EncapsulationStack>(*arg_encap);
} }
} }
else if ( encapsulation ) else if ( encapsulation ) {
{ if ( tunnel_changed ) {
if ( tunnel_changed ) EncapsulationStack empty;
{ EnqueueEvent(tunnel_changed, nullptr, GetVal(), empty.ToVal());
EncapsulationStack empty; }
EnqueueEvent(tunnel_changed, nullptr, GetVal(), empty.ToVal());
}
encapsulation = nullptr; encapsulation = nullptr;
} }
else if ( arg_encap ) else if ( arg_encap ) {
{ if ( tunnel_changed )
if ( tunnel_changed ) EnqueueEvent(tunnel_changed, nullptr, GetVal(), arg_encap->ToVal());
EnqueueEvent(tunnel_changed, nullptr, GetVal(), arg_encap->ToVal());
encapsulation = std::make_shared<EncapsulationStack>(*arg_encap); encapsulation = std::make_shared<EncapsulationStack>(*arg_encap);
} }
} }
void Connection::Done() void Connection::Done() {
{ finished = 1;
finished = 1;
if ( adapter ) if ( adapter ) {
{ if ( ConnTransport() == TRANSPORT_TCP ) {
if ( ConnTransport() == TRANSPORT_TCP ) auto* ta = static_cast<packet_analysis::TCP::TCPSessionAdapter*>(adapter);
{ assert(ta->IsAnalyzer("TCP"));
auto* ta = static_cast<packet_analysis::TCP::TCPSessionAdapter*>(adapter); analyzer::tcp::TCP_Endpoint* to = ta->Orig();
assert(ta->IsAnalyzer("TCP")); analyzer::tcp::TCP_Endpoint* tr = ta->Resp();
analyzer::tcp::TCP_Endpoint* to = ta->Orig();
analyzer::tcp::TCP_Endpoint* tr = ta->Resp();
packet_analysis::TCP::TCPAnalyzer::GetStats().StateLeft(to->state, tr->state); packet_analysis::TCP::TCPAnalyzer::GetStats().StateLeft(to->state, tr->state);
} }
if ( ! adapter->IsFinished() ) if ( ! adapter->IsFinished() )
adapter->Done(); adapter->Done();
} }
} }
void Connection::NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, int caplen, void Connection::NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, int caplen, const u_char*& data,
const u_char*& data, int& record_packet, int& record_content, int& record_packet, int& record_content,
// arguments for reproducing packets // arguments for reproducing packets
const Packet* pkt) const Packet* pkt) {
{ run_state::current_timestamp = t;
run_state::current_timestamp = t; run_state::current_pkt = pkt;
run_state::current_pkt = pkt;
if ( adapter ) if ( adapter ) {
{ if ( adapter->Skipping() )
if ( adapter->Skipping() ) return;
return;
record_current_packet = record_packet; record_current_packet = record_packet;
record_current_content = record_content; record_current_content = record_content;
adapter->NextPacket(len, data, is_orig, -1, ip, caplen); adapter->NextPacket(len, data, is_orig, -1, ip, caplen);
record_packet = record_current_packet; record_packet = record_current_packet;
record_content = record_current_content; record_content = record_current_content;
} }
else else
last_time = t; last_time = t;
run_state::current_timestamp = 0; run_state::current_timestamp = 0;
run_state::current_pkt = nullptr; run_state::current_pkt = nullptr;
} }
bool Connection::IsReuse(double t, const u_char* pkt) bool Connection::IsReuse(double t, const u_char* pkt) { return adapter && adapter->IsReuse(t, pkt); }
{
return adapter && adapter->IsReuse(t, pkt);
}
bool Connection::ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scaling_threshold, bool Connection::ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scaling_threshold, uint32_t scaling_base) {
uint32_t scaling_base) if ( ++counter == scaling_threshold ) {
{ AddHistory(code);
if ( ++counter == scaling_threshold )
{
AddHistory(code);
auto new_threshold = scaling_threshold * scaling_base; auto new_threshold = scaling_threshold * scaling_base;
if ( new_threshold <= scaling_threshold ) if ( new_threshold <= scaling_threshold )
// This can happen due to wrap-around. In that // This can happen due to wrap-around. In that
// case, reset the counter but leave the threshold // case, reset the counter but leave the threshold
// unchanged. // unchanged.
counter = 0; counter = 0;
else else
scaling_threshold = new_threshold; scaling_threshold = new_threshold;
return true; return true;
} }
return false; return false;
} }
void Connection::HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t threshold) void Connection::HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t threshold) {
{ if ( ! e )
if ( ! e ) return;
return;
if ( threshold == 1 ) if ( threshold == 1 )
// This will be far and away the most common case, // This will be far and away the most common case,
// and at this stage it's not a *multiple* instance. // and at this stage it's not a *multiple* instance.
return; return;
EnqueueEvent(e, nullptr, GetVal(), val_mgr->Bool(is_orig), val_mgr->Count(threshold)); EnqueueEvent(e, nullptr, GetVal(), val_mgr->Bool(is_orig), val_mgr->Count(threshold));
} }
namespace namespace {
{
// Flip everything that needs to be flipped in the connection // Flip everything that needs to be flipped in the connection
// record that is known on this level. This needs to align // record that is known on this level. This needs to align
// with GetVal() and connection's layout in init-bare. // with GetVal() and connection's layout in init-bare.
void flip_conn_val(const RecordValPtr& conn_val) void flip_conn_val(const RecordValPtr& conn_val) {
{ // Flip the the conn_id (c$id).
// Flip the the conn_id (c$id). const auto& id_val = conn_val->GetField<zeek::RecordVal>(0);
const auto& id_val = conn_val->GetField<zeek::RecordVal>(0); const auto& tmp_addr = id_val->GetField<zeek::AddrVal>(0);
const auto& tmp_addr = id_val->GetField<zeek::AddrVal>(0); const auto& tmp_port = id_val->GetField<zeek::PortVal>(1);
const auto& tmp_port = id_val->GetField<zeek::PortVal>(1); id_val->Assign(0, id_val->GetField<zeek::AddrVal>(2));
id_val->Assign(0, id_val->GetField<zeek::AddrVal>(2)); id_val->Assign(1, id_val->GetField<zeek::PortVal>(3));
id_val->Assign(1, id_val->GetField<zeek::PortVal>(3)); id_val->Assign(2, tmp_addr);
id_val->Assign(2, tmp_addr); id_val->Assign(3, tmp_port);
id_val->Assign(3, tmp_port);
// Flip the endpoints within connection. // Flip the endpoints within connection.
const auto& tmp_endp = conn_val->GetField<zeek::RecordVal>(1); const auto& tmp_endp = conn_val->GetField<zeek::RecordVal>(1);
conn_val->Assign(1, conn_val->GetField(2)); conn_val->Assign(1, conn_val->GetField(2));
conn_val->Assign(2, tmp_endp); conn_val->Assign(2, tmp_endp);
} }
} } // namespace
const RecordValPtr& Connection::GetVal() const RecordValPtr& Connection::GetVal() {
{ if ( ! conn_val ) {
if ( ! conn_val ) conn_val = make_intrusive<RecordVal>(id::connection);
{
conn_val = make_intrusive<RecordVal>(id::connection);
TransportProto prot_type = ConnTransport(); TransportProto prot_type = ConnTransport();
auto id_val = make_intrusive<RecordVal>(id::conn_id); auto id_val = make_intrusive<RecordVal>(id::conn_id);
id_val->Assign(0, make_intrusive<AddrVal>(orig_addr)); id_val->Assign(0, make_intrusive<AddrVal>(orig_addr));
id_val->Assign(1, val_mgr->Port(ntohs(orig_port), prot_type)); id_val->Assign(1, val_mgr->Port(ntohs(orig_port), prot_type));
id_val->Assign(2, make_intrusive<AddrVal>(resp_addr)); id_val->Assign(2, make_intrusive<AddrVal>(resp_addr));
id_val->Assign(3, val_mgr->Port(ntohs(resp_port), prot_type)); id_val->Assign(3, val_mgr->Port(ntohs(resp_port), prot_type));
auto orig_endp = make_intrusive<RecordVal>(id::endpoint); auto orig_endp = make_intrusive<RecordVal>(id::endpoint);
orig_endp->Assign(0, 0); orig_endp->Assign(0, 0);
orig_endp->Assign(1, 0); orig_endp->Assign(1, 0);
orig_endp->Assign(4, orig_flow_label); orig_endp->Assign(4, orig_flow_label);
const int l2_len = sizeof(orig_l2_addr); const int l2_len = sizeof(orig_l2_addr);
char null[l2_len]{}; char null[l2_len]{};
if ( memcmp(&orig_l2_addr, &null, l2_len) != 0 ) if ( memcmp(&orig_l2_addr, &null, l2_len) != 0 )
orig_endp->Assign(5, fmt_mac(orig_l2_addr, l2_len)); orig_endp->Assign(5, fmt_mac(orig_l2_addr, l2_len));
auto resp_endp = make_intrusive<RecordVal>(id::endpoint); auto resp_endp = make_intrusive<RecordVal>(id::endpoint);
resp_endp->Assign(0, 0); resp_endp->Assign(0, 0);
resp_endp->Assign(1, 0); resp_endp->Assign(1, 0);
resp_endp->Assign(4, resp_flow_label); resp_endp->Assign(4, resp_flow_label);
if ( memcmp(&resp_l2_addr, &null, l2_len) != 0 ) if ( memcmp(&resp_l2_addr, &null, l2_len) != 0 )
resp_endp->Assign(5, fmt_mac(resp_l2_addr, l2_len)); resp_endp->Assign(5, fmt_mac(resp_l2_addr, l2_len));
conn_val->Assign(0, std::move(id_val)); conn_val->Assign(0, std::move(id_val));
conn_val->Assign(1, std::move(orig_endp)); conn_val->Assign(1, std::move(orig_endp));
conn_val->Assign(2, std::move(resp_endp)); conn_val->Assign(2, std::move(resp_endp));
// 3 and 4 are set below. // 3 and 4 are set below.
conn_val->Assign(5, make_intrusive<TableVal>(id::string_set)); // service conn_val->Assign(5, make_intrusive<TableVal>(id::string_set)); // service
conn_val->Assign(6, val_mgr->EmptyString()); // history conn_val->Assign(6, val_mgr->EmptyString()); // history
if ( ! uid ) if ( ! uid )
uid.Set(zeek::detail::bits_per_uid); uid.Set(zeek::detail::bits_per_uid);
conn_val->Assign(7, uid.Base62("C")); conn_val->Assign(7, uid.Base62("C"));
if ( encapsulation && encapsulation->Depth() > 0 ) if ( encapsulation && encapsulation->Depth() > 0 )
conn_val->Assign(8, encapsulation->ToVal()); conn_val->Assign(8, encapsulation->ToVal());
if ( vlan != 0 ) if ( vlan != 0 )
conn_val->Assign(9, vlan); conn_val->Assign(9, vlan);
if ( inner_vlan != 0 ) if ( inner_vlan != 0 )
conn_val->Assign(10, inner_vlan); conn_val->Assign(10, inner_vlan);
} }
if ( adapter ) if ( adapter )
adapter->UpdateConnVal(conn_val.get()); adapter->UpdateConnVal(conn_val.get());
conn_val->AssignTime(3, start_time); // ### conn_val->AssignTime(3, start_time); // ###
conn_val->AssignInterval(4, last_time - start_time); conn_val->AssignInterval(4, last_time - start_time);
if ( ! history.empty() ) if ( ! history.empty() ) {
{ auto v = conn_val->GetFieldAs<StringVal>(6);
auto v = conn_val->GetFieldAs<StringVal>(6); if ( *v != history )
if ( *v != history ) conn_val->Assign(6, history);
conn_val->Assign(6, history); }
}
conn_val->SetOrigin(this); conn_val->SetOrigin(this);
return conn_val; return conn_val;
} }
analyzer::Analyzer* Connection::FindAnalyzer(analyzer::ID id) analyzer::Analyzer* Connection::FindAnalyzer(analyzer::ID id) { return adapter ? adapter->FindChild(id) : nullptr; }
{
return adapter ? adapter->FindChild(id) : nullptr;
}
analyzer::Analyzer* Connection::FindAnalyzer(const zeek::Tag& tag) analyzer::Analyzer* Connection::FindAnalyzer(const zeek::Tag& tag) {
{ return adapter ? adapter->FindChild(tag) : nullptr;
return adapter ? adapter->FindChild(tag) : nullptr; }
}
analyzer::Analyzer* Connection::FindAnalyzer(const char* name) analyzer::Analyzer* Connection::FindAnalyzer(const char* name) { return adapter->FindChild(name); }
{
return adapter->FindChild(name);
}
void Connection::AppendAddl(const char* str) void Connection::AppendAddl(const char* str) {
{ const auto& cv = GetVal();
const auto& cv = GetVal();
const char* old = cv->GetFieldAs<StringVal>(6)->CheckString(); const char* old = cv->GetFieldAs<StringVal>(6)->CheckString();
const char* format = *old ? "%s %s" : "%s%s"; const char* format = *old ? "%s %s" : "%s%s";
cv->Assign(6, util::fmt(format, old, str)); cv->Assign(6, util::fmt(format, old, str));
} }
void Connection::Match(detail::Rule::PatternType type, const u_char* data, int len, bool is_orig, void Connection::Match(detail::Rule::PatternType type, const u_char* data, int len, bool is_orig, bool bol, bool eol,
bool bol, bool eol, bool clear_state) bool clear_state) {
{ if ( primary_PIA )
if ( primary_PIA ) primary_PIA->Match(type, data, len, is_orig, bol, eol, clear_state);
primary_PIA->Match(type, data, len, is_orig, bol, eol, clear_state); }
}
void Connection::RemovalEvent() void Connection::RemovalEvent() {
{ if ( connection_state_remove )
if ( connection_state_remove ) EnqueueEvent(connection_state_remove, nullptr, GetVal());
EnqueueEvent(connection_state_remove, nullptr, GetVal()); }
}
void Connection::Weird(const char* name, const char* addl, const char* source) void Connection::Weird(const char* name, const char* addl, const char* source) {
{ weird = 1;
weird = 1; reporter->Weird(this, name, addl ? addl : "", source ? source : "");
reporter->Weird(this, name, addl ? addl : "", source ? source : ""); }
}
void Connection::FlipRoles() void Connection::FlipRoles() {
{ IPAddr tmp_addr = resp_addr;
IPAddr tmp_addr = resp_addr; resp_addr = orig_addr;
resp_addr = orig_addr; orig_addr = tmp_addr;
orig_addr = tmp_addr;
uint32_t tmp_port = resp_port; uint32_t tmp_port = resp_port;
resp_port = orig_port; resp_port = orig_port;
orig_port = tmp_port; orig_port = tmp_port;
const int l2_len = sizeof(orig_l2_addr); const int l2_len = sizeof(orig_l2_addr);
u_char tmp_l2_addr[l2_len]; u_char tmp_l2_addr[l2_len];
memcpy(tmp_l2_addr, resp_l2_addr, l2_len); memcpy(tmp_l2_addr, resp_l2_addr, l2_len);
memcpy(resp_l2_addr, orig_l2_addr, l2_len); memcpy(resp_l2_addr, orig_l2_addr, l2_len);
memcpy(orig_l2_addr, tmp_l2_addr, l2_len); memcpy(orig_l2_addr, tmp_l2_addr, l2_len);
bool tmp_bool = saw_first_resp_packet; bool tmp_bool = saw_first_resp_packet;
saw_first_resp_packet = saw_first_orig_packet; saw_first_resp_packet = saw_first_orig_packet;
saw_first_orig_packet = tmp_bool; saw_first_orig_packet = tmp_bool;
uint32_t tmp_flow = resp_flow_label; uint32_t tmp_flow = resp_flow_label;
resp_flow_label = orig_flow_label; resp_flow_label = orig_flow_label;
orig_flow_label = tmp_flow; orig_flow_label = tmp_flow;
if ( conn_val ) if ( conn_val )
flip_conn_val(conn_val); flip_conn_val(conn_val);
if ( adapter ) if ( adapter )
adapter->FlipRoles(); adapter->FlipRoles();
analyzer_mgr->ApplyScheduledAnalyzers(this); analyzer_mgr->ApplyScheduledAnalyzers(this);
AddHistory('^'); AddHistory('^');
if ( connection_flipped ) if ( connection_flipped )
EnqueueEvent(connection_flipped, nullptr, GetVal()); EnqueueEvent(connection_flipped, nullptr, GetVal());
} }
void Connection::Describe(ODesc* d) const void Connection::Describe(ODesc* d) const {
{ session::Session::Describe(d);
session::Session::Describe(d);
switch ( proto ) switch ( proto ) {
{ case TRANSPORT_TCP: d->Add("TCP"); break;
case TRANSPORT_TCP:
d->Add("TCP");
break;
case TRANSPORT_UDP: case TRANSPORT_UDP: d->Add("UDP"); break;
d->Add("UDP");
break;
case TRANSPORT_ICMP: case TRANSPORT_ICMP: d->Add("ICMP"); break;
d->Add("ICMP");
break;
case TRANSPORT_UNKNOWN: case TRANSPORT_UNKNOWN:
d->Add("unknown"); d->Add("unknown");
reporter->InternalWarning("unknown transport in Connection::Describe()"); reporter->InternalWarning("unknown transport in Connection::Describe()");
break; break;
default: default: reporter->InternalError("unhandled transport type in Connection::Describe");
reporter->InternalError("unhandled transport type in Connection::Describe"); }
}
d->SP(); d->SP();
d->Add(orig_addr); d->Add(orig_addr);
d->Add(":"); d->Add(":");
d->Add(ntohs(orig_port)); d->Add(ntohs(orig_port));
d->SP(); d->SP();
d->AddSP("->"); d->AddSP("->");
d->Add(resp_addr); d->Add(resp_addr);
d->Add(":"); d->Add(":");
d->Add(ntohs(resp_port)); d->Add(ntohs(resp_port));
d->NL(); d->NL();
} }
void Connection::IDString(ODesc* d) const void Connection::IDString(ODesc* d) const {
{ d->Add(orig_addr);
d->Add(orig_addr); d->AddRaw(":", 1);
d->AddRaw(":", 1); d->Add(ntohs(orig_port));
d->Add(ntohs(orig_port)); d->AddRaw(" > ", 3);
d->AddRaw(" > ", 3); d->Add(resp_addr);
d->Add(resp_addr); d->AddRaw(":", 1);
d->AddRaw(":", 1); d->Add(ntohs(resp_port));
d->Add(ntohs(resp_port)); }
}
void Connection::SetSessionAdapter(packet_analysis::IP::SessionAdapter* aa, analyzer::pia::PIA* pia) void Connection::SetSessionAdapter(packet_analysis::IP::SessionAdapter* aa, analyzer::pia::PIA* pia) {
{ adapter = aa;
adapter = aa; primary_PIA = pia;
primary_PIA = pia; }
}
void Connection::CheckFlowLabel(bool is_orig, uint32_t flow_label) void Connection::CheckFlowLabel(bool is_orig, uint32_t flow_label) {
{ uint32_t& my_flow_label = is_orig ? orig_flow_label : resp_flow_label;
uint32_t& my_flow_label = is_orig ? orig_flow_label : resp_flow_label;
if ( my_flow_label != flow_label ) if ( my_flow_label != flow_label ) {
{ if ( conn_val ) {
if ( conn_val ) RecordVal* endp = conn_val->GetFieldAs<RecordVal>(is_orig ? 1 : 2);
{ endp->Assign(4, flow_label);
RecordVal* endp = conn_val->GetFieldAs<RecordVal>(is_orig ? 1 : 2); }
endp->Assign(4, flow_label);
}
if ( connection_flow_label_changed && if ( connection_flow_label_changed && (is_orig ? saw_first_orig_packet : saw_first_resp_packet) ) {
(is_orig ? saw_first_orig_packet : saw_first_resp_packet) ) EnqueueEvent(connection_flow_label_changed, nullptr, GetVal(), val_mgr->Bool(is_orig),
{ val_mgr->Count(my_flow_label), val_mgr->Count(flow_label));
EnqueueEvent(connection_flow_label_changed, nullptr, GetVal(), val_mgr->Bool(is_orig), }
val_mgr->Count(my_flow_label), val_mgr->Count(flow_label));
}
my_flow_label = flow_label; my_flow_label = flow_label;
} }
if ( is_orig ) if ( is_orig )
saw_first_orig_packet = 1; saw_first_orig_packet = 1;
else else
saw_first_resp_packet = 1; saw_first_resp_packet = 1;
} }
bool Connection::PermitWeird(const char* name, uint64_t threshold, uint64_t rate, double duration) bool Connection::PermitWeird(const char* name, uint64_t threshold, uint64_t rate, double duration) {
{ return detail::PermitWeird(weird_state, name, threshold, rate, duration);
return detail::PermitWeird(weird_state, name, threshold, rate, duration); }
}
} // namespace zeek } // namespace zeek

View file

@ -19,8 +19,7 @@
#include "zeek/iosource/Packet.h" #include "zeek/iosource/Packet.h"
#include "zeek/session/Session.h" #include "zeek/session/Session.h"
namespace zeek namespace zeek {
{
class Connection; class Connection;
class EncapsulationStack; class EncapsulationStack;
@ -30,253 +29,235 @@ class RecordVal;
using ValPtr = IntrusivePtr<Val>; using ValPtr = IntrusivePtr<Val>;
using RecordValPtr = IntrusivePtr<RecordVal>; using RecordValPtr = IntrusivePtr<RecordVal>;
namespace session namespace session {
{
class Manager; class Manager;
} }
namespace detail namespace detail {
{
class Specific_RE_Matcher; class Specific_RE_Matcher;
class RuleEndpointState; class RuleEndpointState;
class RuleHdrTest; class RuleHdrTest;
} // namespace detail } // namespace detail
namespace analyzer namespace analyzer {
{
class Analyzer; class Analyzer;
} }
namespace packet_analysis::IP namespace packet_analysis::IP {
{
class SessionAdapter; class SessionAdapter;
} }
enum ConnEventToFlag enum ConnEventToFlag {
{ NUL_IN_LINE,
NUL_IN_LINE, SINGULAR_CR,
SINGULAR_CR, SINGULAR_LF,
SINGULAR_LF, NUM_EVENTS_TO_FLAG,
NUM_EVENTS_TO_FLAG, };
};
struct ConnTuple struct ConnTuple {
{ IPAddr src_addr;
IPAddr src_addr; IPAddr dst_addr;
IPAddr dst_addr; uint32_t src_port = 0;
uint32_t src_port = 0; uint32_t dst_port = 0;
uint32_t dst_port = 0; bool is_one_way = false; // if true, don't canonicalize order
bool is_one_way = false; // if true, don't canonicalize order TransportProto proto = TRANSPORT_UNKNOWN;
TransportProto proto = TRANSPORT_UNKNOWN; };
};
static inline int addr_port_canon_lt(const IPAddr& addr1, uint32_t p1, const IPAddr& addr2, static inline int addr_port_canon_lt(const IPAddr& addr1, uint32_t p1, const IPAddr& addr2, uint32_t p2) {
uint32_t p2) return addr1 < addr2 || (addr1 == addr2 && p1 < p2);
{ }
return addr1 < addr2 || (addr1 == addr2 && p1 < p2);
}
class Connection final : public session::Session class Connection final : public session::Session {
{
public: public:
Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, const Packet* pkt);
const Packet* pkt); ~Connection() override;
~Connection() override;
/** /**
* Invoked when an encapsulation is discovered. It records the encapsulation * Invoked when an encapsulation is discovered. It records the encapsulation
* with the connection and raises a "tunnel_changed" event if it's different * with the connection and raises a "tunnel_changed" event if it's different
* from the previous encapsulation or if it's the first one encountered. * from the previous encapsulation or if it's the first one encountered.
* *
* @param encap The new encapsulation. Can be set to null to indicated no * @param encap The new encapsulation. Can be set to null to indicated no
* encapsulation or clear an old one. * encapsulation or clear an old one.
*/ */
void CheckEncapsulation(const std::shared_ptr<EncapsulationStack>& encap); void CheckEncapsulation(const std::shared_ptr<EncapsulationStack>& encap);
/** /**
* Invoked when the session is about to be removed. Use Ref(this) * Invoked when the session is about to be removed. Use Ref(this)
* inside Done to keep the session object around, though it'll * inside Done to keep the session object around, though it'll
* no longer be accessible from the SessionManager. * no longer be accessible from the SessionManager.
*/ */
void Done() override; void Done() override;
// Process the connection's next packet. "data" points just // Process the connection's next packet. "data" points just
// beyond the IP header. It's updated to point just beyond // beyond the IP header. It's updated to point just beyond
// the transport header (or whatever should be saved, if we // the transport header (or whatever should be saved, if we
// decide not to save the full packet contents). // decide not to save the full packet contents).
// //
// If record_packet is true, the packet should be recorded. // If record_packet is true, the packet should be recorded.
// If record_content is true, then its entire contents should // If record_content is true, then its entire contents should
// be recorded, otherwise just up through the transport header. // be recorded, otherwise just up through the transport header.
// Both are assumed set to true when called. // Both are assumed set to true when called.
void NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, int caplen, void NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, int caplen, const u_char*& data,
const u_char*& data, int& record_packet, int& record_content, int& record_packet, int& record_content,
// arguments for reproducing packets // arguments for reproducing packets
const Packet* pkt); const Packet* pkt);
// Keys are only considered valid for a connection when a // Keys are only considered valid for a connection when a
// connection is in the session map. If it is removed, the key // connection is in the session map. If it is removed, the key
// should be marked invalid. // should be marked invalid.
const detail::ConnKey& Key() const { return key; } const detail::ConnKey& Key() const { return key; }
session::detail::Key SessionKey(bool copy) const override session::detail::Key SessionKey(bool copy) const override {
{ return session::detail::Key{&key, sizeof(key), session::detail::Key::CONNECTION_KEY_TYPE, copy};
return session::detail::Key{&key, sizeof(key), session::detail::Key::CONNECTION_KEY_TYPE, }
copy};
}
const IPAddr& OrigAddr() const { return orig_addr; } const IPAddr& OrigAddr() const { return orig_addr; }
const IPAddr& RespAddr() const { return resp_addr; } const IPAddr& RespAddr() const { return resp_addr; }
uint32_t OrigPort() const { return orig_port; } uint32_t OrigPort() const { return orig_port; }
uint32_t RespPort() const { return resp_port; } uint32_t RespPort() const { return resp_port; }
void FlipRoles(); void FlipRoles();
analyzer::Analyzer* FindAnalyzer(analyzer::ID id); analyzer::Analyzer* FindAnalyzer(analyzer::ID id);
analyzer::Analyzer* FindAnalyzer(const zeek::Tag& tag); // find first in tree. analyzer::Analyzer* FindAnalyzer(const zeek::Tag& tag); // find first in tree.
analyzer::Analyzer* FindAnalyzer(const char* name); // find first in tree. analyzer::Analyzer* FindAnalyzer(const char* name); // find first in tree.
TransportProto ConnTransport() const { return proto; } TransportProto ConnTransport() const { return proto; }
std::string TransportIdentifier() const override std::string TransportIdentifier() const override {
{ if ( proto == TRANSPORT_TCP )
if ( proto == TRANSPORT_TCP ) return "tcp";
return "tcp"; else if ( proto == TRANSPORT_UDP )
else if ( proto == TRANSPORT_UDP ) return "udp";
return "udp"; else if ( proto == TRANSPORT_ICMP )
else if ( proto == TRANSPORT_ICMP ) return "icmp";
return "icmp"; else
else return "unknown";
return "unknown"; }
}
// Returns true if the packet reflects a reuse of this // Returns true if the packet reflects a reuse of this
// connection (i.e., not a continuation but the beginning of // connection (i.e., not a continuation but the beginning of
// a new connection). // a new connection).
bool IsReuse(double t, const u_char* pkt); bool IsReuse(double t, const u_char* pkt);
/** /**
* Returns the associated "connection" record. * Returns the associated "connection" record.
*/ */
const RecordValPtr& GetVal() override; const RecordValPtr& GetVal() override;
/** /**
* Append additional entries to the history field in the connection record. * Append additional entries to the history field in the connection record.
*/ */
void AppendAddl(const char* str); void AppendAddl(const char* str);
void Match(detail::Rule::PatternType type, const u_char* data, int len, bool is_orig, bool bol, void Match(detail::Rule::PatternType type, const u_char* data, int len, bool is_orig, bool bol, bool eol,
bool eol, bool clear_state); bool clear_state);
/** /**
* Generates connection removal event(s). * Generates connection removal event(s).
*/ */
void RemovalEvent() override; void RemovalEvent() override;
void Weird(const char* name, const char* addl = "", const char* source = ""); void Weird(const char* name, const char* addl = "", const char* source = "");
bool DidWeird() const { return weird != 0; } bool DidWeird() const { return weird != 0; }
inline bool FlagEvent(ConnEventToFlag e) inline bool FlagEvent(ConnEventToFlag e) {
{ if ( e >= 0 && e < NUM_EVENTS_TO_FLAG ) {
if ( e >= 0 && e < NUM_EVENTS_TO_FLAG ) if ( suppress_event & (1 << e) )
{ return false;
if ( suppress_event & (1 << e) ) suppress_event |= 1 << e;
return false; }
suppress_event |= 1 << e;
}
return true; return true;
} }
void Describe(ODesc* d) const override; void Describe(ODesc* d) const override;
void IDString(ODesc* d) const; void IDString(ODesc* d) const;
// Statistics. // Statistics.
static uint64_t TotalConnections() { return total_connections; } static uint64_t TotalConnections() { return total_connections; }
static uint64_t CurrentConnections() { return current_connections; } static uint64_t CurrentConnections() { return current_connections; }
// Returns true if the history was already seen, false otherwise. // Returns true if the history was already seen, false otherwise.
bool CheckHistory(uint32_t mask, char code) bool CheckHistory(uint32_t mask, char code) {
{ if ( (hist_seen & mask) == 0 ) {
if ( (hist_seen & mask) == 0 ) hist_seen |= mask;
{ AddHistory(code);
hist_seen |= mask; return false;
AddHistory(code); }
return false; else
} return true;
else }
return true;
}
// Increments the passed counter and adds it as a history // Increments the passed counter and adds it as a history
// code if it has crossed the next scaling threshold. Scaling // code if it has crossed the next scaling threshold. Scaling
// is done in terms of powers of the third argument. // is done in terms of powers of the third argument.
// Returns true if the threshold was crossed, false otherwise. // Returns true if the threshold was crossed, false otherwise.
bool ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scaling_threshold, bool ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scaling_threshold, uint32_t scaling_base = 10);
uint32_t scaling_base = 10);
void HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t threshold); void HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t threshold);
void AddHistory(char code) { history += code; } void AddHistory(char code) { history += code; }
const std::string& GetHistory() const { return history; } const std::string& GetHistory() const { return history; }
void ReplaceHistory(std::string new_h) { history = std::move(new_h); } void ReplaceHistory(std::string new_h) { history = std::move(new_h); }
// Sets the root of the analyzer tree as well as the primary PIA. // Sets the root of the analyzer tree as well as the primary PIA.
void SetSessionAdapter(packet_analysis::IP::SessionAdapter* aa, analyzer::pia::PIA* pia); void SetSessionAdapter(packet_analysis::IP::SessionAdapter* aa, analyzer::pia::PIA* pia);
packet_analysis::IP::SessionAdapter* GetSessionAdapter() { return adapter; } packet_analysis::IP::SessionAdapter* GetSessionAdapter() { return adapter; }
analyzer::pia::PIA* GetPrimaryPIA() { return primary_PIA; } analyzer::pia::PIA* GetPrimaryPIA() { return primary_PIA; }
// Sets the transport protocol in use. // Sets the transport protocol in use.
void SetTransport(TransportProto arg_proto) { proto = arg_proto; } void SetTransport(TransportProto arg_proto) { proto = arg_proto; }
void SetUID(const UID& arg_uid) { uid = arg_uid; } void SetUID(const UID& arg_uid) { uid = arg_uid; }
UID GetUID() const { return uid; } UID GetUID() const { return uid; }
std::shared_ptr<EncapsulationStack> GetEncapsulation() const { return encapsulation; } std::shared_ptr<EncapsulationStack> GetEncapsulation() const { return encapsulation; }
void CheckFlowLabel(bool is_orig, uint32_t flow_label); void CheckFlowLabel(bool is_orig, uint32_t flow_label);
uint32_t GetOrigFlowLabel() { return orig_flow_label; } uint32_t GetOrigFlowLabel() { return orig_flow_label; }
uint32_t GetRespFlowLabel() { return resp_flow_label; } uint32_t GetRespFlowLabel() { return resp_flow_label; }
bool PermitWeird(const char* name, uint64_t threshold, uint64_t rate, double duration); bool PermitWeird(const char* name, uint64_t threshold, uint64_t rate, double duration);
private: private:
friend class session::detail::Timer; friend class session::detail::Timer;
IPAddr orig_addr; IPAddr orig_addr;
IPAddr resp_addr; IPAddr resp_addr;
uint32_t orig_port, resp_port; // in network order uint32_t orig_port, resp_port; // in network order
TransportProto proto; TransportProto proto;
uint32_t orig_flow_label, resp_flow_label; // most recent IPv6 flow labels uint32_t orig_flow_label, resp_flow_label; // most recent IPv6 flow labels
uint32_t vlan, inner_vlan; // VLAN this connection traverses, if available uint32_t vlan, inner_vlan; // VLAN this connection traverses, if available
u_char orig_l2_addr[Packet::L2_ADDR_LEN]; // Link-layer originator address, if available u_char orig_l2_addr[Packet::L2_ADDR_LEN]; // Link-layer originator address, if available
u_char resp_l2_addr[Packet::L2_ADDR_LEN]; // Link-layer responder address, if available u_char resp_l2_addr[Packet::L2_ADDR_LEN]; // Link-layer responder address, if available
int suppress_event; // suppress certain events to once per conn. int suppress_event; // suppress certain events to once per conn.
RecordValPtr conn_val; RecordValPtr conn_val;
std::shared_ptr<EncapsulationStack> encapsulation; // tunnels std::shared_ptr<EncapsulationStack> encapsulation; // tunnels
uint8_t tunnel_changes = 0; uint8_t tunnel_changes = 0;
detail::ConnKey key; detail::ConnKey key;
unsigned int weird : 1; unsigned int weird : 1;
unsigned int finished : 1; unsigned int finished : 1;
unsigned int saw_first_orig_packet : 1, saw_first_resp_packet : 1; unsigned int saw_first_orig_packet : 1, saw_first_resp_packet : 1;
uint32_t hist_seen; uint32_t hist_seen;
std::string history; std::string history;
packet_analysis::IP::SessionAdapter* adapter; packet_analysis::IP::SessionAdapter* adapter;
analyzer::pia::PIA* primary_PIA; analyzer::pia::PIA* primary_PIA;
UID uid; // Globally unique connection ID. UID uid; // Globally unique connection ID.
detail::WeirdStateMap weird_state; detail::WeirdStateMap weird_state;
// Count number of connections. // Count number of connections.
static uint64_t total_connections; static uint64_t total_connections;
static uint64_t current_connections; static uint64_t current_connections;
}; };
} // namespace zeek } // namespace zeek

View file

@ -8,455 +8,390 @@
#include "zeek/EquivClass.h" #include "zeek/EquivClass.h"
#include "zeek/Hash.h" #include "zeek/Hash.h"
namespace zeek::detail namespace zeek::detail {
{
DFA_State::DFA_State(int arg_state_num, const EquivClass* ec, NFA_state_list* arg_nfa_states, DFA_State::DFA_State(int arg_state_num, const EquivClass* ec, NFA_state_list* arg_nfa_states,
AcceptingSet* arg_accept) AcceptingSet* arg_accept) {
{ state_num = arg_state_num;
state_num = arg_state_num; num_sym = ec->NumClasses();
num_sym = ec->NumClasses(); nfa_states = arg_nfa_states;
nfa_states = arg_nfa_states; accept = arg_accept;
accept = arg_accept; mark = nullptr;
mark = nullptr;
SymPartition(ec);
SymPartition(ec);
xtions = new DFA_State*[num_sym];
xtions = new DFA_State*[num_sym];
for ( int i = 0; i < num_sym; ++i )
for ( int i = 0; i < num_sym; ++i ) xtions[i] = DFA_UNCOMPUTED_STATE_PTR;
xtions[i] = DFA_UNCOMPUTED_STATE_PTR; }
}
DFA_State::~DFA_State() {
DFA_State::~DFA_State() delete[] xtions;
{ delete nfa_states;
delete[] xtions; delete accept;
delete nfa_states; delete meta_ec;
delete accept; }
delete meta_ec;
} void DFA_State::AddXtion(int sym, DFA_State* next_state) { xtions[sym] = next_state; }
void DFA_State::AddXtion(int sym, DFA_State* next_state) void DFA_State::SymPartition(const EquivClass* ec) {
{ // Partitioning is done by creating equivalence classes for those
xtions[sym] = next_state; // characters which have out-transitions from the given state. Thus
} // we are really creating equivalence classes of equivalence classes.
meta_ec = new EquivClass(ec->NumClasses());
void DFA_State::SymPartition(const EquivClass* ec)
{ assert(nfa_states);
// Partitioning is done by creating equivalence classes for those for ( int i = 0; i < nfa_states->length(); ++i ) {
// characters which have out-transitions from the given state. Thus NFA_State* n = (*nfa_states)[i];
// we are really creating equivalence classes of equivalence classes. int sym = n->TransSym();
meta_ec = new EquivClass(ec->NumClasses());
if ( sym == SYM_EPSILON )
assert(nfa_states); continue;
for ( int i = 0; i < nfa_states->length(); ++i )
{ if ( sym != SYM_CCL ) { // character transition
NFA_State* n = (*nfa_states)[i]; if ( ec->IsRep(sym) ) {
int sym = n->TransSym(); sym = ec->SymEquivClass(sym);
meta_ec->UniqueChar(sym);
if ( sym == SYM_EPSILON ) }
continue; continue;
}
if ( sym != SYM_CCL )
{ // character transition // Character class.
if ( ec->IsRep(sym) ) meta_ec->CCL_Use(n->TransCCL());
{ }
sym = ec->SymEquivClass(sym);
meta_ec->UniqueChar(sym); meta_ec->BuildECs();
} }
continue;
} DFA_State* DFA_State::ComputeXtion(int sym, DFA_Machine* machine) {
int equiv_sym = meta_ec->EquivRep(sym);
// Character class. if ( xtions[equiv_sym] != DFA_UNCOMPUTED_STATE_PTR ) {
meta_ec->CCL_Use(n->TransCCL()); AddXtion(sym, xtions[equiv_sym]);
} return xtions[sym];
}
meta_ec->BuildECs();
} const EquivClass* ec = machine->EC();
DFA_State* DFA_State::ComputeXtion(int sym, DFA_Machine* machine) DFA_State* next_d;
{
int equiv_sym = meta_ec->EquivRep(sym); NFA_state_list* ns = SymFollowSet(equiv_sym, ec);
if ( xtions[equiv_sym] != DFA_UNCOMPUTED_STATE_PTR ) if ( ns->length() > 0 ) {
{ NFA_state_list* state_set = epsilon_closure(ns);
AddXtion(sym, xtions[equiv_sym]); if ( ! machine->StateSetToDFA_State(state_set, next_d, ec) )
return xtions[sym]; delete state_set;
} }
else {
const EquivClass* ec = machine->EC(); delete ns;
next_d = nullptr; // Jam
DFA_State* next_d; }
NFA_state_list* ns = SymFollowSet(equiv_sym, ec); AddXtion(equiv_sym, next_d);
if ( ns->length() > 0 ) if ( sym != equiv_sym )
{ AddXtion(sym, next_d);
NFA_state_list* state_set = epsilon_closure(ns);
if ( ! machine->StateSetToDFA_State(state_set, next_d, ec) ) return xtions[sym];
delete state_set; }
}
else void DFA_State::AppendIfNew(int sym, int_list* sym_list) {
{ for ( auto value : *sym_list )
delete ns; if ( value == sym )
next_d = nullptr; // Jam return;
}
sym_list->push_back(sym);
AddXtion(equiv_sym, next_d); }
if ( sym != equiv_sym )
AddXtion(sym, next_d); NFA_state_list* DFA_State::SymFollowSet(int ec_sym, const EquivClass* ec) {
NFA_state_list* ns = new NFA_state_list;
return xtions[sym];
} assert(nfa_states);
void DFA_State::AppendIfNew(int sym, int_list* sym_list) for ( int i = 0; i < nfa_states->length(); ++i ) {
{ NFA_State* n = (*nfa_states)[i];
for ( auto value : *sym_list )
if ( value == sym ) if ( n->TransSym() == SYM_CCL ) { // it's a character class
return; CCL* ccl = n->TransCCL();
int_list* syms = ccl->Syms();
sym_list->push_back(sym);
} if ( ccl->IsNegated() ) {
size_t j;
NFA_state_list* DFA_State::SymFollowSet(int ec_sym, const EquivClass* ec) for ( j = 0; j < syms->size(); ++j ) {
{ // Loop through (sorted) negated
NFA_state_list* ns = new NFA_state_list; // character class, which has
// presumably already been converted
assert(nfa_states); // over to equivalence classes.
if ( (*syms)[j] >= ec_sym )
for ( int i = 0; i < nfa_states->length(); ++i ) break;
{ }
NFA_State* n = (*nfa_states)[i];
if ( j >= syms->size() || (*syms)[j] > ec_sym )
if ( n->TransSym() == SYM_CCL ) // Didn't find ec_sym in ccl.
{ // it's a character class n->AddXtionsTo(ns);
CCL* ccl = n->TransCCL();
int_list* syms = ccl->Syms(); continue;
}
if ( ccl->IsNegated() )
{ for ( auto sym : *syms ) {
size_t j; if ( sym > ec_sym )
for ( j = 0; j < syms->size(); ++j ) break;
{
// Loop through (sorted) negated if ( sym == ec_sym ) {
// character class, which has n->AddXtionsTo(ns);
// presumably already been converted break;
// over to equivalence classes. }
if ( (*syms)[j] >= ec_sym ) }
break; }
}
else if ( n->TransSym() == SYM_EPSILON ) { // do nothing
if ( j >= syms->size() || (*syms)[j] > ec_sym ) }
// Didn't find ec_sym in ccl.
n->AddXtionsTo(ns); else if ( ec->IsRep(n->TransSym()) ) {
if ( ec_sym == ec->SymEquivClass(n->TransSym()) )
continue; n->AddXtionsTo(ns);
} }
}
for ( auto sym : *syms )
{ ns->resize(0);
if ( sym > ec_sym ) return ns;
break; }
if ( sym == ec_sym ) void DFA_State::ClearMarks() {
{ if ( mark ) {
n->AddXtionsTo(ns); SetMark(nullptr);
break;
} for ( int i = 0; i < num_sym; ++i ) {
} DFA_State* s = xtions[i];
}
if ( s && s != DFA_UNCOMPUTED_STATE_PTR )
else if ( n->TransSym() == SYM_EPSILON ) xtions[i]->ClearMarks();
{ // do nothing }
} }
}
else if ( ec->IsRep(n->TransSym()) )
{ void DFA_State::Describe(ODesc* d) const { d->Add("DFA state"); }
if ( ec_sym == ec->SymEquivClass(n->TransSym()) )
n->AddXtionsTo(ns); void DFA_State::Dump(FILE* f, DFA_Machine* m) {
} if ( mark )
} return;
ns->resize(0); fprintf(f, "\nDFA state %d:", StateNum());
return ns;
} if ( accept ) {
AcceptingSet::const_iterator it;
void DFA_State::ClearMarks()
{ for ( it = accept->begin(); it != accept->end(); ++it )
if ( mark ) fprintf(f, "%s accept #%d", it == accept->begin() ? "" : ",", *it);
{ }
SetMark(nullptr);
fprintf(f, "\n");
for ( int i = 0; i < num_sym; ++i )
{ int num_trans = 0;
DFA_State* s = xtions[i]; for ( int sym = 0; sym < num_sym; ++sym ) {
DFA_State* s = xtions[sym];
if ( s && s != DFA_UNCOMPUTED_STATE_PTR )
xtions[i]->ClearMarks(); if ( ! s )
} continue;
}
} // Look ahead for compression.
int i;
void DFA_State::Describe(ODesc* d) const for ( i = sym + 1; i < num_sym; ++i )
{ if ( xtions[i] != s )
d->Add("DFA state"); break;
}
constexpr int xbuf_size = 512;
void DFA_State::Dump(FILE* f, DFA_Machine* m) char* xbuf = new char[xbuf_size];
{
if ( mark ) int r = m->Rep(sym);
return; if ( ! r )
r = '.';
fprintf(f, "\nDFA state %d:", StateNum());
if ( i == sym + 1 )
if ( accept ) snprintf(xbuf, xbuf_size, "'%c'", r);
{ else
AcceptingSet::const_iterator it; snprintf(xbuf, xbuf_size, "'%c'-'%c'", r, m->Rep(i - 1));
for ( it = accept->begin(); it != accept->end(); ++it ) if ( s == DFA_UNCOMPUTED_STATE_PTR )
fprintf(f, "%s accept #%d", it == accept->begin() ? "" : ",", *it); fprintf(f, "%stransition on %s to <uncomputed>", ++num_trans == 1 ? "\t" : "\n\t", xbuf);
} else
fprintf(f, "%stransition on %s to state %d", ++num_trans == 1 ? "\t" : "\n\t", xbuf, s->StateNum());
fprintf(f, "\n");
delete[] xbuf;
int num_trans = 0;
for ( int sym = 0; sym < num_sym; ++sym ) sym = i - 1;
{ }
DFA_State* s = xtions[sym];
if ( num_trans > 0 )
if ( ! s ) fprintf(f, "\n");
continue;
SetMark(this);
// Look ahead for compression.
int i; for ( int sym = 0; sym < num_sym; ++sym ) {
for ( i = sym + 1; i < num_sym; ++i ) DFA_State* s = xtions[sym];
if ( xtions[i] != s )
break; if ( s && s != DFA_UNCOMPUTED_STATE_PTR )
s->Dump(f, m);
constexpr int xbuf_size = 512; }
char* xbuf = new char[xbuf_size]; }
int r = m->Rep(sym); void DFA_State::Stats(unsigned int* computed, unsigned int* uncomputed) {
if ( ! r ) for ( int sym = 0; sym < num_sym; ++sym ) {
r = '.'; DFA_State* s = xtions[sym];
if ( i == sym + 1 ) if ( s == DFA_UNCOMPUTED_STATE_PTR )
snprintf(xbuf, xbuf_size, "'%c'", r); (*uncomputed)++;
else else
snprintf(xbuf, xbuf_size, "'%c'-'%c'", r, m->Rep(i - 1)); (*computed)++;
}
if ( s == DFA_UNCOMPUTED_STATE_PTR ) }
fprintf(f, "%stransition on %s to <uncomputed>", ++num_trans == 1 ? "\t" : "\n\t",
xbuf); unsigned int DFA_State::Size() {
else return sizeof(*this) + util::pad_size(sizeof(DFA_State*) * num_sym) +
fprintf(f, "%stransition on %s to state %d", ++num_trans == 1 ? "\t" : "\n\t", xbuf, (accept ? util::pad_size(sizeof(int) * accept->size()) : 0) +
s->StateNum()); (nfa_states ? util::pad_size(sizeof(NFA_State*) * nfa_states->length()) : 0) +
(meta_ec ? meta_ec->Size() : 0);
delete[] xbuf; }
sym = i - 1; DFA_State_Cache::DFA_State_Cache() { hits = misses = 0; }
}
DFA_State_Cache::~DFA_State_Cache() {
if ( num_trans > 0 ) for ( auto& entry : states ) {
fprintf(f, "\n"); assert(entry.second);
Unref(entry.second);
SetMark(this); }
for ( int sym = 0; sym < num_sym; ++sym ) states.clear();
{ }
DFA_State* s = xtions[sym];
DFA_State* DFA_State_Cache::Lookup(const NFA_state_list& nfas, DigestStr* digest) {
if ( s && s != DFA_UNCOMPUTED_STATE_PTR ) // We assume that state ID's don't exceed 10 digits, plus
s->Dump(f, m); // we allow one more character for the delimiter.
} auto id_tag_buf = std::make_unique<u_char[]>(nfas.length() * 11 + 1);
} auto id_tag = id_tag_buf.get();
u_char* p = id_tag;
void DFA_State::Stats(unsigned int* computed, unsigned int* uncomputed)
{ for ( int i = 0; i < nfas.length(); ++i ) {
for ( int sym = 0; sym < num_sym; ++sym ) NFA_State* n = nfas[i];
{ if ( n->TransSym() != SYM_EPSILON || n->Accept() != NO_ACCEPT ) {
DFA_State* s = xtions[sym]; int id = n->ID();
do {
if ( s == DFA_UNCOMPUTED_STATE_PTR ) *p++ = '0' + (char)(id % 10);
(*uncomputed)++; id /= 10;
else } while ( id > 0 );
(*computed)++; *p++ = '&';
} }
} }
unsigned int DFA_State::Size() *p++ = '\0';
{
return sizeof(*this) + util::pad_size(sizeof(DFA_State*) * num_sym) + // We use the short MD5 instead of the full string for the
(accept ? util::pad_size(sizeof(int) * accept->size()) : 0) + // HashKey because the data is copied into the key.
(nfa_states ? util::pad_size(sizeof(NFA_State*) * nfa_states->length()) : 0) + hash128_t hash;
(meta_ec ? meta_ec->Size() : 0); KeyedHash::Hash128(id_tag, p - id_tag, &hash);
} *digest = DigestStr(reinterpret_cast<const unsigned char*>(hash), 16);
DFA_State_Cache::DFA_State_Cache() auto entry = states.find(*digest);
{ if ( entry == states.end() ) {
hits = misses = 0; ++misses;
} return nullptr;
}
DFA_State_Cache::~DFA_State_Cache() ++hits;
{
for ( auto& entry : states ) digest->clear();
{
assert(entry.second); return entry->second;
Unref(entry.second); }
}
DFA_State* DFA_State_Cache::Insert(DFA_State* state, DigestStr digest) {
states.clear(); states.emplace(std::move(digest), state);
} return state;
}
DFA_State* DFA_State_Cache::Lookup(const NFA_state_list& nfas, DigestStr* digest)
{ void DFA_State_Cache::GetStats(Stats* s) {
// We assume that state ID's don't exceed 10 digits, plus s->dfa_states = 0;
// we allow one more character for the delimiter. s->nfa_states = 0;
auto id_tag_buf = std::make_unique<u_char[]>(nfas.length() * 11 + 1); s->computed = 0;
auto id_tag = id_tag_buf.get(); s->uncomputed = 0;
u_char* p = id_tag; s->mem = 0;
s->hits = hits;
for ( int i = 0; i < nfas.length(); ++i ) s->misses = misses;
{
NFA_State* n = nfas[i]; for ( const auto& state : states ) {
if ( n->TransSym() != SYM_EPSILON || n->Accept() != NO_ACCEPT ) DFA_State* e = state.second;
{ ++s->dfa_states;
int id = n->ID(); s->nfa_states += e->NFAStateNum();
do e->Stats(&s->computed, &s->uncomputed);
{ s->mem += util::pad_size(e->Size()) + padded_sizeof(*e);
*p++ = '0' + (char)(id % 10); }
id /= 10; }
} while ( id > 0 );
*p++ = '&'; DFA_Machine::DFA_Machine(NFA_Machine* n, EquivClass* arg_ec) {
} state_count = 0;
}
nfa = n;
*p++ = '\0'; Ref(n);
// We use the short MD5 instead of the full string for the ec = arg_ec;
// HashKey because the data is copied into the key.
hash128_t hash; dfa_state_cache = new DFA_State_Cache();
KeyedHash::Hash128(id_tag, p - id_tag, &hash);
*digest = DigestStr(reinterpret_cast<const unsigned char*>(hash), 16); NFA_state_list* ns = new NFA_state_list;
ns->push_back(n->FirstState());
auto entry = states.find(*digest);
if ( entry == states.end() ) if ( ns->length() > 0 ) {
{ NFA_state_list* state_set = epsilon_closure(ns);
++misses; StateSetToDFA_State(state_set, start_state, ec);
return nullptr; }
} else {
++hits; start_state = nullptr; // Jam
delete ns;
digest->clear(); }
}
return entry->second;
} DFA_Machine::~DFA_Machine() {
delete dfa_state_cache;
DFA_State* DFA_State_Cache::Insert(DFA_State* state, DigestStr digest) Unref(nfa);
{ }
states.emplace(std::move(digest), state);
return state; void DFA_Machine::Describe(ODesc* d) const { d->Add("DFA machine"); }
}
void DFA_Machine::Dump(FILE* f) {
void DFA_State_Cache::GetStats(Stats* s) start_state->Dump(f, this);
{ start_state->ClearMarks();
s->dfa_states = 0; }
s->nfa_states = 0;
s->computed = 0; bool DFA_Machine::StateSetToDFA_State(NFA_state_list* state_set, DFA_State*& d, const EquivClass* ec) {
s->uncomputed = 0; DigestStr digest;
s->mem = 0; d = dfa_state_cache->Lookup(*state_set, &digest);
s->hits = hits;
s->misses = misses; if ( d )
return false;
for ( const auto& state : states )
{ AcceptingSet* accept = new AcceptingSet;
DFA_State* e = state.second;
++s->dfa_states; for ( int i = 0; i < state_set->length(); ++i ) {
s->nfa_states += e->NFAStateNum(); int acc = (*state_set)[i]->Accept();
e->Stats(&s->computed, &s->uncomputed);
s->mem += util::pad_size(e->Size()) + padded_sizeof(*e); if ( acc != NO_ACCEPT )
} accept->insert(acc);
} }
DFA_Machine::DFA_Machine(NFA_Machine* n, EquivClass* arg_ec) if ( accept->empty() ) {
{ delete accept;
state_count = 0; accept = nullptr;
}
nfa = n;
Ref(n); DFA_State* ds = new DFA_State(state_count++, ec, state_set, accept);
d = dfa_state_cache->Insert(ds, std::move(digest));
ec = arg_ec;
return true;
dfa_state_cache = new DFA_State_Cache(); }
NFA_state_list* ns = new NFA_state_list; int DFA_Machine::Rep(int sym) {
ns->push_back(n->FirstState()); for ( int i = 0; i < NUM_SYM; ++i )
if ( ec->SymEquivClass(i) == sym )
if ( ns->length() > 0 ) return i;
{
NFA_state_list* state_set = epsilon_closure(ns); return -1;
StateSetToDFA_State(state_set, start_state, ec); }
}
else } // namespace zeek::detail
{
start_state = nullptr; // Jam
delete ns;
}
}
DFA_Machine::~DFA_Machine()
{
delete dfa_state_cache;
Unref(nfa);
}
void DFA_Machine::Describe(ODesc* d) const
{
d->Add("DFA machine");
}
void DFA_Machine::Dump(FILE* f)
{
start_state->Dump(f, this);
start_state->ClearMarks();
}
bool DFA_Machine::StateSetToDFA_State(NFA_state_list* state_set, DFA_State*& d,
const EquivClass* ec)
{
DigestStr digest;
d = dfa_state_cache->Lookup(*state_set, &digest);
if ( d )
return false;
AcceptingSet* accept = new AcceptingSet;
for ( int i = 0; i < state_set->length(); ++i )
{
int acc = (*state_set)[i]->Accept();
if ( acc != NO_ACCEPT )
accept->insert(acc);
}
if ( accept->empty() )
{
delete accept;
accept = nullptr;
}
DFA_State* ds = new DFA_State(state_count++, ec, state_set, accept);
d = dfa_state_cache->Insert(ds, std::move(digest));
return true;
}
int DFA_Machine::Rep(int sym)
{
for ( int i = 0; i < NUM_SYM; ++i )
if ( ec->SymEquivClass(i) == sym )
return i;
return -1;
}
} // namespace zeek::detail

173
src/DFA.h
View file

@ -11,8 +11,7 @@
#include "zeek/Obj.h" #include "zeek/Obj.h"
#include "zeek/RE.h" // for typedef AcceptingSet #include "zeek/RE.h" // for typedef AcceptingSet
namespace zeek::detail namespace zeek::detail {
{
class DFA_State; class DFA_State;
class DFA_Machine; class DFA_Machine;
@ -22,132 +21,126 @@ class DFA_Machine;
#define DFA_UNCOMPUTED_STATE -2 #define DFA_UNCOMPUTED_STATE -2
#define DFA_UNCOMPUTED_STATE_PTR ((DFA_State*)DFA_UNCOMPUTED_STATE) #define DFA_UNCOMPUTED_STATE_PTR ((DFA_State*)DFA_UNCOMPUTED_STATE)
class DFA_State : public Obj class DFA_State : public Obj {
{
public: public:
DFA_State(int state_num, const EquivClass* ec, NFA_state_list* nfa_states, DFA_State(int state_num, const EquivClass* ec, NFA_state_list* nfa_states, AcceptingSet* accept);
AcceptingSet* accept); ~DFA_State() override;
~DFA_State() override;
int StateNum() const { return state_num; } int StateNum() const { return state_num; }
int NFAStateNum() const { return nfa_states->length(); } int NFAStateNum() const { return nfa_states->length(); }
void AddXtion(int sym, DFA_State* next_state); void AddXtion(int sym, DFA_State* next_state);
inline DFA_State* Xtion(int sym, DFA_Machine* machine); inline DFA_State* Xtion(int sym, DFA_Machine* machine);
const AcceptingSet* Accept() const { return accept; } const AcceptingSet* Accept() const { return accept; }
void SymPartition(const EquivClass* ec); void SymPartition(const EquivClass* ec);
// ec_sym is an equivalence class, not a character. // ec_sym is an equivalence class, not a character.
NFA_state_list* SymFollowSet(int ec_sym, const EquivClass* ec); NFA_state_list* SymFollowSet(int ec_sym, const EquivClass* ec);
void SetMark(DFA_State* m) { mark = m; } void SetMark(DFA_State* m) { mark = m; }
DFA_State* Mark() const { return mark; } DFA_State* Mark() const { return mark; }
void ClearMarks(); void ClearMarks();
// Returns the equivalence classes of ec's corresponding to this state. // Returns the equivalence classes of ec's corresponding to this state.
const EquivClass* MetaECs() const { return meta_ec; } const EquivClass* MetaECs() const { return meta_ec; }
void Describe(ODesc* d) const override; void Describe(ODesc* d) const override;
void Dump(FILE* f, DFA_Machine* m); void Dump(FILE* f, DFA_Machine* m);
void Stats(unsigned int* computed, unsigned int* uncomputed); void Stats(unsigned int* computed, unsigned int* uncomputed);
unsigned int Size(); unsigned int Size();
protected: protected:
friend class DFA_State_Cache; friend class DFA_State_Cache;
DFA_State* ComputeXtion(int sym, DFA_Machine* machine); DFA_State* ComputeXtion(int sym, DFA_Machine* machine);
void AppendIfNew(int sym, int_list* sym_list); void AppendIfNew(int sym, int_list* sym_list);
int state_num; int state_num;
int num_sym; int num_sym;
DFA_State** xtions; DFA_State** xtions;
AcceptingSet* accept; AcceptingSet* accept;
NFA_state_list* nfa_states; NFA_state_list* nfa_states;
EquivClass* meta_ec; // which ec's make same transition EquivClass* meta_ec; // which ec's make same transition
DFA_State* mark; DFA_State* mark;
}; };
using DigestStr = std::basic_string<u_char>; using DigestStr = std::basic_string<u_char>;
class DFA_State_Cache class DFA_State_Cache {
{
public: public:
DFA_State_Cache(); DFA_State_Cache();
~DFA_State_Cache(); ~DFA_State_Cache();
// If the caller stores the handle, it has to call Ref() on it. // If the caller stores the handle, it has to call Ref() on it.
DFA_State* Lookup(const NFA_state_list& nfa_states, DigestStr* digest); DFA_State* Lookup(const NFA_state_list& nfa_states, DigestStr* digest);
// Takes ownership of state; digest is the one returned by Lookup(). // Takes ownership of state; digest is the one returned by Lookup().
DFA_State* Insert(DFA_State* state, DigestStr digest); DFA_State* Insert(DFA_State* state, DigestStr digest);
int NumEntries() const { return states.size(); } int NumEntries() const { return states.size(); }
struct Stats struct Stats {
{ // Sum of all NFA states
// Sum of all NFA states unsigned int nfa_states;
unsigned int nfa_states; unsigned int dfa_states;
unsigned int dfa_states; unsigned int computed;
unsigned int computed; unsigned int uncomputed;
unsigned int uncomputed; unsigned int mem;
unsigned int mem; unsigned int hits;
unsigned int hits; unsigned int misses;
unsigned int misses; };
};
void GetStats(Stats* s); void GetStats(Stats* s);
private: private:
int hits; // Statistics int hits; // Statistics
int misses; int misses;
// Hash indexed by NFA states (MD5s of them, actually). // Hash indexed by NFA states (MD5s of them, actually).
std::map<DigestStr, DFA_State*> states; std::map<DigestStr, DFA_State*> states;
}; };
class DFA_Machine : public Obj class DFA_Machine : public Obj {
{
public: public:
DFA_Machine(NFA_Machine* n, EquivClass* ec); DFA_Machine(NFA_Machine* n, EquivClass* ec);
~DFA_Machine() override; ~DFA_Machine() override;
DFA_State* StartState() const { return start_state; } DFA_State* StartState() const { return start_state; }
int NumStates() const { return dfa_state_cache->NumEntries(); } int NumStates() const { return dfa_state_cache->NumEntries(); }
DFA_State_Cache* Cache() { return dfa_state_cache; } DFA_State_Cache* Cache() { return dfa_state_cache; }
int Rep(int sym); int Rep(int sym);
void Describe(ODesc* d) const override; void Describe(ODesc* d) const override;
void Dump(FILE* f); void Dump(FILE* f);
protected: protected:
friend class DFA_State; // for DFA_State::ComputeXtion friend class DFA_State; // for DFA_State::ComputeXtion
friend class DFA_State_Cache; friend class DFA_State_Cache;
int state_count; int state_count;
// The state list has to be sorted according to IDs. // The state list has to be sorted according to IDs.
bool StateSetToDFA_State(NFA_state_list* state_set, DFA_State*& d, const EquivClass* ec); bool StateSetToDFA_State(NFA_state_list* state_set, DFA_State*& d, const EquivClass* ec);
const EquivClass* EC() const { return ec; } const EquivClass* EC() const { return ec; }
EquivClass* ec; // equivalence classes corresponding to NFAs EquivClass* ec; // equivalence classes corresponding to NFAs
DFA_State* start_state; DFA_State* start_state;
DFA_State_Cache* dfa_state_cache; DFA_State_Cache* dfa_state_cache;
NFA_Machine* nfa; NFA_Machine* nfa;
}; };
inline DFA_State* DFA_State::Xtion(int sym, DFA_Machine* machine) inline DFA_State* DFA_State::Xtion(int sym, DFA_Machine* machine) {
{ if ( xtions[sym] == DFA_UNCOMPUTED_STATE_PTR )
if ( xtions[sym] == DFA_UNCOMPUTED_STATE_PTR ) return ComputeXtion(sym, machine);
return ComputeXtion(sym, machine); else
else return xtions[sym];
return xtions[sym]; }
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -6,428 +6,399 @@
#include "zeek/DNS_Mgr.h" #include "zeek/DNS_Mgr.h"
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
namespace zeek::detail namespace zeek::detail {
{
DNS_Mapping::DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type) DNS_Mapping::DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type) {
{ Init(h);
Init(h); req_host = host;
req_host = host; req_ttl = ttl;
req_ttl = ttl; req_type = type;
req_type = type;
if ( names.empty() ) if ( names.empty() )
names.push_back(std::move(host)); names.push_back(std::move(host));
} }
DNS_Mapping::DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl) DNS_Mapping::DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl) {
{ Init(h);
Init(h); req_addr = addr;
req_addr = addr; req_ttl = ttl;
req_ttl = ttl; req_type = T_PTR;
req_type = T_PTR; }
}
DNS_Mapping::DNS_Mapping(FILE* f) DNS_Mapping::DNS_Mapping(FILE* f) {
{ Clear();
Clear(); init_failed = true;
init_failed = true;
req_ttl = 0; req_ttl = 0;
creation_time = 0; creation_time = 0;
char buf[512]; char buf[512];
if ( ! fgets(buf, sizeof(buf), f) ) if ( ! fgets(buf, sizeof(buf), f) ) {
{ no_mapping = true;
no_mapping = true; return;
return; }
}
char req_buf[512 + 1], name_buf[512 + 1]; char req_buf[512 + 1], name_buf[512 + 1];
int is_req_host; int is_req_host;
int failed_local; int failed_local;
int num_addrs; int num_addrs;
if ( sscanf(buf, "%lf %d %512s %d %512s %d %d %" PRIu32, &creation_time, &is_req_host, req_buf, if ( sscanf(buf, "%lf %d %512s %d %512s %d %d %" PRIu32, &creation_time, &is_req_host, req_buf, &failed_local,
&failed_local, name_buf, &req_type, &num_addrs, &req_ttl) != 8 ) name_buf, &req_type, &num_addrs, &req_ttl) != 8 ) {
{ no_mapping = true;
no_mapping = true; return;
return; }
}
failed = static_cast<bool>(failed_local); failed = static_cast<bool>(failed_local);
if ( is_req_host ) if ( is_req_host )
req_host = req_buf; req_host = req_buf;
else else
req_addr = IPAddr(req_buf); req_addr = IPAddr(req_buf);
names.emplace_back(name_buf); names.emplace_back(name_buf);
for ( int i = 0; i < num_addrs; ++i ) for ( int i = 0; i < num_addrs; ++i ) {
{ if ( ! fgets(buf, sizeof(buf), f) )
if ( ! fgets(buf, sizeof(buf), f) ) return;
return;
char* newline = strchr(buf, '\n'); char* newline = strchr(buf, '\n');
if ( newline ) if ( newline )
*newline = '\0'; *newline = '\0';
addrs.emplace_back(buf); addrs.emplace_back(buf);
} }
init_failed = false; init_failed = false;
} }
ListValPtr DNS_Mapping::Addrs() ListValPtr DNS_Mapping::Addrs() {
{ if ( failed )
if ( failed ) return nullptr;
return nullptr;
if ( ! addrs_val ) if ( ! addrs_val ) {
{ addrs_val = make_intrusive<ListVal>(TYPE_ADDR);
addrs_val = make_intrusive<ListVal>(TYPE_ADDR);
for ( const auto& addr : addrs ) for ( const auto& addr : addrs )
addrs_val->Append(make_intrusive<AddrVal>(addr)); addrs_val->Append(make_intrusive<AddrVal>(addr));
} }
return addrs_val; return addrs_val;
} }
TableValPtr DNS_Mapping::AddrsSet() TableValPtr DNS_Mapping::AddrsSet() {
{ auto l = Addrs();
auto l = Addrs();
if ( ! l || l->Length() == 0 ) if ( ! l || l->Length() == 0 )
return DNS_Mgr::empty_addr_set(); return DNS_Mgr::empty_addr_set();
return l->ToSetVal(); return l->ToSetVal();
} }
StringValPtr DNS_Mapping::Host() StringValPtr DNS_Mapping::Host() {
{ if ( failed || names.empty() )
if ( failed || names.empty() ) return nullptr;
return nullptr;
if ( ! host_val ) if ( ! host_val )
host_val = make_intrusive<StringVal>(names[0]); host_val = make_intrusive<StringVal>(names[0]);
return host_val; return host_val;
} }
void DNS_Mapping::Init(struct hostent* h) void DNS_Mapping::Init(struct hostent* h) {
{ no_mapping = false;
no_mapping = false; init_failed = false;
init_failed = false; creation_time = util::current_time();
creation_time = util::current_time(); host_val = nullptr;
host_val = nullptr; addrs_val = nullptr;
addrs_val = nullptr;
if ( ! h ) if ( ! h ) {
{ Clear();
Clear(); return;
return; }
}
if ( h->h_name ) if ( h->h_name )
// for now, just use the official name // for now, just use the official name
// TODO: this could easily be expanded to include all of the aliases as well // TODO: this could easily be expanded to include all of the aliases as well
names.emplace_back(h->h_name); names.emplace_back(h->h_name);
if ( h->h_addr_list ) if ( h->h_addr_list ) {
{ for ( int i = 0; h->h_addr_list[i] != NULL; ++i ) {
for ( int i = 0; h->h_addr_list[i] != NULL; ++i ) if ( h->h_addrtype == AF_INET )
{ addrs.emplace_back(IPv4, (uint32_t*)h->h_addr_list[i], IPAddr::Network);
if ( h->h_addrtype == AF_INET ) else if ( h->h_addrtype == AF_INET6 )
addrs.emplace_back(IPv4, (uint32_t*)h->h_addr_list[i], IPAddr::Network); addrs.emplace_back(IPv6, (uint32_t*)h->h_addr_list[i], IPAddr::Network);
else if ( h->h_addrtype == AF_INET6 ) }
addrs.emplace_back(IPv6, (uint32_t*)h->h_addr_list[i], IPAddr::Network); }
}
}
failed = false; failed = false;
} }
void DNS_Mapping::Clear() void DNS_Mapping::Clear() {
{ names.clear();
names.clear(); host_val = nullptr;
host_val = nullptr; addrs.clear();
addrs.clear(); addrs_val = nullptr;
addrs_val = nullptr; no_mapping = false;
no_mapping = false; req_type = 0;
req_type = 0; failed = true;
failed = true; }
}
void DNS_Mapping::Save(FILE* f) const void DNS_Mapping::Save(FILE* f) const {
{ fprintf(f, "%.0f %d %s %d %s %d %zu %" PRIu32 "\n", creation_time, ! req_host.empty(),
fprintf(f, "%.0f %d %s %d %s %d %zu %" PRIu32 "\n", creation_time, ! req_host.empty(), req_host.empty() ? req_addr.AsString().c_str() : req_host.c_str(), failed,
req_host.empty() ? req_addr.AsString().c_str() : req_host.c_str(), failed, names.empty() ? "*" : names[0].c_str(), req_type, addrs.size(), req_ttl);
names.empty() ? "*" : names[0].c_str(), req_type, addrs.size(), req_ttl);
for ( const auto& addr : addrs ) for ( const auto& addr : addrs )
fprintf(f, "%s\n", addr.AsString().c_str()); fprintf(f, "%s\n", addr.AsString().c_str());
} }
void DNS_Mapping::Merge(const DNS_MappingPtr& other) void DNS_Mapping::Merge(const DNS_MappingPtr& other) {
{ std::copy(other->names.begin(), other->names.end(), std::back_inserter(names));
std::copy(other->names.begin(), other->names.end(), std::back_inserter(names)); std::copy(other->addrs.begin(), other->addrs.end(), std::back_inserter(addrs));
std::copy(other->addrs.begin(), other->addrs.end(), std::back_inserter(addrs)); }
}
// This value needs to be incremented if something changes in the data stored by Save(). This // This value needs to be incremented if something changes in the data stored by Save(). This
// allows us to change the structure of the cache without breaking something in DNS_Mgr. // allows us to change the structure of the cache without breaking something in DNS_Mgr.
constexpr int FILE_VERSION = 1; constexpr int FILE_VERSION = 1;
void DNS_Mapping::InitializeCache(FILE* f) void DNS_Mapping::InitializeCache(FILE* f) { fprintf(f, "%d\n", FILE_VERSION); }
{
fprintf(f, "%d\n", FILE_VERSION);
}
bool DNS_Mapping::ValidateCacheVersion(FILE* f) bool DNS_Mapping::ValidateCacheVersion(FILE* f) {
{ char buf[512];
char buf[512]; if ( ! fgets(buf, sizeof(buf), f) )
if ( ! fgets(buf, sizeof(buf), f) ) return false;
return false;
int version; int version;
if ( sscanf(buf, "%d", &version) != 1 ) if ( sscanf(buf, "%d", &version) != 1 ) {
{ reporter->Warning("Existing DNS cache did not have correct version, ignoring");
reporter->Warning("Existing DNS cache did not have correct version, ignoring"); return false;
return false; }
}
return FILE_VERSION == version; return FILE_VERSION == version;
} }
////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////
TEST_CASE("dns_mapping init null hostent") TEST_CASE("dns_mapping init null hostent") {
{ DNS_Mapping mapping("www.apple.com", nullptr, 123, T_A);
DNS_Mapping mapping("www.apple.com", nullptr, 123, T_A);
CHECK(! mapping.Valid()); CHECK(! mapping.Valid());
CHECK(mapping.Addrs() == nullptr); CHECK(mapping.Addrs() == nullptr);
CHECK(mapping.AddrsSet()->EqualTo(DNS_Mgr::empty_addr_set())); CHECK(mapping.AddrsSet()->EqualTo(DNS_Mgr::empty_addr_set()));
CHECK(mapping.Host() == nullptr); CHECK(mapping.Host() == nullptr);
} }
TEST_CASE("dns_mapping init host") TEST_CASE("dns_mapping init host") {
{ IPAddr addr("1.2.3.4");
IPAddr addr("1.2.3.4"); in4_addr in4;
in4_addr in4; addr.CopyIPv4(&in4);
addr.CopyIPv4(&in4);
struct hostent he; struct hostent he;
he.h_name = util::copy_string("testing.home"); he.h_name = util::copy_string("testing.home");
he.h_aliases = NULL; he.h_aliases = NULL;
he.h_addrtype = AF_INET; he.h_addrtype = AF_INET;
he.h_length = sizeof(in_addr); he.h_length = sizeof(in_addr);
std::vector<in_addr*> addrs = {&in4, NULL}; std::vector<in_addr*> addrs = {&in4, NULL};
he.h_addr_list = reinterpret_cast<char**>(addrs.data()); he.h_addr_list = reinterpret_cast<char**>(addrs.data());
DNS_Mapping mapping("testing.home", &he, 123, T_A); DNS_Mapping mapping("testing.home", &he, 123, T_A);
CHECK(mapping.Valid()); CHECK(mapping.Valid());
CHECK(mapping.ReqAddr() == IPAddr::v6_unspecified); CHECK(mapping.ReqAddr() == IPAddr::v6_unspecified);
CHECK(strcmp(mapping.ReqHost(), "testing.home") == 0); CHECK(strcmp(mapping.ReqHost(), "testing.home") == 0);
CHECK(mapping.ReqStr() == "testing.home"); CHECK(mapping.ReqStr() == "testing.home");
auto lva = mapping.Addrs(); auto lva = mapping.Addrs();
REQUIRE(lva != nullptr); REQUIRE(lva != nullptr);
CHECK(lva->Length() == 1); CHECK(lva->Length() == 1);
auto lvae = lva->Idx(0)->AsAddrVal(); auto lvae = lva->Idx(0)->AsAddrVal();
REQUIRE(lvae != nullptr); REQUIRE(lvae != nullptr);
CHECK(lvae->Get().AsString() == "1.2.3.4"); CHECK(lvae->Get().AsString() == "1.2.3.4");
auto tvas = mapping.AddrsSet(); auto tvas = mapping.AddrsSet();
REQUIRE(tvas != nullptr); REQUIRE(tvas != nullptr);
CHECK_FALSE(tvas->EqualTo(DNS_Mgr::empty_addr_set())); CHECK_FALSE(tvas->EqualTo(DNS_Mgr::empty_addr_set()));
auto svh = mapping.Host(); auto svh = mapping.Host();
REQUIRE(svh != nullptr); REQUIRE(svh != nullptr);
CHECK(svh->ToStdString() == "testing.home"); CHECK(svh->ToStdString() == "testing.home");
delete[] he.h_name; delete[] he.h_name;
} }
TEST_CASE("dns_mapping init addr") TEST_CASE("dns_mapping init addr") {
{ IPAddr addr("1.2.3.4");
IPAddr addr("1.2.3.4"); in4_addr in4;
in4_addr in4; addr.CopyIPv4(&in4);
addr.CopyIPv4(&in4);
struct hostent he; struct hostent he;
he.h_name = util::copy_string("testing.home"); he.h_name = util::copy_string("testing.home");
he.h_aliases = NULL; he.h_aliases = NULL;
he.h_addrtype = AF_INET; he.h_addrtype = AF_INET;
he.h_length = sizeof(in_addr); he.h_length = sizeof(in_addr);
std::vector<in_addr*> addrs = {&in4, NULL}; std::vector<in_addr*> addrs = {&in4, NULL};
he.h_addr_list = reinterpret_cast<char**>(addrs.data()); he.h_addr_list = reinterpret_cast<char**>(addrs.data());
DNS_Mapping mapping(addr, &he, 123); DNS_Mapping mapping(addr, &he, 123);
CHECK(mapping.Valid()); CHECK(mapping.Valid());
CHECK(mapping.ReqAddr() == addr); CHECK(mapping.ReqAddr() == addr);
CHECK(mapping.ReqHost() == nullptr); CHECK(mapping.ReqHost() == nullptr);
CHECK(mapping.ReqStr() == "1.2.3.4"); CHECK(mapping.ReqStr() == "1.2.3.4");
auto lva = mapping.Addrs(); auto lva = mapping.Addrs();
REQUIRE(lva != nullptr); REQUIRE(lva != nullptr);
CHECK(lva->Length() == 1); CHECK(lva->Length() == 1);
auto lvae = lva->Idx(0)->AsAddrVal(); auto lvae = lva->Idx(0)->AsAddrVal();
REQUIRE(lvae != nullptr); REQUIRE(lvae != nullptr);
CHECK(lvae->Get().AsString() == "1.2.3.4"); CHECK(lvae->Get().AsString() == "1.2.3.4");
auto tvas = mapping.AddrsSet(); auto tvas = mapping.AddrsSet();
REQUIRE(tvas != nullptr); REQUIRE(tvas != nullptr);
CHECK_FALSE(tvas->EqualTo(DNS_Mgr::empty_addr_set())); CHECK_FALSE(tvas->EqualTo(DNS_Mgr::empty_addr_set()));
auto svh = mapping.Host(); auto svh = mapping.Host();
REQUIRE(svh != nullptr); REQUIRE(svh != nullptr);
CHECK(svh->ToStdString() == "testing.home"); CHECK(svh->ToStdString() == "testing.home");
delete[] he.h_name; delete[] he.h_name;
} }
TEST_CASE("dns_mapping save reload") TEST_CASE("dns_mapping save reload") {
{ // TODO: this test uses fmemopen and mkdtemp, both of which aren't available on
// TODO: this test uses fmemopen and mkdtemp, both of which aren't available on // Windows. We'll have to figure out another way to do this test there.
// Windows. We'll have to figure out another way to do this test there.
#ifndef _MSC_VER #ifndef _MSC_VER
IPAddr addr("1.2.3.4"); IPAddr addr("1.2.3.4");
in4_addr in4; in4_addr in4;
addr.CopyIPv4(&in4); addr.CopyIPv4(&in4);
struct hostent he; struct hostent he;
he.h_name = util::copy_string("testing.home"); he.h_name = util::copy_string("testing.home");
he.h_aliases = NULL; he.h_aliases = NULL;
he.h_addrtype = AF_INET; he.h_addrtype = AF_INET;
he.h_length = sizeof(in_addr); he.h_length = sizeof(in_addr);
std::vector<in_addr*> addrs = {&in4, NULL}; std::vector<in_addr*> addrs = {&in4, NULL};
he.h_addr_list = reinterpret_cast<char**>(addrs.data()); he.h_addr_list = reinterpret_cast<char**>(addrs.data());
// Create a temporary file in memory and fseek to the end of it so we're at // Create a temporary file in memory and fseek to the end of it so we're at
// EOF for the next bit. // EOF for the next bit.
char buffer[4096]; char buffer[4096];
memset(buffer, 0, 4096); memset(buffer, 0, 4096);
FILE* tmpfile = fmemopen(buffer, 4096, "r+"); FILE* tmpfile = fmemopen(buffer, 4096, "r+");
if ( fseek(tmpfile, 0, SEEK_END) < 0 ) if ( fseek(tmpfile, 0, SEEK_END) < 0 )
reporter->Error("DNS_Mapping: seek failed"); reporter->Error("DNS_Mapping: seek failed");
// Try loading from the file at EOF. This should cause a mapping failure. // Try loading from the file at EOF. This should cause a mapping failure.
DNS_Mapping mapping(tmpfile); DNS_Mapping mapping(tmpfile);
CHECK(mapping.NoMapping()); CHECK(mapping.NoMapping());
rewind(tmpfile); rewind(tmpfile);
// Try reading from the empty file. This should cause an init failure. // Try reading from the empty file. This should cause an init failure.
DNS_Mapping mapping2(tmpfile); DNS_Mapping mapping2(tmpfile);
CHECK(mapping2.InitFailed()); CHECK(mapping2.InitFailed());
rewind(tmpfile); rewind(tmpfile);
// Save a valid mapping into the file and rewind to the start. // Save a valid mapping into the file and rewind to the start.
DNS_Mapping mapping3(addr, &he, 123); DNS_Mapping mapping3(addr, &he, 123);
mapping3.Save(tmpfile); mapping3.Save(tmpfile);
rewind(tmpfile); rewind(tmpfile);
// Test loading the mapping back out of the file // Test loading the mapping back out of the file
DNS_Mapping mapping4(tmpfile); DNS_Mapping mapping4(tmpfile);
fclose(tmpfile); fclose(tmpfile);
CHECK(mapping4.Valid()); CHECK(mapping4.Valid());
CHECK(mapping4.ReqAddr() == addr); CHECK(mapping4.ReqAddr() == addr);
CHECK(mapping4.ReqHost() == nullptr); CHECK(mapping4.ReqHost() == nullptr);
CHECK(mapping4.ReqStr() == "1.2.3.4"); CHECK(mapping4.ReqStr() == "1.2.3.4");
auto lva = mapping4.Addrs(); auto lva = mapping4.Addrs();
REQUIRE(lva != nullptr); REQUIRE(lva != nullptr);
CHECK(lva->Length() == 1); CHECK(lva->Length() == 1);
auto lvae = lva->Idx(0)->AsAddrVal(); auto lvae = lva->Idx(0)->AsAddrVal();
REQUIRE(lvae != nullptr); REQUIRE(lvae != nullptr);
CHECK(lvae->Get().AsString() == "1.2.3.4"); CHECK(lvae->Get().AsString() == "1.2.3.4");
auto tvas = mapping4.AddrsSet(); auto tvas = mapping4.AddrsSet();
REQUIRE(tvas != nullptr); REQUIRE(tvas != nullptr);
CHECK(tvas != DNS_Mgr::empty_addr_set()); CHECK(tvas != DNS_Mgr::empty_addr_set());
auto svh = mapping4.Host(); auto svh = mapping4.Host();
REQUIRE(svh != nullptr); REQUIRE(svh != nullptr);
CHECK(svh->ToStdString() == "testing.home"); CHECK(svh->ToStdString() == "testing.home");
delete[] he.h_name; delete[] he.h_name;
#endif #endif
} }
TEST_CASE("dns_mapping multiple addresses") TEST_CASE("dns_mapping multiple addresses") {
{ IPAddr addr("1.2.3.4");
IPAddr addr("1.2.3.4"); in4_addr in4_1;
in4_addr in4_1; addr.CopyIPv4(&in4_1);
addr.CopyIPv4(&in4_1);
IPAddr addr2("5.6.7.8"); IPAddr addr2("5.6.7.8");
in4_addr in4_2; in4_addr in4_2;
addr2.CopyIPv4(&in4_2); addr2.CopyIPv4(&in4_2);
struct hostent he; struct hostent he;
he.h_name = util::copy_string("testing.home"); he.h_name = util::copy_string("testing.home");
he.h_aliases = NULL; he.h_aliases = NULL;
he.h_addrtype = AF_INET; he.h_addrtype = AF_INET;
he.h_length = sizeof(in_addr); he.h_length = sizeof(in_addr);
std::vector<in_addr*> addrs = {&in4_1, &in4_2, NULL}; std::vector<in_addr*> addrs = {&in4_1, &in4_2, NULL};
he.h_addr_list = reinterpret_cast<char**>(addrs.data()); he.h_addr_list = reinterpret_cast<char**>(addrs.data());
DNS_Mapping mapping("testing.home", &he, 123, T_A); DNS_Mapping mapping("testing.home", &he, 123, T_A);
CHECK(mapping.Valid()); CHECK(mapping.Valid());
auto lva = mapping.Addrs(); auto lva = mapping.Addrs();
REQUIRE(lva != nullptr); REQUIRE(lva != nullptr);
CHECK(lva->Length() == 2); CHECK(lva->Length() == 2);
auto lvae = lva->Idx(0)->AsAddrVal(); auto lvae = lva->Idx(0)->AsAddrVal();
REQUIRE(lvae != nullptr); REQUIRE(lvae != nullptr);
CHECK(lvae->Get().AsString() == "1.2.3.4"); CHECK(lvae->Get().AsString() == "1.2.3.4");
lvae = lva->Idx(1)->AsAddrVal(); lvae = lva->Idx(1)->AsAddrVal();
REQUIRE(lvae != nullptr); REQUIRE(lvae != nullptr);
CHECK(lvae->Get().AsString() == "5.6.7.8"); CHECK(lvae->Get().AsString() == "5.6.7.8");
delete[] he.h_name; delete[] he.h_name;
} }
TEST_CASE("dns_mapping ipv6") TEST_CASE("dns_mapping ipv6") {
{ IPAddr addr("64:ff9b:1::");
IPAddr addr("64:ff9b:1::"); in6_addr in6;
in6_addr in6; addr.CopyIPv6(&in6);
addr.CopyIPv6(&in6);
struct hostent he; struct hostent he;
he.h_name = util::copy_string("testing.home"); he.h_name = util::copy_string("testing.home");
he.h_aliases = NULL; he.h_aliases = NULL;
he.h_addrtype = AF_INET6; he.h_addrtype = AF_INET6;
he.h_length = sizeof(in6_addr); he.h_length = sizeof(in6_addr);
std::vector<in6_addr*> addrs = {&in6, NULL}; std::vector<in6_addr*> addrs = {&in6, NULL};
he.h_addr_list = reinterpret_cast<char**>(addrs.data()); he.h_addr_list = reinterpret_cast<char**>(addrs.data());
DNS_Mapping mapping(addr, &he, 123); DNS_Mapping mapping(addr, &he, 123);
CHECK(mapping.Valid()); CHECK(mapping.Valid());
CHECK(mapping.ReqAddr() == addr); CHECK(mapping.ReqAddr() == addr);
CHECK(mapping.ReqHost() == nullptr); CHECK(mapping.ReqHost() == nullptr);
CHECK(mapping.ReqStr() == "64:ff9b:1::"); CHECK(mapping.ReqStr() == "64:ff9b:1::");
auto lva = mapping.Addrs(); auto lva = mapping.Addrs();
REQUIRE(lva != nullptr); REQUIRE(lva != nullptr);
CHECK(lva->Length() == 1); CHECK(lva->Length() == 1);
auto lvae = lva->Idx(0)->AsAddrVal(); auto lvae = lva->Idx(0)->AsAddrVal();
REQUIRE(lvae != nullptr); REQUIRE(lvae != nullptr);
CHECK(lvae->Get().AsString() == "64:ff9b:1::"); CHECK(lvae->Get().AsString() == "64:ff9b:1::");
delete[] he.h_name; delete[] he.h_name;
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -8,73 +8,71 @@
#include "zeek/IPAddr.h" #include "zeek/IPAddr.h"
#include "zeek/Val.h" #include "zeek/Val.h"
namespace zeek::detail namespace zeek::detail {
{
class DNS_Mapping; class DNS_Mapping;
using DNS_MappingPtr = std::shared_ptr<DNS_Mapping>; using DNS_MappingPtr = std::shared_ptr<DNS_Mapping>;
class DNS_Mapping class DNS_Mapping {
{
public: public:
DNS_Mapping() = delete; DNS_Mapping() = delete;
DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type); DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type);
DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl); DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl);
DNS_Mapping(FILE* f); DNS_Mapping(FILE* f);
bool NoMapping() const { return no_mapping; } bool NoMapping() const { return no_mapping; }
bool InitFailed() const { return init_failed; } bool InitFailed() const { return init_failed; }
~DNS_Mapping() = default; ~DNS_Mapping() = default;
// Returns nil if this was an address request. // Returns nil if this was an address request.
// TODO: fix this an uses of this to just return the empty string // TODO: fix this an uses of this to just return the empty string
const char* ReqHost() const { return req_host.empty() ? nullptr : req_host.c_str(); } const char* ReqHost() const { return req_host.empty() ? nullptr : req_host.c_str(); }
const IPAddr& ReqAddr() const { return req_addr; } const IPAddr& ReqAddr() const { return req_addr; }
std::string ReqStr() const { return req_host.empty() ? req_addr.AsString() : req_host; } std::string ReqStr() const { return req_host.empty() ? req_addr.AsString() : req_host; }
int ReqType() const { return req_type; } int ReqType() const { return req_type; }
ListValPtr Addrs(); ListValPtr Addrs();
TableValPtr AddrsSet(); // addresses returned as a set TableValPtr AddrsSet(); // addresses returned as a set
StringValPtr Host(); StringValPtr Host();
double CreationTime() const { return creation_time; } double CreationTime() const { return creation_time; }
uint32_t TTL() const { return req_ttl; } uint32_t TTL() const { return req_ttl; }
void Save(FILE* f) const; void Save(FILE* f) const;
bool Failed() const { return failed; } bool Failed() const { return failed; }
bool Valid() const { return ! failed; } bool Valid() const { return ! failed; }
bool Expired() const { return util::current_time() > (creation_time + req_ttl); } bool Expired() const { return util::current_time() > (creation_time + req_ttl); }
void Merge(const DNS_MappingPtr& other); void Merge(const DNS_MappingPtr& other);
static void InitializeCache(FILE* f); static void InitializeCache(FILE* f);
static bool ValidateCacheVersion(FILE* f); static bool ValidateCacheVersion(FILE* f);
protected: protected:
friend class DNS_Mgr; friend class DNS_Mgr;
void Init(struct hostent* h); void Init(struct hostent* h);
void Clear(); void Clear();
std::string req_host; std::string req_host;
IPAddr req_addr; IPAddr req_addr;
uint32_t req_ttl = 0; uint32_t req_ttl = 0;
int req_type = 0; int req_type = 0;
// This class supports multiple names per address, but we only store one of them. // This class supports multiple names per address, but we only store one of them.
std::vector<std::string> names; std::vector<std::string> names;
StringValPtr host_val; StringValPtr host_val;
std::vector<IPAddr> addrs; std::vector<IPAddr> addrs;
ListValPtr addrs_val; ListValPtr addrs_val;
double creation_time = 0.0; double creation_time = 0.0;
bool no_mapping = false; // when initializing from a file, immediately hit EOF bool no_mapping = false; // when initializing from a file, immediately hit EOF
bool init_failed = false; bool init_failed = false;
bool failed = false; bool failed = false;
}; };
} // namespace zeek::detail } // namespace zeek::detail

File diff suppressed because it is too large Load diff

View file

@ -27,326 +27,314 @@ typedef struct ares_channeldata* ares_channel;
#define T_TXT 16 #define T_TXT 16
#endif #endif
namespace zeek namespace zeek {
{
class Val; class Val;
class ListVal; class ListVal;
class TableVal; class TableVal;
class StringVal; class StringVal;
template <class T> class IntrusivePtr; template<class T>
class IntrusivePtr;
using ValPtr = IntrusivePtr<Val>; using ValPtr = IntrusivePtr<Val>;
using ListValPtr = IntrusivePtr<ListVal>; using ListValPtr = IntrusivePtr<ListVal>;
using TableValPtr = IntrusivePtr<TableVal>; using TableValPtr = IntrusivePtr<TableVal>;
using StringValPtr = IntrusivePtr<StringVal>; using StringValPtr = IntrusivePtr<StringVal>;
} // namespace zeek } // namespace zeek
namespace zeek::detail namespace zeek::detail {
{
class DNS_Mapping; class DNS_Mapping;
using DNS_MappingPtr = std::shared_ptr<DNS_Mapping>; using DNS_MappingPtr = std::shared_ptr<DNS_Mapping>;
class DNS_Request; class DNS_Request;
enum DNS_MgrMode enum DNS_MgrMode {
{ DNS_PRIME, // used to prime the cache
DNS_PRIME, // used to prime the cache DNS_FORCE, // internal error if cache miss
DNS_FORCE, // internal error if cache miss DNS_DEFAULT, // lookup names as they're requested
DNS_DEFAULT, // lookup names as they're requested DNS_FAKE, // don't look up names, just return dummy results
DNS_FAKE, // don't look up names, just return dummy results };
};
class DNS_Mgr : public iosource::IOSource class DNS_Mgr : public iosource::IOSource {
{
public: public:
/** /**
* Base class for callback handling for asynchronous lookups. * Base class for callback handling for asynchronous lookups.
*/ */
class LookupCallback class LookupCallback {
{ public:
public: virtual ~LookupCallback() = default;
virtual ~LookupCallback() = default;
/** /**
* Called when an address lookup finishes. * Called when an address lookup finishes.
* *
* @param name The resulting name from the lookup. * @param name The resulting name from the lookup.
*/ */
virtual void Resolved(const std::string& name){}; virtual void Resolved(const std::string& name){};
/** /**
* Called when a name lookup finishes. * Called when a name lookup finishes.
* *
* @param addrs A table of the resulting addresses from the lookup. * @param addrs A table of the resulting addresses from the lookup.
*/ */
virtual void Resolved(TableValPtr addrs){}; virtual void Resolved(TableValPtr addrs){};
/** /**
* Generic callback method for all request types. * Generic callback method for all request types.
* *
* @param val A Val containing the data from the query. * @param val A Val containing the data from the query.
*/ */
virtual void Resolved(ValPtr data, int request_type) { } virtual void Resolved(ValPtr data, int request_type) {}
/** /**
* Called when a timeout request occurs. * Called when a timeout request occurs.
*/ */
virtual void Timeout() = 0; virtual void Timeout() = 0;
}; };
explicit DNS_Mgr(DNS_MgrMode mode); explicit DNS_Mgr(DNS_MgrMode mode);
~DNS_Mgr() override; ~DNS_Mgr() override;
/** /**
* Finalizes the source when it's being closed. * Finalizes the source when it's being closed.
*/ */
void Done() override; void Done() override;
/** /**
* Finalizes the manager initialization. This should be called only after all * Finalizes the manager initialization. This should be called only after all
* of the scripts have been parsed at startup. * of the scripts have been parsed at startup.
*/ */
void InitPostScript(); void InitPostScript();
/** /**
* Attempts to process one more round of requests and then flushes the * Attempts to process one more round of requests and then flushes the
* mapping caches. * mapping caches.
*/ */
void Flush(); void Flush();
/** /**
* Looks up the address(es) of a given host and returns a set of addresses. * Looks up the address(es) of a given host and returns a set of addresses.
* This is a shorthand method for doing A/AAAA requests. This is a * This is a shorthand method for doing A/AAAA requests. This is a
* synchronous request and will block until the request completes or times * synchronous request and will block until the request completes or times
* out. * out.
* *
* @param host The hostname to lookup an address for. * @param host The hostname to lookup an address for.
* @return A set of addresses for the host. * @return A set of addresses for the host.
*/ */
TableValPtr LookupHost(const std::string& host); TableValPtr LookupHost(const std::string& host);
/** /**
* Looks up the hostname of a given address. This is a shorthand method for * Looks up the hostname of a given address. This is a shorthand method for
* doing PTR requests. This is a synchronous request and will block until * doing PTR requests. This is a synchronous request and will block until
* the request completes or times out. * the request completes or times out.
* *
* @param host The addr to lookup a hostname for. * @param host The addr to lookup a hostname for.
* @return The hostname for the address. * @return The hostname for the address.
*/ */
StringValPtr LookupAddr(const IPAddr& addr); StringValPtr LookupAddr(const IPAddr& addr);
/** /**
* Performs a generic request to the DNS server. This is a synchronous * Performs a generic request to the DNS server. This is a synchronous
* request and will block until the request completes or times out. * request and will block until the request completes or times out.
* *
* @param name The name or address to make a request for. If this is an * @param name The name or address to make a request for. If this is an
* address it should be in arpa format (x.x.x.x.in-addr.arpa or x-*.ip6.arpa). * address it should be in arpa format (x.x.x.x.in-addr.arpa or x-*.ip6.arpa).
* Note that calling LookupAddr for PTR requests does this conversion * Note that calling LookupAddr for PTR requests does this conversion
* automatically. * automatically.
* @param request_type The type of request to make. This should be one of * @param request_type The type of request to make. This should be one of
* the type values defined in arpa/nameser.h or ares_nameser.h. * the type values defined in arpa/nameser.h or ares_nameser.h.
* @return The requested data. * @return The requested data.
*/ */
ValPtr Lookup(const std::string& name, int request_type); ValPtr Lookup(const std::string& name, int request_type);
/** /**
* Looks up the address(es) of a given host. This is a shorthand method * Looks up the address(es) of a given host. This is a shorthand method
* for doing A/AAAA requests. This is an asynchronous request. The * for doing A/AAAA requests. This is an asynchronous request. The
* response will be handled via the provided callback object. * response will be handled via the provided callback object.
* *
* @param host The hostname to lookup an address for. * @param host The hostname to lookup an address for.
* @param callback A callback object for handling the response. * @param callback A callback object for handling the response.
*/ */
void LookupHost(const std::string& host, LookupCallback* callback); void LookupHost(const std::string& host, LookupCallback* callback);
/** /**
* Looks up the hostname of a given address. This is a shorthand method for * Looks up the hostname of a given address. This is a shorthand method for
* doing PTR requests. This is an asynchronous request. The response will * doing PTR requests. This is an asynchronous request. The response will
* be handled via the provided callback object. * be handled via the provided callback object.
* *
* @param host The addr to lookup a hostname for. * @param host The addr to lookup a hostname for.
* @param callback A callback object for handling the response. * @param callback A callback object for handling the response.
*/ */
void LookupAddr(const IPAddr& addr, LookupCallback* callback); void LookupAddr(const IPAddr& addr, LookupCallback* callback);
/** /**
* Performs a generic request to the DNS server. This is an asynchronous * Performs a generic request to the DNS server. This is an asynchronous
* request. The response will be handled via the provided callback * request. The response will be handled via the provided callback
* object. * object.
* *
* @param name The name or address to make a request for. If this is an * @param name The name or address to make a request for. If this is an
* address it should be in arpa format (x.x.x.x.in-addr.arpa or x-*.ip6.arpa). * address it should be in arpa format (x.x.x.x.in-addr.arpa or x-*.ip6.arpa).
* Note that calling LookupAddr for PTR requests does this conversion * Note that calling LookupAddr for PTR requests does this conversion
* automatically. * automatically.
* @param request_type The type of request to make. This should be one of * @param request_type The type of request to make. This should be one of
* the type values defined in arpa/nameser.h or ares_nameser.h. * the type values defined in arpa/nameser.h or ares_nameser.h.
* @param callback A callback object for handling the response. * @param callback A callback object for handling the response.
*/ */
void Lookup(const std::string& name, int request_type, LookupCallback* callback); void Lookup(const std::string& name, int request_type, LookupCallback* callback);
/** /**
* Sets the directory where to store DNS data when Save() is called. * Sets the directory where to store DNS data when Save() is called.
*/ */
void SetDir(const std::string& arg_dir) { dir = arg_dir; } void SetDir(const std::string& arg_dir) { dir = arg_dir; }
/** /**
* Waits for responses to become available or a timeout to occur, * Waits for responses to become available or a timeout to occur,
* and handles any responses. * and handles any responses.
*/ */
void Resolve(); void Resolve();
/** /**
* Saves the current name and address caches to disk. * Saves the current name and address caches to disk.
*/ */
bool Save(); bool Save();
struct Stats struct Stats {
{ unsigned long requests; // These count only async requests.
unsigned long requests; // These count only async requests. unsigned long successful;
unsigned long successful; unsigned long failed;
unsigned long failed; unsigned long pending;
unsigned long pending; unsigned long cached_hosts;
unsigned long cached_hosts; unsigned long cached_addresses;
unsigned long cached_addresses; unsigned long cached_texts;
unsigned long cached_texts; unsigned long cached_total;
unsigned long cached_total; };
};
/** /**
* Returns the current statistics for the DNS_Manager. * Returns the current statistics for the DNS_Manager.
* *
* @param stats A pointer to a stats object to return the data in. * @param stats A pointer to a stats object to return the data in.
*/ */
void GetStats(Stats* stats); void GetStats(Stats* stats);
/** /**
* Adds a result from a request to the caches. This is public so that the * Adds a result from a request to the caches. This is public so that the
* callback methods can call it from outside of the DNS_Mgr class. * callback methods can call it from outside of the DNS_Mgr class.
* *
* @param dr The request associated with the result. * @param dr The request associated with the result.
* @param h A hostent structure containing the actual result data. * @param h A hostent structure containing the actual result data.
* @param ttl A ttl value contained in the response from the server. * @param ttl A ttl value contained in the response from the server.
* @param merge A flag for whether these results should be merged into * @param merge A flag for whether these results should be merged into
* an existing mapping. If false, AddResult will attempt to replace the * an existing mapping. If false, AddResult will attempt to replace the
* existing mapping with the new data and delete the old mapping. * existing mapping with the new data and delete the old mapping.
*/ */
void AddResult(DNS_Request* dr, struct hostent* h, uint32_t ttl, bool merge = false); void AddResult(DNS_Request* dr, struct hostent* h, uint32_t ttl, bool merge = false);
/** /**
* Returns an empty set of addresses, used in various error cases and during * Returns an empty set of addresses, used in various error cases and during
* cache priming. * cache priming.
*/ */
static TableValPtr empty_addr_set(); static TableValPtr empty_addr_set();
/** /**
* Returns the full path to the file used to store the DNS cache. * Returns the full path to the file used to store the DNS cache.
*/ */
std::string CacheFile() const { return cache_name; } std::string CacheFile() const { return cache_name; }
/** /**
* Used by the c-ares socket call back to register/unregister a socket file descriptor. * Used by the c-ares socket call back to register/unregister a socket file descriptor.
*/ */
void RegisterSocket(int fd, bool read, bool write); void RegisterSocket(int fd, bool read, bool write);
ares_channel& GetChannel() { return channel; } ares_channel& GetChannel() { return channel; }
protected: protected:
friend class LookupCallback; friend class LookupCallback;
friend class DNS_Request; friend class DNS_Request;
StringValPtr LookupAddrInCache(const IPAddr& addr, bool cleanup_expired = false, StringValPtr LookupAddrInCache(const IPAddr& addr, bool cleanup_expired = false, bool check_failed = false);
bool check_failed = false); TableValPtr LookupNameInCache(const std::string& name, bool cleanup_expired = false, bool check_failed = false);
TableValPtr LookupNameInCache(const std::string& name, bool cleanup_expired = false, StringValPtr LookupOtherInCache(const std::string& name, int request_type, bool cleanup_expired = false);
bool check_failed = false);
StringValPtr LookupOtherInCache(const std::string& name, int request_type,
bool cleanup_expired = false);
// Finish the request if we have a result. If not, time it out if // Finish the request if we have a result. If not, time it out if
// requested. // requested.
void CheckAsyncAddrRequest(const IPAddr& addr, bool timeout); void CheckAsyncAddrRequest(const IPAddr& addr, bool timeout);
void CheckAsyncHostRequest(const std::string& host, bool timeout); void CheckAsyncHostRequest(const std::string& host, bool timeout);
void CheckAsyncOtherRequest(const std::string& host, bool timeout, int request_type); void CheckAsyncOtherRequest(const std::string& host, bool timeout, int request_type);
void Event(EventHandlerPtr e, const DNS_MappingPtr& dm); void Event(EventHandlerPtr e, const DNS_MappingPtr& dm);
void Event(EventHandlerPtr e, const DNS_MappingPtr& dm, ListValPtr l1, ListValPtr l2); void Event(EventHandlerPtr e, const DNS_MappingPtr& dm, ListValPtr l1, ListValPtr l2);
void Event(EventHandlerPtr e, const DNS_MappingPtr& old_dm, DNS_MappingPtr new_dm); void Event(EventHandlerPtr e, const DNS_MappingPtr& old_dm, DNS_MappingPtr new_dm);
ValPtr BuildMappingVal(const DNS_MappingPtr& dm); ValPtr BuildMappingVal(const DNS_MappingPtr& dm);
void CompareMappings(const DNS_MappingPtr& prev_dm, const DNS_MappingPtr& new_dm); void CompareMappings(const DNS_MappingPtr& prev_dm, const DNS_MappingPtr& new_dm);
ListValPtr AddrListDelta(ListValPtr al1, ListValPtr al2); ListValPtr AddrListDelta(ListValPtr al1, ListValPtr al2);
using MappingKey = std::variant<IPAddr, std::pair<int, std::string>>; using MappingKey = std::variant<IPAddr, std::pair<int, std::string>>;
using MappingMap = std::map<MappingKey, DNS_MappingPtr>; using MappingMap = std::map<MappingKey, DNS_MappingPtr>;
void LoadCache(const std::string& path); void LoadCache(const std::string& path);
void Save(FILE* f, const MappingMap& m); void Save(FILE* f, const MappingMap& m);
// Issue as many queued async requests as slots are available. // Issue as many queued async requests as slots are available.
void IssueAsyncRequests(); void IssueAsyncRequests();
// IOSource interface. // IOSource interface.
void Process() override; void Process() override;
void ProcessFd(int fd, int flags) override; void ProcessFd(int fd, int flags) override;
void InitSource() override; void InitSource() override;
const char* Tag() override { return "DNS_Mgr"; } const char* Tag() override { return "DNS_Mgr"; }
double GetNextTimeout() override; double GetNextTimeout() override;
DNS_MgrMode mode; DNS_MgrMode mode;
MappingMap all_mappings; MappingMap all_mappings;
std::string cache_name; std::string cache_name;
std::string dir; // directory in which cache_name resides std::string dir; // directory in which cache_name resides
bool did_init = false; bool did_init = false;
int asyncs_pending = 0; int asyncs_pending = 0;
RecordTypePtr dm_rec; RecordTypePtr dm_rec;
ares_channel channel{}; ares_channel channel{};
using CallbackList = std::list<LookupCallback*>; using CallbackList = std::list<LookupCallback*>;
struct AsyncRequest struct AsyncRequest {
{ double time = 0.0;
double time = 0.0; IPAddr addr;
IPAddr addr; std::string host;
std::string host; CallbackList callbacks;
CallbackList callbacks; int type = 0;
int type = 0; bool processed = false;
bool processed = false;
AsyncRequest(std::string host, int request_type) : host(std::move(host)), type(request_type) AsyncRequest(std::string host, int request_type) : host(std::move(host)), type(request_type) {}
{ AsyncRequest(const IPAddr& addr) : addr(addr), type(T_PTR) {}
}
AsyncRequest(const IPAddr& addr) : addr(addr), type(T_PTR) { }
void Resolved(const std::string& name); void Resolved(const std::string& name);
void Resolved(TableValPtr addrs); void Resolved(TableValPtr addrs);
void Timeout(); void Timeout();
}; };
struct AsyncRequestCompare struct AsyncRequestCompare {
{ bool operator()(const AsyncRequest* a, const AsyncRequest* b) { return a->time > b->time; }
bool operator()(const AsyncRequest* a, const AsyncRequest* b) { return a->time > b->time; } };
};
using AsyncRequestMap = std::map<MappingKey, AsyncRequest*>; using AsyncRequestMap = std::map<MappingKey, AsyncRequest*>;
AsyncRequestMap asyncs; AsyncRequestMap asyncs;
using QueuedList = std::list<AsyncRequest*>; using QueuedList = std::list<AsyncRequest*>;
QueuedList asyncs_queued; QueuedList asyncs_queued;
unsigned long num_requests = 0; unsigned long num_requests = 0;
unsigned long successful = 0; unsigned long successful = 0;
unsigned long failed = 0; unsigned long failed = 0;
std::set<int> socket_fds; std::set<int> socket_fds;
std::set<int> write_socket_fds; std::set<int> write_socket_fds;
bool shutting_down = false; bool shutting_down = false;
}; };
extern DNS_Mgr* dns_mgr; extern DNS_Mgr* dns_mgr;
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -18,357 +18,307 @@
#include "zeek/Val.h" #include "zeek/Val.h"
#include "zeek/module_util.h" #include "zeek/module_util.h"
namespace zeek::detail namespace zeek::detail {
{
// BreakpointTimer used for time-based breakpoints // BreakpointTimer used for time-based breakpoints
class BreakpointTimer final : public Timer class BreakpointTimer final : public Timer {
{
public: public:
BreakpointTimer(DbgBreakpoint* arg_bp, double arg_t) : Timer(arg_t, TIMER_BREAKPOINT) BreakpointTimer(DbgBreakpoint* arg_bp, double arg_t) : Timer(arg_t, TIMER_BREAKPOINT) { bp = arg_bp; }
{
bp = arg_bp;
}
void Dispatch(double t, bool is_expire) override; void Dispatch(double t, bool is_expire) override;
protected: protected:
DbgBreakpoint* bp; DbgBreakpoint* bp;
}; };
void BreakpointTimer::Dispatch(double t, bool is_expire) void BreakpointTimer::Dispatch(double t, bool is_expire) {
{ if ( is_expire )
if ( is_expire ) return;
return;
bp->ShouldBreak(t);
bp->ShouldBreak(t); }
}
DbgBreakpoint::DbgBreakpoint() {
DbgBreakpoint::DbgBreakpoint() kind = BP_STMT;
{
kind = BP_STMT; enabled = temporary = false;
BPID = -1;
enabled = temporary = false;
BPID = -1; at_stmt = nullptr;
at_time = -1.0;
at_stmt = nullptr;
at_time = -1.0; repeat_count = hit_count = 0;
repeat_count = hit_count = 0; description[0] = 0;
source_filename = nullptr;
description[0] = 0; source_line = 0;
source_filename = nullptr; }
source_line = 0;
} DbgBreakpoint::~DbgBreakpoint() {
SetEnable(false); // clean up any active state
DbgBreakpoint::~DbgBreakpoint() RemoveFromGlobalMap();
{ }
SetEnable(false); // clean up any active state
RemoveFromGlobalMap(); bool DbgBreakpoint::SetEnable(bool do_enable) {
} bool old_value = enabled;
enabled = do_enable;
bool DbgBreakpoint::SetEnable(bool do_enable)
{ // Update statement counts.
bool old_value = enabled; if ( do_enable && ! old_value )
enabled = do_enable; AddToStmt();
// Update statement counts. else if ( ! do_enable && old_value )
if ( do_enable && ! old_value ) RemoveFromStmt();
AddToStmt();
return old_value;
else if ( ! do_enable && old_value ) }
RemoveFromStmt();
void DbgBreakpoint::AddToGlobalMap() {
return old_value; // Make sure it's not there already.
} RemoveFromGlobalMap();
void DbgBreakpoint::AddToGlobalMap() g_debugger_state.breakpoint_map.insert(BPMapType::value_type(at_stmt, this));
{ }
// Make sure it's not there already.
RemoveFromGlobalMap(); void DbgBreakpoint::RemoveFromGlobalMap() {
std::pair<BPMapType::iterator, BPMapType::iterator> p;
g_debugger_state.breakpoint_map.insert(BPMapType::value_type(at_stmt, this)); p = g_debugger_state.breakpoint_map.equal_range(at_stmt);
}
for ( BPMapType::iterator i = p.first; i != p.second; ) {
void DbgBreakpoint::RemoveFromGlobalMap() if ( i->second == this ) {
{ BPMapType::iterator next = i;
std::pair<BPMapType::iterator, BPMapType::iterator> p; ++next;
p = g_debugger_state.breakpoint_map.equal_range(at_stmt); g_debugger_state.breakpoint_map.erase(i);
i = next;
for ( BPMapType::iterator i = p.first; i != p.second; ) }
{ else
if ( i->second == this ) ++i;
{ }
BPMapType::iterator next = i; }
++next;
g_debugger_state.breakpoint_map.erase(i); void DbgBreakpoint::AddToStmt() {
i = next; if ( at_stmt )
} at_stmt->IncrBPCount();
else }
++i;
} void DbgBreakpoint::RemoveFromStmt() {
} if ( at_stmt )
at_stmt->DecrBPCount();
void DbgBreakpoint::AddToStmt() }
{
if ( at_stmt ) bool DbgBreakpoint::SetLocation(ParseLocationRec plr, std::string_view loc_str) {
at_stmt->IncrBPCount(); if ( plr.type == PLR_UNKNOWN ) {
} debug_msg("Breakpoint specifier invalid or operation canceled.\n");
return false;
void DbgBreakpoint::RemoveFromStmt() }
{
if ( at_stmt ) if ( plr.type == PLR_FILE_AND_LINE ) {
at_stmt->DecrBPCount(); kind = BP_LINE;
} source_filename = plr.filename;
source_line = plr.line;
bool DbgBreakpoint::SetLocation(ParseLocationRec plr, std::string_view loc_str)
{ if ( ! plr.stmt ) {
if ( plr.type == PLR_UNKNOWN ) debug_msg("No statement at that line.\n");
{ return false;
debug_msg("Breakpoint specifier invalid or operation canceled.\n"); }
return false;
} at_stmt = plr.stmt;
snprintf(description, sizeof(description), "%s:%d", source_filename, source_line);
if ( plr.type == PLR_FILE_AND_LINE )
{ debug_msg("Breakpoint %d set at %s\n", GetID(), Description());
kind = BP_LINE; }
source_filename = plr.filename;
source_line = plr.line; else if ( plr.type == PLR_FUNCTION ) {
std::string loc_s(loc_str);
if ( ! plr.stmt ) kind = BP_FUNC;
{ function_name = make_full_var_name(current_module.c_str(), loc_s.c_str());
debug_msg("No statement at that line.\n"); at_stmt = plr.stmt;
return false; const Location* loc = at_stmt->GetLocationInfo();
} snprintf(description, sizeof(description), "%s at %s:%d", function_name.c_str(), loc->filename, loc->last_line);
at_stmt = plr.stmt; debug_msg("Breakpoint %d set at %s\n", GetID(), Description());
snprintf(description, sizeof(description), "%s:%d", source_filename, source_line); }
debug_msg("Breakpoint %d set at %s\n", GetID(), Description()); SetEnable(true);
} AddToGlobalMap();
return true;
else if ( plr.type == PLR_FUNCTION ) }
{
std::string loc_s(loc_str); bool DbgBreakpoint::SetLocation(Stmt* stmt) {
kind = BP_FUNC; if ( ! stmt )
function_name = make_full_var_name(current_module.c_str(), loc_s.c_str()); return false;
at_stmt = plr.stmt;
const Location* loc = at_stmt->GetLocationInfo(); kind = BP_STMT;
snprintf(description, sizeof(description), "%s at %s:%d", function_name.c_str(), at_stmt = stmt;
loc->filename, loc->last_line);
SetEnable(true);
debug_msg("Breakpoint %d set at %s\n", GetID(), Description()); AddToGlobalMap();
}
const Location* loc = stmt->GetLocationInfo();
SetEnable(true); snprintf(description, sizeof(description), "%s:%d", loc->filename, loc->last_line);
AddToGlobalMap();
return true; debug_msg("Breakpoint %d set at %s\n", GetID(), Description());
}
return true;
bool DbgBreakpoint::SetLocation(Stmt* stmt) }
{
if ( ! stmt ) bool DbgBreakpoint::SetLocation(double t) {
return false; debug_msg("SetLocation(time) has not been debugged.");
return false;
kind = BP_STMT;
at_stmt = stmt; kind = BP_TIME;
at_time = t;
SetEnable(true);
AddToGlobalMap(); timer_mgr->Add(new BreakpointTimer(this, t));
const Location* loc = stmt->GetLocationInfo(); debug_msg("Time-based breakpoints not yet supported.\n");
snprintf(description, sizeof(description), "%s:%d", loc->filename, loc->last_line); return false;
}
debug_msg("Breakpoint %d set at %s\n", GetID(), Description());
bool DbgBreakpoint::Reset() {
return true; ParseLocationRec plr;
}
switch ( kind ) {
bool DbgBreakpoint::SetLocation(double t) case BP_TIME: debug_msg("Time-based breakpoints not yet supported.\n"); break;
{
debug_msg("SetLocation(time) has not been debugged."); case BP_FUNC:
return false; case BP_STMT:
case BP_LINE:
kind = BP_TIME; plr.type = PLR_FUNCTION;
at_time = t; //### How to deal with wildcards?
//### perhaps save user choices?--tough...
timer_mgr->Add(new BreakpointTimer(this, t)); break;
}
debug_msg("Time-based breakpoints not yet supported.\n");
return false; reporter->InternalError("DbgBreakpoint::Reset function incomplete.");
}
// Cannot be reached.
bool DbgBreakpoint::Reset() return false;
{ }
ParseLocationRec plr;
bool DbgBreakpoint::SetCondition(const std::string& new_condition) {
switch ( kind ) condition = new_condition;
{ return true;
case BP_TIME: }
debug_msg("Time-based breakpoints not yet supported.\n");
break; bool DbgBreakpoint::SetRepeatCount(int count) {
repeat_count = count;
case BP_FUNC: return true;
case BP_STMT: }
case BP_LINE:
plr.type = PLR_FUNCTION; BreakCode DbgBreakpoint::HasHit() {
//### How to deal with wildcards? if ( temporary ) {
//### perhaps save user choices?--tough... SetEnable(false);
break; return BC_HIT_AND_DELETE;
} }
reporter->InternalError("DbgBreakpoint::Reset function incomplete."); if ( condition.size() ) {
// TODO: ### evaluate using debugger frame too
// Cannot be reached. auto yes = dbg_eval_expr(condition.c_str());
return false;
} if ( ! yes ) {
debug_msg("Breakpoint condition '%s' invalid, removing condition.\n", condition.c_str());
bool DbgBreakpoint::SetCondition(const std::string& new_condition) SetCondition("");
{ PrintHitMsg();
condition = new_condition; return BC_HIT;
return true; }
}
if ( ! IsIntegral(yes->GetType()->Tag()) && ! IsBool(yes->GetType()->Tag()) ) {
bool DbgBreakpoint::SetRepeatCount(int count) PrintHitMsg();
{ debug_msg("Breakpoint condition should return an integral type");
repeat_count = count; return BC_HIT_AND_DELETE;
return true; }
}
yes->CoerceToInt();
BreakCode DbgBreakpoint::HasHit() if ( yes->IsZero() ) {
{ return BC_NO_HIT;
if ( temporary ) }
{ }
SetEnable(false);
return BC_HIT_AND_DELETE; int repcount = GetRepeatCount();
} if ( repcount ) {
if ( ++hit_count == repcount ) {
if ( condition.size() ) hit_count = 0;
{ PrintHitMsg();
// TODO: ### evaluate using debugger frame too return BC_HIT;
auto yes = dbg_eval_expr(condition.c_str()); }
if ( ! yes ) return BC_NO_HIT;
{ }
debug_msg("Breakpoint condition '%s' invalid, removing condition.\n",
condition.c_str()); PrintHitMsg();
SetCondition(""); return BC_HIT;
PrintHitMsg(); }
return BC_HIT;
} BreakCode DbgBreakpoint::ShouldBreak(Stmt* s) {
if ( ! IsEnabled() )
if ( ! IsIntegral(yes->GetType()->Tag()) && ! IsBool(yes->GetType()->Tag()) ) return BC_NO_HIT;
{
PrintHitMsg(); switch ( kind ) {
debug_msg("Breakpoint condition should return an integral type"); case BP_STMT:
return BC_HIT_AND_DELETE; case BP_FUNC:
} if ( at_stmt != s )
return BC_NO_HIT;
yes->CoerceToInt(); break;
if ( yes->IsZero() )
{ case BP_LINE:
return BC_NO_HIT; assert(s->GetLocationInfo()->first_line <= source_line && s->GetLocationInfo()->last_line >= source_line);
} break;
}
case BP_TIME: assert(false);
int repcount = GetRepeatCount();
if ( repcount ) default: reporter->InternalError("Invalid breakpoint type in DbgBreakpoint::ShouldBreak");
{ }
if ( ++hit_count == repcount )
{ // If we got here, that means that the breakpoint could hit,
hit_count = 0; // except potentially if it has a special condition or a repeat count.
PrintHitMsg();
return BC_HIT; BreakCode code = HasHit();
} if ( code )
g_debugger_state.BreakBeforeNextStmt(true);
return BC_NO_HIT;
} return code;
}
PrintHitMsg();
return BC_HIT; BreakCode DbgBreakpoint::ShouldBreak(double t) {
} if ( kind != BP_TIME )
reporter->InternalError("Calling ShouldBreak(time) on a non-time breakpoint");
BreakCode DbgBreakpoint::ShouldBreak(Stmt* s)
{ if ( t < at_time )
if ( ! IsEnabled() ) return BC_NO_HIT;
return BC_NO_HIT;
if ( ! IsEnabled() )
switch ( kind ) return BC_NO_HIT;
{
case BP_STMT: BreakCode code = HasHit();
case BP_FUNC: if ( code )
if ( at_stmt != s ) g_debugger_state.BreakBeforeNextStmt(true);
return BC_NO_HIT;
break; return code;
}
case BP_LINE:
assert(s->GetLocationInfo()->first_line <= source_line && void DbgBreakpoint::PrintHitMsg() {
s->GetLocationInfo()->last_line >= source_line); switch ( kind ) {
break; case BP_STMT:
case BP_FUNC:
case BP_TIME: case BP_LINE: {
assert(false); ODesc d;
Frame* f = g_frame_stack.back();
default: const ScriptFunc* func = f->GetFunction();
reporter->InternalError("Invalid breakpoint type in DbgBreakpoint::ShouldBreak");
} if ( func )
func->DescribeDebug(&d, f->GetFuncArgs());
// If we got here, that means that the breakpoint could hit,
// except potentially if it has a special condition or a repeat count. const Location* loc = at_stmt->GetLocationInfo();
BreakCode code = HasHit(); debug_msg("Breakpoint %d, %s at %s:%d\n", GetID(), d.Description(), loc->filename, loc->first_line);
if ( code ) }
g_debugger_state.BreakBeforeNextStmt(true); return;
return code; case BP_TIME: assert(false);
}
default: reporter->InternalError("Missed a case in DbgBreakpoint::PrintHitMsg\n");
BreakCode DbgBreakpoint::ShouldBreak(double t) }
{ }
if ( kind != BP_TIME )
reporter->InternalError("Calling ShouldBreak(time) on a non-time breakpoint"); } // namespace zeek::detail
if ( t < at_time )
return BC_NO_HIT;
if ( ! IsEnabled() )
return BC_NO_HIT;
BreakCode code = HasHit();
if ( code )
g_debugger_state.BreakBeforeNextStmt(true);
return code;
}
void DbgBreakpoint::PrintHitMsg()
{
switch ( kind )
{
case BP_STMT:
case BP_FUNC:
case BP_LINE:
{
ODesc d;
Frame* f = g_frame_stack.back();
const ScriptFunc* func = f->GetFunction();
if ( func )
func->DescribeDebug(&d, f->GetFuncArgs());
const Location* loc = at_stmt->GetLocationInfo();
debug_msg("Breakpoint %d, %s at %s:%d\n", GetID(), d.Description(), loc->filename,
loc->first_line);
}
return;
case BP_TIME:
assert(false);
default:
reporter->InternalError("Missed a case in DbgBreakpoint::PrintHitMsg\n");
}
}
} // namespace zeek::detail

View file

@ -6,95 +6,82 @@
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek::detail namespace zeek::detail {
{
class Stmt; class Stmt;
class ParseLocationRec; class ParseLocationRec;
enum BreakCode enum BreakCode { BC_NO_HIT, BC_HIT, BC_HIT_AND_DELETE };
{ class DbgBreakpoint {
BC_NO_HIT, enum Kind { BP_STMT = 0, BP_FUNC, BP_LINE, BP_TIME };
BC_HIT,
BC_HIT_AND_DELETE
};
class DbgBreakpoint
{
enum Kind
{
BP_STMT = 0,
BP_FUNC,
BP_LINE,
BP_TIME
};
public: public:
DbgBreakpoint(); DbgBreakpoint();
~DbgBreakpoint(); ~DbgBreakpoint();
int GetID() const { return BPID; } int GetID() const { return BPID; }
void SetID(int newID) { BPID = newID; } void SetID(int newID) { BPID = newID; }
// True if breakpoint could be set; false otherwise // True if breakpoint could be set; false otherwise
bool SetLocation(ParseLocationRec plr, std::string_view loc_str); bool SetLocation(ParseLocationRec plr, std::string_view loc_str);
bool SetLocation(Stmt* stmt); bool SetLocation(Stmt* stmt);
bool SetLocation(double time); bool SetLocation(double time);
bool Reset(); // cancel and re-apply bpt when restarting execution bool Reset(); // cancel and re-apply bpt when restarting execution
// Temporary = disable (remove?) the breakpoint right after it's hit. // Temporary = disable (remove?) the breakpoint right after it's hit.
bool IsTemporary() const { return temporary; } bool IsTemporary() const { return temporary; }
void SetTemporary(bool is_temporary) { temporary = is_temporary; } void SetTemporary(bool is_temporary) { temporary = is_temporary; }
// Feed it a Stmt* or a time and see if this breakpoint should // Feed it a Stmt* or a time and see if this breakpoint should
// hit. bcHitAndDelete means that it has hit, and should now be // hit. bcHitAndDelete means that it has hit, and should now be
// deleted entirely. // deleted entirely.
// //
// NOTE: If it returns a hit, the DbgBreakpoint object will take // NOTE: If it returns a hit, the DbgBreakpoint object will take
// appropriate action (e.g., resetting counters). // appropriate action (e.g., resetting counters).
BreakCode ShouldBreak(Stmt* s); BreakCode ShouldBreak(Stmt* s);
BreakCode ShouldBreak(double t); BreakCode ShouldBreak(double t);
const std::string& GetCondition() const { return condition; } const std::string& GetCondition() const { return condition; }
bool SetCondition(const std::string& new_condition); bool SetCondition(const std::string& new_condition);
int GetRepeatCount() const { return repeat_count; } int GetRepeatCount() const { return repeat_count; }
bool SetRepeatCount(int count); // implements function of ignore command in gdb bool SetRepeatCount(int count); // implements function of ignore command in gdb
bool IsEnabled() const { return enabled; } bool IsEnabled() const { return enabled; }
bool SetEnable(bool do_enable); bool SetEnable(bool do_enable);
// e.g. "FooBar() at foo.c:23" // e.g. "FooBar() at foo.c:23"
const char* Description() const { return description; } const char* Description() const { return description; }
protected: protected:
void AddToGlobalMap(); void AddToGlobalMap();
void RemoveFromGlobalMap(); void RemoveFromGlobalMap();
void AddToStmt(); void AddToStmt();
void RemoveFromStmt(); void RemoveFromStmt();
BreakCode HasHit(); // a breakpoint hit, update state, return proper code. BreakCode HasHit(); // a breakpoint hit, update state, return proper code.
void PrintHitMsg(); // display reason when the breakpoint hits void PrintHitMsg(); // display reason when the breakpoint hits
Kind kind; Kind kind;
int32_t BPID; int32_t BPID;
char description[512]; char description[512];
std::string function_name; // location std::string function_name; // location
const char* source_filename; const char* source_filename;
int32_t source_line; int32_t source_line;
bool enabled; // ### comment this and next bool enabled; // ### comment this and next
bool temporary; bool temporary;
Stmt* at_stmt; Stmt* at_stmt;
double at_time; // break when the virtual time is this double at_time; // break when the virtual time is this
// Support for conditional and N'th time breakpoints. // Support for conditional and N'th time breakpoints.
int32_t repeat_count; // if positive, break after this many hits int32_t repeat_count; // if positive, break after this many hits
int32_t hit_count; // how many times it's been hit (w/o breaking) int32_t hit_count; // how many times it's been hit (w/o breaking)
std::string condition; // condition to evaluate; nil for none std::string condition; // condition to evaluate; nil for none
}; };
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -2,30 +2,27 @@
#pragma once #pragma once
namespace zeek::detail namespace zeek::detail {
{
class Expr; class Expr;
// Automatic displays: display these at each stoppage. // Automatic displays: display these at each stoppage.
class DbgDisplay class DbgDisplay {
{
public: public:
DbgDisplay(Expr* expr_to_display); DbgDisplay(Expr* expr_to_display);
bool IsEnabled() { return enabled; } bool IsEnabled() { return enabled; }
bool SetEnable(bool do_enable) bool SetEnable(bool do_enable) {
{ bool old_value = enabled;
bool old_value = enabled; enabled = do_enable;
enabled = do_enable; return old_value;
return old_value; }
}
const Expr* Expression() const { return expression; } const Expr* Expression() const { return expression; }
protected: protected:
bool enabled; bool enabled;
Expr* expression; Expr* expression;
}; };
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -7,18 +7,11 @@
#include "zeek/Debug.h" #include "zeek/Debug.h"
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
namespace zeek::detail namespace zeek::detail {
{
// Support classes // Support classes
DbgWatch::DbgWatch(zeek::Obj* var_to_watch) DbgWatch::DbgWatch(zeek::Obj* var_to_watch) { reporter->InternalError("DbgWatch unimplemented"); }
{
reporter->InternalError("DbgWatch unimplemented");
}
DbgWatch::DbgWatch(Expr* expr_to_watch) DbgWatch::DbgWatch(Expr* expr_to_watch) { reporter->InternalError("DbgWatch unimplemented"); }
{
reporter->InternalError("DbgWatch unimplemented");
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -4,26 +4,23 @@
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek namespace zeek {
{
class Obj; class Obj;
} }
namespace zeek::detail namespace zeek::detail {
{
class Expr; class Expr;
class DbgWatch class DbgWatch {
{
public: public:
explicit DbgWatch(Obj* var_to_watch); explicit DbgWatch(Obj* var_to_watch);
explicit DbgWatch(Expr* expr_to_watch); explicit DbgWatch(Expr* expr_to_watch);
~DbgWatch() = default; ~DbgWatch() = default;
protected: protected:
Obj* var; Obj* var;
Expr* expr; Expr* expr;
}; };
} // namespace zeek::detail } // namespace zeek::detail

File diff suppressed because it is too large Load diff

View file

@ -11,17 +11,16 @@
#include "zeek/StmtEnums.h" #include "zeek/StmtEnums.h"
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek namespace zeek {
{
class Val; class Val;
template <class T> class IntrusivePtr; template<class T>
class IntrusivePtr;
using ValPtr = zeek::IntrusivePtr<Val>; using ValPtr = zeek::IntrusivePtr<Val>;
extern std::string current_module; extern std::string current_module;
namespace detail namespace detail {
{
class Frame; class Frame;
class Stmt; class Stmt;
@ -30,20 +29,14 @@ class DbgWatch;
class DbgDisplay; class DbgDisplay;
// This needs to be defined before we do the includes that come after it. // This needs to be defined before we do the includes that come after it.
enum ParseLocationRecType enum ParseLocationRecType { PLR_UNKNOWN, PLR_FILE_AND_LINE, PLR_FUNCTION };
{ class ParseLocationRec {
PLR_UNKNOWN,
PLR_FILE_AND_LINE,
PLR_FUNCTION
};
class ParseLocationRec
{
public: public:
ParseLocationRecType type; ParseLocationRecType type;
int32_t line; int32_t line;
Stmt* stmt; Stmt* stmt;
const char* filename; const char* filename;
}; };
class StmtLocMapping; class StmtLocMapping;
using Filemap = std::deque<StmtLocMapping*>; // mapping for a single file using Filemap = std::deque<StmtLocMapping*>; // mapping for a single file
@ -51,94 +44,89 @@ using Filemap = std::deque<StmtLocMapping*>; // mapping for a single file
using BPIDMapType = std::map<int, DbgBreakpoint*>; using BPIDMapType = std::map<int, DbgBreakpoint*>;
using BPMapType = std::multimap<const Stmt*, DbgBreakpoint*>; using BPMapType = std::multimap<const Stmt*, DbgBreakpoint*>;
class TraceState class TraceState {
{
public: public:
TraceState() TraceState() {
{ dbgtrace = false;
dbgtrace = false; trace_file = stderr;
trace_file = stderr; }
}
// Returns previous filename. // Returns previous filename.
FILE* SetTraceFile(const char* trace_filename); FILE* SetTraceFile(const char* trace_filename);
bool DoTrace() const { return dbgtrace; } bool DoTrace() const { return dbgtrace; }
void TraceOn(); void TraceOn();
void TraceOff(); void TraceOff();
int LogTrace(const char* fmt, ...) __attribute__((format(printf, 2, 3))); int LogTrace(const char* fmt, ...) __attribute__((format(printf, 2, 3)));
; ;
protected: protected:
bool dbgtrace; // print an execution trace bool dbgtrace; // print an execution trace
FILE* trace_file; FILE* trace_file;
}; };
extern TraceState g_trace_state; extern TraceState g_trace_state;
class DebuggerState class DebuggerState {
{
public: public:
DebuggerState(); DebuggerState();
~DebuggerState(); ~DebuggerState();
int NextBPID() { return next_bp_id++; } int NextBPID() { return next_bp_id++; }
int NextWatchID() { return next_watch_id++; } int NextWatchID() { return next_watch_id++; }
int NextDisplayID() { return next_display_id++; } int NextDisplayID() { return next_display_id++; }
bool BreakBeforeNextStmt() { return break_before_next_stmt; } bool BreakBeforeNextStmt() { return break_before_next_stmt; }
void BreakBeforeNextStmt(bool dobrk) { break_before_next_stmt = dobrk; } void BreakBeforeNextStmt(bool dobrk) { break_before_next_stmt = dobrk; }
bool BreakFromSignal() { return break_from_signal; } bool BreakFromSignal() { return break_from_signal; }
void BreakFromSignal(bool dobrk) { break_from_signal = dobrk; } void BreakFromSignal(bool dobrk) { break_from_signal = dobrk; }
// Temporary state: vanishes when execution resumes. // Temporary state: vanishes when execution resumes.
//### Umesh, why do these all need to be public? -- Vern //### Umesh, why do these all need to be public? -- Vern
// Which frame we're looking at; 0 = the innermost frame. // Which frame we're looking at; 0 = the innermost frame.
int curr_frame_idx; int curr_frame_idx;
bool already_did_list; // did we already do a 'list' command? bool already_did_list; // did we already do a 'list' command?
Location last_loc; // used by 'list'; the last location listed Location last_loc; // used by 'list'; the last location listed
BPIDMapType breakpoints; // BPID -> Breakpoint BPIDMapType breakpoints; // BPID -> Breakpoint
std::vector<DbgWatch*> watches; std::vector<DbgWatch*> watches;
std::vector<DbgDisplay*> displays; std::vector<DbgDisplay*> displays;
BPMapType breakpoint_map; // maps Stmt -> Breakpoints on it BPMapType breakpoint_map; // maps Stmt -> Breakpoints on it
protected: protected:
bool break_before_next_stmt; // trap into debugger (used for "step") bool break_before_next_stmt; // trap into debugger (used for "step")
bool break_from_signal; // was break caused by a signal? bool break_from_signal; // was break caused by a signal?
int next_bp_id, next_watch_id, next_display_id; int next_bp_id, next_watch_id, next_display_id;
private: private:
Frame* dbg_locals; // unused Frame* dbg_locals; // unused
}; };
// Source line -> statement mapping. // Source line -> statement mapping.
// (obj -> source line mapping available in object itself) // (obj -> source line mapping available in object itself)
class StmtLocMapping class StmtLocMapping {
{
public: public:
StmtLocMapping() { } StmtLocMapping() {}
StmtLocMapping(const Location* l, Stmt* s) StmtLocMapping(const Location* l, Stmt* s) {
{ loc = *l;
loc = *l; stmt = s;
stmt = s; }
}
bool StartsAfter(const StmtLocMapping* m2); bool StartsAfter(const StmtLocMapping* m2);
const Location& Loc() const { return loc; } const Location& Loc() const { return loc; }
Stmt* Statement() const { return stmt; } Stmt* Statement() const { return stmt; }
protected: protected:
Location loc; Location loc;
Stmt* stmt = nullptr; Stmt* stmt = nullptr;
}; };
extern bool g_policy_debug; // enable debugging facility extern bool g_policy_debug; // enable debugging facility
extern DebuggerState g_debugger_state; extern DebuggerState g_debugger_state;
@ -195,5 +183,5 @@ extern std::map<std::string, Filemap*> g_dbgfilemaps; // filename => filemap
// Perhaps add a code/priority argument to do selective output. // Perhaps add a code/priority argument to do selective output.
int debug_msg(const char* fmt, ...) __attribute__((format(printf, 1, 2))); int debug_msg(const char* fmt, ...) __attribute__((format(printf, 1, 2)));
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

File diff suppressed because it is too large Load diff

View file

@ -11,39 +11,37 @@
// This file is generated during the build. // This file is generated during the build.
#include "DebugCmdConstants.h" #include "DebugCmdConstants.h"
namespace zeek::detail namespace zeek::detail {
{
class DebugCmdInfo class DebugCmdInfo {
{
public: public:
DebugCmdInfo(const DebugCmdInfo& info); DebugCmdInfo(const DebugCmdInfo& info);
DebugCmdInfo(DebugCmd cmd, const char* const* names, int num_names, bool resume_execution, DebugCmdInfo(DebugCmd cmd, const char* const* names, int num_names, bool resume_execution,
const char* const helpstring, bool repeatable); const char* const helpstring, bool repeatable);
DebugCmdInfo() : helpstring(nullptr) { } DebugCmdInfo() : helpstring(nullptr) {}
int Cmd() const { return cmd; } int Cmd() const { return cmd; }
int NumNames() const { return num_names; } int NumNames() const { return num_names; }
const std::vector<const char*>& Names() const { return names; } const std::vector<const char*>& Names() const { return names; }
bool ResumeExecution() const { return resume_execution; } bool ResumeExecution() const { return resume_execution; }
const char* Helpstring() const { return helpstring; } const char* Helpstring() const { return helpstring; }
bool Repeatable() const { return repeatable; } bool Repeatable() const { return repeatable; }
protected: protected:
DebugCmd cmd; DebugCmd cmd;
int32_t num_names; int32_t num_names;
std::vector<const char*> names; std::vector<const char*> names;
const char* const helpstring; const char* const helpstring;
// Whether executing this should restart execution of the script. // Whether executing this should restart execution of the script.
bool resume_execution; bool resume_execution;
// Does entering a blank line repeat this command? // Does entering a blank line repeat this command?
bool repeatable; bool repeatable;
}; };
using DebugCmdInfoQueue = std::deque<DebugCmdInfo*>; using DebugCmdInfoQueue = std::deque<DebugCmdInfo*>;
extern DebugCmdInfoQueue g_DebugCmdInfos; extern DebugCmdInfoQueue g_DebugCmdInfos;
@ -80,4 +78,4 @@ DbgCmdFn dbg_cmd_info;
DbgCmdFn dbg_cmd_list; DbgCmdFn dbg_cmd_list;
DbgCmdFn dbg_cmd_trace; DbgCmdFn dbg_cmd_trace;
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -11,202 +11,178 @@
zeek::detail::DebugLogger zeek::detail::debug_logger; zeek::detail::DebugLogger zeek::detail::debug_logger;
zeek::detail::DebugLogger& debug_logger = zeek::detail::debug_logger; zeek::detail::DebugLogger& debug_logger = zeek::detail::debug_logger;
namespace zeek::detail namespace zeek::detail {
{
// Same order here as in DebugStream. // Same order here as in DebugStream.
DebugLogger::Stream DebugLogger::streams[NUM_DBGS] = { DebugLogger::Stream DebugLogger::streams[NUM_DBGS] =
{"serial", 0, false}, {"rules", 0, false}, {"string", 0, false}, {{"serial", 0, false}, {"rules", 0, false}, {"string", 0, false}, {"notifiers", 0, false},
{"notifiers", 0, false}, {"main-loop", 0, false}, {"dpd", 0, false}, {"main-loop", 0, false}, {"dpd", 0, false}, {"packet_analysis", 0, false}, {"file_analysis", 0, false},
{"packet_analysis", 0, false}, {"file_analysis", 0, false}, {"tm", 0, false}, {"tm", 0, false}, {"logging", 0, false}, {"input", 0, false}, {"threading", 0, false},
{"logging", 0, false}, {"input", 0, false}, {"threading", 0, false}, {"plugins", 0, false}, {"zeekygen", 0, false}, {"pktio", 0, false}, {"broker", 0, false},
{"plugins", 0, false}, {"zeekygen", 0, false}, {"pktio", 0, false}, {"scripts", 0, false}, {"supervisor", 0, false}, {"hashkey", 0, false}, {"spicy", 0, false}};
{"broker", 0, false}, {"scripts", 0, false}, {"supervisor", 0, false},
{"hashkey", 0, false}, {"spicy", 0, false}};
DebugLogger::DebugLogger() DebugLogger::DebugLogger() {
{ verbose = false;
verbose = false; file = nullptr;
file = nullptr; }
}
DebugLogger::~DebugLogger() DebugLogger::~DebugLogger() {
{ if ( file && file != stderr )
if ( file && file != stderr ) fclose(file);
fclose(file); }
}
void DebugLogger::OpenDebugLog(const char* filename) void DebugLogger::OpenDebugLog(const char* filename) {
{ if ( filename ) {
if ( filename ) filename = util::detail::log_file_name(filename);
{
filename = util::detail::log_file_name(filename);
file = fopen(filename, "w"); file = fopen(filename, "w");
if ( ! file ) if ( ! file ) {
{ // The reporter may not be initialized here yet.
// The reporter may not be initialized here yet. if ( reporter )
if ( reporter ) reporter->FatalError("can't open '%s' for debugging output", filename);
reporter->FatalError("can't open '%s' for debugging output", filename); else {
else fprintf(stderr, "can't open '%s' for debugging output\n", filename);
{ exit(1);
fprintf(stderr, "can't open '%s' for debugging output\n", filename); }
exit(1); }
}
}
util::detail::setvbuf(file, NULL, _IOLBF, 0); util::detail::setvbuf(file, NULL, _IOLBF, 0);
} }
else else
file = stderr; file = stderr;
} }
void DebugLogger::ShowStreamsHelp() void DebugLogger::ShowStreamsHelp() {
{ fprintf(stderr, "\n");
fprintf(stderr, "\n"); fprintf(stderr, "Enable debug output into debug.log with -B <streams>.\n");
fprintf(stderr, "Enable debug output into debug.log with -B <streams>.\n"); fprintf(stderr, "<streams> is a comma-separated list of streams to enable.\n");
fprintf(stderr, "<streams> is a comma-separated list of streams to enable.\n"); fprintf(stderr, "\n");
fprintf(stderr, "\n"); fprintf(stderr, "Available streams:\n");
fprintf(stderr, "Available streams:\n");
for ( int i = 0; i < NUM_DBGS; ++i ) for ( int i = 0; i < NUM_DBGS; ++i )
fprintf(stderr, " %s\n", streams[i].prefix); fprintf(stderr, " %s\n", streams[i].prefix);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, " plugin-<plugin-name> (replace '::' in name with '-'; e.g., '-B " fprintf(stderr,
"plugin-Zeek-Netmap')\n"); " plugin-<plugin-name> (replace '::' in name with '-'; e.g., '-B "
fprintf(stderr, "\n"); "plugin-Zeek-Netmap')\n");
fprintf(stderr, "Pseudo streams\n"); fprintf(stderr, "\n");
fprintf(stderr, " verbose Increase verbosity.\n"); fprintf(stderr, "Pseudo streams\n");
fprintf(stderr, " all Enable all streams at maximum verbosity.\n"); fprintf(stderr, " verbose Increase verbosity.\n");
fprintf(stderr, "\n"); fprintf(stderr, " all Enable all streams at maximum verbosity.\n");
} fprintf(stderr, "\n");
}
void DebugLogger::EnableStreams(const char* s) void DebugLogger::EnableStreams(const char* s) {
{ char* brkt;
char* brkt; char* tmp = util::copy_string(s);
char* tmp = util::copy_string(s); char* tok = strtok(tmp, ",");
char* tok = strtok(tmp, ",");
while ( tok ) while ( tok ) {
{ if ( strcasecmp("all", tok) == 0 ) {
if ( strcasecmp("all", tok) == 0 ) for ( int i = 0; i < NUM_DBGS; ++i ) {
{ streams[i].enabled = true;
for ( int i = 0; i < NUM_DBGS; ++i ) enabled_streams.insert(streams[i].prefix);
{ }
streams[i].enabled = true;
enabled_streams.insert(streams[i].prefix);
}
verbose = true; verbose = true;
goto next; goto next;
} }
if ( strcasecmp("verbose", tok) == 0 ) if ( strcasecmp("verbose", tok) == 0 ) {
{ verbose = true;
verbose = true; goto next;
goto next; }
}
if ( strcasecmp("help", tok) == 0 ) if ( strcasecmp("help", tok) == 0 ) {
{ ShowStreamsHelp();
ShowStreamsHelp(); exit(0);
exit(0); }
}
if ( util::starts_with(tok, "plugin-") ) if ( util::starts_with(tok, "plugin-") ) {
{ // Cannot verify this at this time, plugins may not
// Cannot verify this at this time, plugins may not // have been loaded.
// have been loaded. enabled_streams.insert(tok);
enabled_streams.insert(tok); goto next;
goto next; }
}
int i; int i;
for ( i = 0; i < NUM_DBGS; ++i ) for ( i = 0; i < NUM_DBGS; ++i ) {
{ if ( strcasecmp(streams[i].prefix, tok) == 0 ) {
if ( strcasecmp(streams[i].prefix, tok) == 0 ) streams[i].enabled = true;
{ enabled_streams.insert(tok);
streams[i].enabled = true; goto next;
enabled_streams.insert(tok); }
goto next; }
}
}
reporter->FatalError("unknown debug stream '%s', try -B help.\n", tok); reporter->FatalError("unknown debug stream '%s', try -B help.\n", tok);
next: next:
tok = strtok(0, ","); tok = strtok(0, ",");
} }
delete[] tmp; delete[] tmp;
} }
bool DebugLogger::CheckStreams(const std::set<std::string>& plugin_names) bool DebugLogger::CheckStreams(const std::set<std::string>& plugin_names) {
{ bool ok = true;
bool ok = true;
std::set<std::string> available_plugin_streams; std::set<std::string> available_plugin_streams;
for ( const auto& p : plugin_names ) for ( const auto& p : plugin_names )
available_plugin_streams.insert(PluginStreamName(p)); available_plugin_streams.insert(PluginStreamName(p));
for ( const auto& stream : enabled_streams ) for ( const auto& stream : enabled_streams ) {
{ if ( ! util::starts_with(stream, "plugin-") )
if ( ! util::starts_with(stream, "plugin-") ) continue;
continue;
if ( available_plugin_streams.count(stream) == 0 ) if ( available_plugin_streams.count(stream) == 0 ) {
{ reporter->Error("No plugin debug stream '%s' found", stream.c_str());
reporter->Error("No plugin debug stream '%s' found", stream.c_str()); ok = false;
ok = false; }
} }
}
return ok; return ok;
} }
void DebugLogger::Log(DebugStream stream, const char* fmt, ...) void DebugLogger::Log(DebugStream stream, const char* fmt, ...) {
{ Stream* g = &streams[int(stream)];
Stream* g = &streams[int(stream)];
if ( ! g->enabled ) if ( ! g->enabled )
return; return;
fprintf(file, "%17.06f/%17.06f [%s] ", run_state::network_time, util::current_time(true), fprintf(file, "%17.06f/%17.06f [%s] ", run_state::network_time, util::current_time(true), g->prefix);
g->prefix);
for ( int i = g->indent; i > 0; --i ) for ( int i = g->indent; i > 0; --i )
fputs(" ", file); fputs(" ", file);
va_list ap; va_list ap;
va_start(ap, fmt); va_start(ap, fmt);
vfprintf(file, fmt, ap); vfprintf(file, fmt, ap);
va_end(ap); va_end(ap);
fputc('\n', file); fputc('\n', file);
fflush(file); fflush(file);
} }
void DebugLogger::Log(const plugin::Plugin& plugin, const char* fmt, ...) void DebugLogger::Log(const plugin::Plugin& plugin, const char* fmt, ...) {
{ std::string tok = PluginStreamName(plugin.Name());
std::string tok = PluginStreamName(plugin.Name());
if ( enabled_streams.find(tok) == enabled_streams.end() ) if ( enabled_streams.find(tok) == enabled_streams.end() )
return; return;
fprintf(file, "%17.06f/%17.06f [plugin %s] ", run_state::network_time, util::current_time(true), fprintf(file, "%17.06f/%17.06f [plugin %s] ", run_state::network_time, util::current_time(true),
plugin.Name().c_str()); plugin.Name().c_str());
va_list ap; va_list ap;
va_start(ap, fmt); va_start(ap, fmt);
vfprintf(file, fmt, ap); vfprintf(file, fmt, ap);
va_end(ap); va_end(ap);
fputc('\n', file); fputc('\n', file);
fflush(file); fflush(file);
} }
} // namespace zeek::detail } // namespace zeek::detail
#endif #endif

View file

@ -13,115 +13,106 @@
#include "zeek/util.h" #include "zeek/util.h"
#define DBG_LOG(stream, ...) \ #define DBG_LOG(stream, ...) \
if ( ::zeek::detail::debug_logger.IsEnabled(stream) ) \ if ( ::zeek::detail::debug_logger.IsEnabled(stream) ) \
::zeek::detail::debug_logger.Log(stream, __VA_ARGS__) ::zeek::detail::debug_logger.Log(stream, __VA_ARGS__)
#define DBG_LOG_VERBOSE(stream, ...) \ #define DBG_LOG_VERBOSE(stream, ...) \
if ( ::zeek::detail::debug_logger.IsVerbose() && \ if ( ::zeek::detail::debug_logger.IsVerbose() && ::zeek::detail::debug_logger.IsEnabled(stream) ) \
::zeek::detail::debug_logger.IsEnabled(stream) ) \ ::zeek::detail::debug_logger.Log(stream, __VA_ARGS__)
::zeek::detail::debug_logger.Log(stream, __VA_ARGS__)
#define DBG_PUSH(stream) ::zeek::detail::debug_logger.PushIndent(stream) #define DBG_PUSH(stream) ::zeek::detail::debug_logger.PushIndent(stream)
#define DBG_POP(stream) ::zeek::detail::debug_logger.PopIndent(stream) #define DBG_POP(stream) ::zeek::detail::debug_logger.PopIndent(stream)
#define PLUGIN_DBG_LOG(plugin, ...) ::zeek::detail::debug_logger.Log(plugin, __VA_ARGS__) #define PLUGIN_DBG_LOG(plugin, ...) ::zeek::detail::debug_logger.Log(plugin, __VA_ARGS__)
namespace zeek namespace zeek {
{
namespace plugin namespace plugin {
{
class Plugin; class Plugin;
} }
// To add a new debugging stream, add a constant here as well as // To add a new debugging stream, add a constant here as well as
// an entry to DebugLogger::streams in DebugLogger.cc. // an entry to DebugLogger::streams in DebugLogger.cc.
enum DebugStream enum DebugStream {
{ DBG_SERIAL, // Serialization
DBG_SERIAL, // Serialization DBG_RULES, // Signature matching
DBG_RULES, // Signature matching DBG_STRING, // String code
DBG_STRING, // String code DBG_NOTIFIERS, // Notifiers
DBG_NOTIFIERS, // Notifiers DBG_MAINLOOP, // Main IOSource loop
DBG_MAINLOOP, // Main IOSource loop DBG_ANALYZER, // Analyzer framework
DBG_ANALYZER, // Analyzer framework DBG_PACKET_ANALYSIS, // Packet analysis
DBG_PACKET_ANALYSIS, // Packet analysis DBG_FILE_ANALYSIS, // File analysis
DBG_FILE_ANALYSIS, // File analysis DBG_TM, // Time-machine packet input via Broccoli
DBG_TM, // Time-machine packet input via Broccoli DBG_LOGGING, // Logging streams
DBG_LOGGING, // Logging streams DBG_INPUT, // Input streams
DBG_INPUT, // Input streams DBG_THREADING, // Threading system
DBG_THREADING, // Threading system DBG_PLUGINS, // Plugin system
DBG_PLUGINS, // Plugin system DBG_ZEEKYGEN, // Zeekygen
DBG_ZEEKYGEN, // Zeekygen DBG_PKTIO, // Packet sources and dumpers.
DBG_PKTIO, // Packet sources and dumpers. DBG_BROKER, // Broker communication
DBG_BROKER, // Broker communication DBG_SCRIPTS, // Script initialization
DBG_SCRIPTS, // Script initialization DBG_SUPERVISOR, // Process supervisor
DBG_SUPERVISOR, // Process supervisor DBG_HASHKEY, // HashKey buffers
DBG_HASHKEY, // HashKey buffers DBG_SPICY, // Spicy functionality
DBG_SPICY, // Spicy functionality
NUM_DBGS // Has to be last NUM_DBGS // Has to be last
}; };
namespace detail namespace detail {
{
class DebugLogger class DebugLogger {
{
public: public:
// Output goes to stderr per default. // Output goes to stderr per default.
DebugLogger(); DebugLogger();
~DebugLogger(); ~DebugLogger();
void OpenDebugLog(const char* filename = 0); void OpenDebugLog(const char* filename = 0);
void Log(DebugStream stream, const char* fmt, ...) __attribute__((format(printf, 3, 4))); void Log(DebugStream stream, const char* fmt, ...) __attribute__((format(printf, 3, 4)));
void Log(const plugin::Plugin& plugin, const char* fmt, ...) void Log(const plugin::Plugin& plugin, const char* fmt, ...) __attribute__((format(printf, 3, 4)));
__attribute__((format(printf, 3, 4)));
void PushIndent(DebugStream stream) { ++streams[int(stream)].indent; } void PushIndent(DebugStream stream) { ++streams[int(stream)].indent; }
void PopIndent(DebugStream stream) { --streams[int(stream)].indent; } void PopIndent(DebugStream stream) { --streams[int(stream)].indent; }
void EnableStream(DebugStream stream) { streams[int(stream)].enabled = true; } void EnableStream(DebugStream stream) { streams[int(stream)].enabled = true; }
void DisableStream(DebugStream stream) { streams[int(stream)].enabled = false; } void DisableStream(DebugStream stream) { streams[int(stream)].enabled = false; }
// Takes comma-separated list of stream prefixes. // Takes comma-separated list of stream prefixes.
void EnableStreams(const char* streams); void EnableStreams(const char* streams);
// Check the enabled streams for invalid ones. // Check the enabled streams for invalid ones.
bool CheckStreams(const std::set<std::string>& plugin_names); bool CheckStreams(const std::set<std::string>& plugin_names);
bool IsEnabled(DebugStream stream) const { return streams[int(stream)].enabled; } bool IsEnabled(DebugStream stream) const { return streams[int(stream)].enabled; }
void SetVerbose(bool arg_verbose) { verbose = arg_verbose; } void SetVerbose(bool arg_verbose) { verbose = arg_verbose; }
bool IsVerbose() const { return verbose; } bool IsVerbose() const { return verbose; }
void ShowStreamsHelp(); void ShowStreamsHelp();
private: private:
FILE* file; FILE* file;
bool verbose; bool verbose;
struct Stream struct Stream {
{ const char* prefix;
const char* prefix; int indent;
int indent; bool enabled;
bool enabled; };
};
std::set<std::string> enabled_streams; std::set<std::string> enabled_streams;
static Stream streams[NUM_DBGS]; static Stream streams[NUM_DBGS];
const std::string PluginStreamName(const std::string& plugin_name) const std::string PluginStreamName(const std::string& plugin_name) {
{ return "plugin-" + util::strreplace(plugin_name, "::", "-");
return "plugin-" + util::strreplace(plugin_name, "::", "-"); }
} };
};
extern DebugLogger debug_logger; extern DebugLogger debug_logger;
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek
#else #else
#define DBG_LOG(...) #define DBG_LOG(...)

View file

@ -17,428 +17,362 @@
#define DEFAULT_SIZE 128 #define DEFAULT_SIZE 128
#define SLOP 10 #define SLOP 10
namespace zeek namespace zeek {
{
ODesc::ODesc(DescType t, File* arg_f) {
ODesc::ODesc(DescType t, File* arg_f) type = t;
{ style = STANDARD_STYLE;
type = t; f = arg_f;
style = STANDARD_STYLE;
f = arg_f; if ( f == nullptr ) {
size = DEFAULT_SIZE;
if ( f == nullptr ) base = util::safe_malloc(size);
{ ((char*)base)[0] = '\0';
size = DEFAULT_SIZE; offset = 0;
base = util::safe_malloc(size); }
((char*)base)[0] = '\0'; else {
offset = 0; offset = size = 0;
} base = nullptr;
else }
{
offset = size = 0; indent_level = 0;
base = nullptr; is_short = false;
} want_quotes = false;
want_determinism = false;
indent_level = 0; do_flush = true;
is_short = false; include_stats = false;
want_quotes = false; indent_with_spaces = 0;
want_determinism = false; escape = false;
do_flush = true; utf8 = false;
include_stats = false; }
indent_with_spaces = 0;
escape = false; ODesc::~ODesc() {
utf8 = false; if ( f ) {
} if ( do_flush )
f->Flush();
ODesc::~ODesc() }
{ else if ( base )
if ( f ) free(base);
{ }
if ( do_flush )
f->Flush(); void ODesc::EnableEscaping() { escape = true; }
}
else if ( base ) void ODesc::EnableUTF8() { utf8 = true; }
free(base);
} void ODesc::PushIndent() {
++indent_level;
void ODesc::EnableEscaping() NL();
{ }
escape = true;
} void ODesc::PopIndent() {
if ( --indent_level < 0 )
void ODesc::EnableUTF8() reporter->InternalError("ODesc::PopIndent underflow");
{
utf8 = true; NL();
} }
void ODesc::PushIndent() void ODesc::PopIndentNoNL() {
{ if ( --indent_level < 0 )
++indent_level; reporter->InternalError("ODesc::PopIndent underflow");
NL(); }
}
void ODesc::Add(const char* s, int do_indent) {
void ODesc::PopIndent() unsigned int n = strlen(s);
{
if ( --indent_level < 0 ) if ( do_indent && IsReadable() && offset > 0 && ((const char*)base)[offset - 1] == '\n' )
reporter->InternalError("ODesc::PopIndent underflow"); Indent();
NL(); if ( IsBinary() )
} AddBytes(s, n + 1);
else
void ODesc::PopIndentNoNL() AddBytes(s, n);
{ }
if ( --indent_level < 0 )
reporter->InternalError("ODesc::PopIndent underflow"); void ODesc::Add(int i) {
} if ( IsBinary() )
AddBytes(&i, sizeof(i));
void ODesc::Add(const char* s, int do_indent) else {
{ char tmp[256];
unsigned int n = strlen(s); modp_litoa10(i, tmp);
Add(tmp);
if ( do_indent && IsReadable() && offset > 0 && ((const char*)base)[offset - 1] == '\n' ) }
Indent(); }
if ( IsBinary() ) void ODesc::Add(uint32_t u) {
AddBytes(s, n + 1); if ( IsBinary() )
else AddBytes(&u, sizeof(u));
AddBytes(s, n); else {
} char tmp[256];
modp_ulitoa10(u, tmp);
void ODesc::Add(int i) Add(tmp);
{ }
if ( IsBinary() ) }
AddBytes(&i, sizeof(i));
else void ODesc::Add(int64_t i) {
{ if ( IsBinary() )
char tmp[256]; AddBytes(&i, sizeof(i));
modp_litoa10(i, tmp); else {
Add(tmp); char tmp[256];
} modp_litoa10(i, tmp);
} Add(tmp);
}
void ODesc::Add(uint32_t u) }
{
if ( IsBinary() ) void ODesc::Add(uint64_t u) {
AddBytes(&u, sizeof(u)); if ( IsBinary() )
else AddBytes(&u, sizeof(u));
{ else {
char tmp[256]; char tmp[256];
modp_ulitoa10(u, tmp); modp_ulitoa10(u, tmp);
Add(tmp); Add(tmp);
} }
} }
void ODesc::Add(int64_t i) void ODesc::Add(double d, bool no_exp) {
{ if ( IsBinary() )
if ( IsBinary() ) AddBytes(&d, sizeof(d));
AddBytes(&i, sizeof(i)); else {
else // Buffer needs enough chars to store max. possible "double" value
{ // of 1.79e308 without using scientific notation.
char tmp[256]; char tmp[350];
modp_litoa10(i, tmp);
Add(tmp); if ( no_exp )
} modp_dtoa3(d, tmp, sizeof(tmp), IsReadable() ? 6 : 8);
} else
modp_dtoa2(d, tmp, IsReadable() ? 6 : 8);
void ODesc::Add(uint64_t u)
{ Add(tmp);
if ( IsBinary() )
AddBytes(&u, sizeof(u)); auto approx_equal = [](double a, double b, double tolerance = 1e-6) -> bool {
else auto v = a - b;
{ return v < 0 ? -v < tolerance : v < tolerance;
char tmp[256]; };
modp_ulitoa10(u, tmp);
Add(tmp); if ( approx_equal(d, nearbyint(d), 1e-9) && std::isfinite(d) && ! strchr(tmp, 'e') )
} // disambiguate from integer
} Add(".0");
}
void ODesc::Add(double d, bool no_exp) }
{
if ( IsBinary() ) void ODesc::Add(const IPAddr& addr) { Add(addr.AsString()); }
AddBytes(&d, sizeof(d));
else void ODesc::Add(const IPPrefix& prefix) { Add(prefix.AsString()); }
{
// Buffer needs enough chars to store max. possible "double" value void ODesc::AddCS(const char* s) {
// of 1.79e308 without using scientific notation. int n = strlen(s);
char tmp[350]; Add(n);
if ( ! IsBinary() )
if ( no_exp ) Add(" ");
modp_dtoa3(d, tmp, sizeof(tmp), IsReadable() ? 6 : 8); Add(s);
else }
modp_dtoa2(d, tmp, IsReadable() ? 6 : 8);
void ODesc::AddBytes(const String* s) {
Add(tmp); if ( IsReadable() ) {
if ( Style() == RAW_STYLE )
auto approx_equal = [](double a, double b, double tolerance = 1e-6) -> bool AddBytes(reinterpret_cast<const char*>(s->Bytes()), s->Len());
{ else {
auto v = a - b; const char* str = s->Render(String::EXPANDED_STRING);
return v < 0 ? -v < tolerance : v < tolerance; Add(str);
}; delete[] str;
}
if ( approx_equal(d, nearbyint(d), 1e-9) && std::isfinite(d) && ! strchr(tmp, 'e') ) }
// disambiguate from integer else {
Add(".0"); Add(s->Len());
} if ( ! IsBinary() )
} Add(" ");
AddBytes(s->Bytes(), s->Len());
void ODesc::Add(const IPAddr& addr) }
{ }
Add(addr.AsString());
} void ODesc::Indent() {
if ( indent_with_spaces > 0 ) {
void ODesc::Add(const IPPrefix& prefix) for ( int i = 0; i < indent_level; ++i )
{ for ( int j = 0; j < indent_with_spaces; ++j )
Add(prefix.AsString()); Add(" ", 0);
} }
else {
void ODesc::AddCS(const char* s) for ( int i = 0; i < indent_level; ++i )
{ Add("\t", 0);
int n = strlen(s); }
Add(n); }
if ( ! IsBinary() )
Add(" "); static bool starts_with(const char* str1, const char* str2, size_t len) {
Add(s); for ( size_t i = 0; i < len; ++i )
} if ( str1[i] != str2[i] )
return false;
void ODesc::AddBytes(const String* s)
{ return true;
if ( IsReadable() ) }
{
if ( Style() == RAW_STYLE ) size_t ODesc::StartsWithEscapeSequence(const char* start, const char* end) {
AddBytes(reinterpret_cast<const char*>(s->Bytes()), s->Len()); if ( escape_sequences.empty() )
else return 0;
{
const char* str = s->Render(String::EXPANDED_STRING); for ( const auto& esc_str : escape_sequences ) {
Add(str); size_t esc_len = esc_str.length();
delete[] str;
} if ( start + esc_len > end )
} continue;
else
{ if ( starts_with(start, esc_str.c_str(), esc_len) )
Add(s->Len()); return esc_len;
if ( ! IsBinary() ) }
Add(" ");
AddBytes(s->Bytes(), s->Len()); return 0;
} }
}
std::pair<const char*, size_t> ODesc::FirstEscapeLoc(const char* bytes, size_t n) {
void ODesc::Indent() if ( IsBinary() )
{ return {nullptr, 0};
if ( indent_with_spaces > 0 )
{ for ( size_t i = 0; i < n; ++i ) {
for ( int i = 0; i < indent_level; ++i ) auto printable = isprint(bytes[i]);
for ( int j = 0; j < indent_with_spaces; ++j )
Add(" ", 0); if ( ! printable && ! utf8 )
} return {bytes + i, 1};
else
{ if ( bytes[i] == '\\' )
for ( int i = 0; i < indent_level; ++i ) return {bytes + i, 1};
Add("\t", 0);
} size_t len = StartsWithEscapeSequence(bytes + i, bytes + n);
}
if ( len )
static bool starts_with(const char* str1, const char* str2, size_t len) return {bytes + i, len};
{ }
for ( size_t i = 0; i < len; ++i )
if ( str1[i] != str2[i] ) return {nullptr, 0};
return false; }
return true; void ODesc::AddBytes(const void* bytes, unsigned int n) {
} if ( ! escape ) {
AddBytesRaw(bytes, n);
size_t ODesc::StartsWithEscapeSequence(const char* start, const char* end) return;
{ }
if ( escape_sequences.empty() )
return 0; const char* s = (const char*)bytes;
const char* e = (const char*)bytes + n;
for ( const auto& esc_str : escape_sequences )
{ while ( s < e ) {
size_t esc_len = esc_str.length(); auto [esc_start, esc_len] = FirstEscapeLoc(s, e - s);
if ( start + esc_len > end ) if ( esc_start != nullptr ) {
continue; if ( utf8 ) {
std::string result = util::json_escape_utf8(s, esc_start - s, false);
if ( starts_with(start, esc_str.c_str(), esc_len) ) AddBytesRaw(result.c_str(), result.size());
return esc_len; }
} else
AddBytesRaw(s, esc_start - s);
return 0;
} util::get_escaped_string(this, esc_start, esc_len, true);
s = esc_start + esc_len;
std::pair<const char*, size_t> ODesc::FirstEscapeLoc(const char* bytes, size_t n) }
{ else {
if ( IsBinary() ) if ( utf8 ) {
return {nullptr, 0}; std::string result = util::json_escape_utf8(s, e - s, false);
AddBytesRaw(result.c_str(), result.size());
for ( size_t i = 0; i < n; ++i ) }
{ else
auto printable = isprint(bytes[i]); AddBytesRaw(s, e - s);
if ( ! printable && ! utf8 ) break;
return {bytes + i, 1}; }
}
if ( bytes[i] == '\\' ) }
return {bytes + i, 1};
void ODesc::AddBytesRaw(const void* bytes, unsigned int n) {
size_t len = StartsWithEscapeSequence(bytes + i, bytes + n); if ( n == 0 )
return;
if ( len )
return {bytes + i, len}; if ( f ) {
} static bool write_failed = false;
return {nullptr, 0}; if ( ! f->Write((const char*)bytes, n) ) {
} if ( ! write_failed )
// Most likely it's a "disk full" so report
void ODesc::AddBytes(const void* bytes, unsigned int n) // subsequent failures only once.
{ reporter->Error("error writing to %s: %s", f->Name(), strerror(errno));
if ( ! escape )
{ write_failed = true;
AddBytesRaw(bytes, n); return;
return; }
}
write_failed = false;
const char* s = (const char*)bytes; }
const char* e = (const char*)bytes + n;
else {
while ( s < e ) Grow(n);
{
auto [esc_start, esc_len] = FirstEscapeLoc(s, e - s); // The following casting contortions are necessary because
// simply using &base[offset] generates complaints about
if ( esc_start != nullptr ) // using a void* for pointer arithmetic.
{ memcpy((void*)&((char*)base)[offset], bytes, n);
if ( utf8 ) offset += n;
{
std::string result = util::json_escape_utf8(s, esc_start - s, false); ((char*)base)[offset] = '\0'; // ensure that always NUL-term.
AddBytesRaw(result.c_str(), result.size()); }
} }
else
AddBytesRaw(s, esc_start - s); void ODesc::Grow(unsigned int n) {
bool size_changed = false;
util::get_escaped_string(this, esc_start, esc_len, true); while ( offset + n + SLOP >= size ) {
s = esc_start + esc_len; size *= 2;
} size_changed = true;
else }
{
if ( utf8 ) if ( size_changed )
{ base = util::safe_realloc(base, size);
std::string result = util::json_escape_utf8(s, e - s, false); }
AddBytesRaw(result.c_str(), result.size());
} void ODesc::Clear() {
else offset = 0;
AddBytesRaw(s, e - s);
// If we've allocated an exceedingly large amount of space, free it.
break; if ( size > 10 * 1024 * 1024 ) {
} free(base);
} size = DEFAULT_SIZE;
} base = util::safe_malloc(size);
((char*)base)[0] = '\0';
void ODesc::AddBytesRaw(const void* bytes, unsigned int n) }
{ }
if ( n == 0 )
return; bool ODesc::PushType(const Type* type) {
auto res = encountered_types.insert(type);
if ( f ) return std::get<1>(res);
{ }
static bool write_failed = false;
bool ODesc::PopType(const Type* type) {
if ( ! f->Write((const char*)bytes, n) ) size_t res = encountered_types.erase(type);
{ return (res == 1);
if ( ! write_failed ) }
// Most likely it's a "disk full" so report
// subsequent failures only once. bool ODesc::FindType(const Type* type) {
reporter->Error("error writing to %s: %s", f->Name(), strerror(errno)); auto res = encountered_types.find(type);
write_failed = true; if ( res != encountered_types.end() )
return; return true;
}
return false;
write_failed = false; }
}
std::string obj_desc(const Obj* o) {
else static ODesc d;
{
Grow(n); d.Clear();
o->Describe(&d);
// The following casting contortions are necessary because d.SP();
// simply using &base[offset] generates complaints about o->GetLocationInfo()->Describe(&d);
// using a void* for pointer arithmetic.
memcpy((void*)&((char*)base)[offset], bytes, n); return d.Description();
offset += n; }
((char*)base)[offset] = '\0'; // ensure that always NUL-term. std::string obj_desc_short(const Obj* o) {
} static ODesc d;
}
d.SetShort(true);
void ODesc::Grow(unsigned int n) d.Clear();
{ o->Describe(&d);
bool size_changed = false;
while ( offset + n + SLOP >= size ) return d.Description();
{ }
size *= 2;
size_changed = true; } // namespace zeek
}
if ( size_changed )
base = util::safe_realloc(base, size);
}
void ODesc::Clear()
{
offset = 0;
// If we've allocated an exceedingly large amount of space, free it.
if ( size > 10 * 1024 * 1024 )
{
free(base);
size = DEFAULT_SIZE;
base = util::safe_malloc(size);
((char*)base)[0] = '\0';
}
}
bool ODesc::PushType(const Type* type)
{
auto res = encountered_types.insert(type);
return std::get<1>(res);
}
bool ODesc::PopType(const Type* type)
{
size_t res = encountered_types.erase(type);
return (res == 1);
}
bool ODesc::FindType(const Type* type)
{
auto res = encountered_types.find(type);
if ( res != encountered_types.end() )
return true;
return false;
}
std::string obj_desc(const Obj* o)
{
static ODesc d;
d.Clear();
o->Describe(&d);
d.SP();
o->GetLocationInfo()->Describe(&d);
return d.Description();
}
std::string obj_desc_short(const Obj* o)
{
static ODesc d;
d.SetShort(true);
d.Clear();
o->Describe(&d);
return d.Description();
}
} // namespace zeek

View file

@ -8,222 +8,207 @@
#include <utility> #include <utility>
#include "zeek/ZeekString.h" // for byte_vec #include "zeek/ZeekString.h" // for byte_vec
#include "zeek/util.h" // for zeek_int_t #include "zeek/util.h" // for zeek_int_t
namespace zeek namespace zeek {
{
class IPAddr; class IPAddr;
class IPPrefix; class IPPrefix;
class File; class File;
class Type; class Type;
enum DescType enum DescType {
{ DESC_READABLE,
DESC_READABLE, DESC_BINARY,
DESC_BINARY, };
};
enum DescStyle enum DescStyle {
{ STANDARD_STYLE,
STANDARD_STYLE, RAW_STYLE,
RAW_STYLE, };
};
class ODesc class ODesc {
{
public: public:
explicit ODesc(DescType t = DESC_READABLE, File* f = nullptr); explicit ODesc(DescType t = DESC_READABLE, File* f = nullptr);
~ODesc(); ~ODesc();
bool IsReadable() const { return type == DESC_READABLE; } bool IsReadable() const { return type == DESC_READABLE; }
bool IsBinary() const { return type == DESC_BINARY; } bool IsBinary() const { return type == DESC_BINARY; }
bool IsShort() const { return is_short; } bool IsShort() const { return is_short; }
void SetShort() { is_short = true; } void SetShort() { is_short = true; }
void SetShort(bool s) { is_short = s; } void SetShort(bool s) { is_short = s; }
// Whether we want to have quotes around strings. // Whether we want to have quotes around strings.
bool WantQuotes() const { return want_quotes; } bool WantQuotes() const { return want_quotes; }
void SetQuotes(bool q) { want_quotes = q; } void SetQuotes(bool q) { want_quotes = q; }
// Whether to ensure deterministic output (for example, when // Whether to ensure deterministic output (for example, when
// describing TableVal's). // describing TableVal's).
bool WantDeterminism() const { return want_determinism; } bool WantDeterminism() const { return want_determinism; }
void SetDeterminism(bool d) { want_determinism = d; } void SetDeterminism(bool d) { want_determinism = d; }
// Whether we want to print statistics like access time and execution // Whether we want to print statistics like access time and execution
// count where available. // count where available.
bool IncludeStats() const { return include_stats; } bool IncludeStats() const { return include_stats; }
void SetIncludeStats(bool s) { include_stats = s; } void SetIncludeStats(bool s) { include_stats = s; }
DescStyle Style() const { return style; } DescStyle Style() const { return style; }
void SetStyle(DescStyle s) { style = s; } void SetStyle(DescStyle s) { style = s; }
void SetFlush(bool arg_do_flush) { do_flush = arg_do_flush; } void SetFlush(bool arg_do_flush) { do_flush = arg_do_flush; }
void EnableEscaping(); void EnableEscaping();
void EnableUTF8(); void EnableUTF8();
void AddEscapeSequence(const char* s) { escape_sequences.insert(s); } void AddEscapeSequence(const char* s) { escape_sequences.insert(s); }
void AddEscapeSequence(const char* s, size_t n) { escape_sequences.insert(std::string(s, n)); } void AddEscapeSequence(const char* s, size_t n) { escape_sequences.insert(std::string(s, n)); }
void AddEscapeSequence(const std::string& s) { escape_sequences.insert(s); } void AddEscapeSequence(const std::string& s) { escape_sequences.insert(s); }
void RemoveEscapeSequence(const char* s) { escape_sequences.erase(s); } void RemoveEscapeSequence(const char* s) { escape_sequences.erase(s); }
void RemoveEscapeSequence(const char* s, size_t n) void RemoveEscapeSequence(const char* s, size_t n) { escape_sequences.erase(std::string(s, n)); }
{ void RemoveEscapeSequence(const std::string& s) { escape_sequences.erase(s); }
escape_sequences.erase(std::string(s, n));
}
void RemoveEscapeSequence(const std::string& s) { escape_sequences.erase(s); }
void PushIndent(); void PushIndent();
void PopIndent(); void PopIndent();
void PopIndentNoNL(); void PopIndentNoNL();
int GetIndentLevel() const { return indent_level; } int GetIndentLevel() const { return indent_level; }
void ClearIndentLevel() { indent_level = 0; } void ClearIndentLevel() { indent_level = 0; }
int IndentSpaces() const { return indent_with_spaces; } int IndentSpaces() const { return indent_with_spaces; }
void SetIndentSpaces(int i) { indent_with_spaces = i; } void SetIndentSpaces(int i) { indent_with_spaces = i; }
void Add(const char* s, int do_indent = 1); void Add(const char* s, int do_indent = 1);
void AddN(const char* s, int len) { AddBytes(s, len); } void AddN(const char* s, int len) { AddBytes(s, len); }
void Add(const std::string& s) { AddBytes(s.data(), s.size()); } void Add(const std::string& s) { AddBytes(s.data(), s.size()); }
void Add(int i); void Add(int i);
void Add(uint32_t u); void Add(uint32_t u);
void Add(int64_t i); void Add(int64_t i);
void Add(uint64_t u); void Add(uint64_t u);
void Add(double d, bool no_exp = false); void Add(double d, bool no_exp = false);
void Add(const IPAddr& addr); void Add(const IPAddr& addr);
void Add(const IPPrefix& prefix); void Add(const IPPrefix& prefix);
// Add s as a counted string. // Add s as a counted string.
void AddCS(const char* s); void AddCS(const char* s);
void AddBytes(const String* s); void AddBytes(const String* s);
void Add(const char* s1, const char* s2) void Add(const char* s1, const char* s2) {
{ Add(s1);
Add(s1); Add(s2);
Add(s2); }
}
void AddSP(const char* s1, const char* s2) void AddSP(const char* s1, const char* s2) {
{ Add(s1);
Add(s1); AddSP(s2);
AddSP(s2); }
}
void AddSP(const char* s) void AddSP(const char* s) {
{ Add(s);
Add(s); SP();
SP(); }
}
void AddCount(zeek_int_t n) void AddCount(zeek_int_t n) {
{ if ( ! IsReadable() ) {
if ( ! IsReadable() ) Add(n);
{ SP();
Add(n); }
SP(); }
}
}
void SP() void SP() {
{ if ( ! IsBinary() )
if ( ! IsBinary() ) Add(" ", 0);
Add(" ", 0); }
} void NL() {
void NL() if ( ! IsBinary() && ! is_short )
{ Add("\n", 0);
if ( ! IsBinary() && ! is_short ) }
Add("\n", 0);
}
// Bypasses the escaping enabled via EnableEscaping(). // Bypasses the escaping enabled via EnableEscaping().
void AddRaw(const char* s, int len) { AddBytesRaw(s, len); } void AddRaw(const char* s, int len) { AddBytesRaw(s, len); }
void AddRaw(const std::string& s) { AddBytesRaw(s.data(), s.size()); } void AddRaw(const std::string& s) { AddBytesRaw(s.data(), s.size()); }
// Returns the description as a string. // Returns the description as a string.
const char* Description() const { return (const char*)base; } const char* Description() const { return (const char*)base; }
const u_char* Bytes() const { return (const u_char*)base; } const u_char* Bytes() const { return (const u_char*)base; }
byte_vec TakeBytes() byte_vec TakeBytes() {
{ const void* t = base;
const void* t = base; base = nullptr;
base = nullptr; size = 0;
size = 0;
// Don't clear offset, as we want to still support // Don't clear offset, as we want to still support
// subsequent calls to Len(). // subsequent calls to Len().
return byte_vec(t); return byte_vec(t);
} }
int Len() const { return offset; } int Len() const { return offset; }
void Clear(); void Clear();
// Used to determine recursive types. Records push their types on here; // Used to determine recursive types. Records push their types on here;
// if the same type (by address) is re-encountered, processing aborts. // if the same type (by address) is re-encountered, processing aborts.
bool PushType(const Type* type); bool PushType(const Type* type);
bool PopType(const Type* type); bool PopType(const Type* type);
bool FindType(const Type* type); bool FindType(const Type* type);
protected: protected:
void Indent(); void Indent();
void AddBytes(const void* bytes, unsigned int n); void AddBytes(const void* bytes, unsigned int n);
void AddBytesRaw(const void* bytes, unsigned int n); void AddBytesRaw(const void* bytes, unsigned int n);
// Make buffer big enough for n bytes beyond bufp. // Make buffer big enough for n bytes beyond bufp.
void Grow(unsigned int n); void Grow(unsigned int n);
/** /**
* Returns the location of the first place in the bytes to be hex-escaped. * Returns the location of the first place in the bytes to be hex-escaped.
* *
* @param bytes the starting memory address to start searching for * @param bytes the starting memory address to start searching for
* escapable character. * escapable character.
* @param n the maximum number of bytes to search. * @param n the maximum number of bytes to search.
* @return a pair whose first element represents a starting memory address * @return a pair whose first element represents a starting memory address
* to be escaped up to the number of characters indicated by the * to be escaped up to the number of characters indicated by the
* second element. The first element may be 0 if nothing is * second element. The first element may be 0 if nothing is
* to be escaped. * to be escaped.
*/ */
std::pair<const char*, size_t> FirstEscapeLoc(const char* bytes, size_t n); std::pair<const char*, size_t> FirstEscapeLoc(const char* bytes, size_t n);
/** /**
* @param start start of string to check for starting with an escape * @param start start of string to check for starting with an escape
* sequence. * sequence.
* @param end one byte past the last character in the string. * @param end one byte past the last character in the string.
* @return The number of bytes in the escape sequence that the string * @return The number of bytes in the escape sequence that the string
* starts with. * starts with.
*/ */
size_t StartsWithEscapeSequence(const char* start, const char* end); size_t StartsWithEscapeSequence(const char* start, const char* end);
DescType type; DescType type;
DescStyle style; DescStyle style;
void* base; // beginning of buffer void* base; // beginning of buffer
unsigned int offset; // where we are in the buffer unsigned int offset; // where we are in the buffer
unsigned int size; // size of buffer in bytes unsigned int size; // size of buffer in bytes
bool utf8; // whether valid utf-8 sequences may pass through unescaped bool utf8; // whether valid utf-8 sequences may pass through unescaped
bool escape; // escape unprintable characters in output? bool escape; // escape unprintable characters in output?
bool is_short; bool is_short;
bool want_quotes; bool want_quotes;
bool want_determinism; bool want_determinism;
bool do_flush; bool do_flush;
bool include_stats; bool include_stats;
int indent_with_spaces; int indent_with_spaces;
int indent_level; int indent_level;
using escape_set = std::set<std::string>; using escape_set = std::set<std::string>;
escape_set escape_sequences; // additional sequences of chars to escape escape_set escape_sequences; // additional sequences of chars to escape
File* f; // or the file we're using. File* f; // or the file we're using.
std::set<const Type*> encountered_types; std::set<const Type*> encountered_types;
}; };
// Returns a string representation of an object's description. Used for // Returns a string representation of an object's description. Used for
// debugging and error messages. takes a bare pointer rather than an // debugging and error messages. takes a bare pointer rather than an
@ -235,4 +220,4 @@ std::string obj_desc(const Obj* o);
// Same as obj_desc(), but ensure it is short and don't include location info. // Same as obj_desc(), but ensure it is short and don't include location info.
std::string obj_desc_short(const Obj* o); std::string obj_desc_short(const Obj* o);
} // namespace zeek } // namespace zeek

View file

@ -5,458 +5,435 @@
#include "zeek/3rdparty/doctest.h" #include "zeek/3rdparty/doctest.h"
#include "zeek/Hash.h" #include "zeek/Hash.h"
namespace zeek namespace zeek {
{
// namespace detail // namespace detail
TEST_SUITE_BEGIN("Dict"); TEST_SUITE_BEGIN("Dict");
TEST_CASE("dict construction") TEST_CASE("dict construction") {
{ PDict<int> dict;
PDict<int> dict; CHECK(! dict.IsOrdered());
CHECK(! dict.IsOrdered()); CHECK(dict.Length() == 0);
CHECK(dict.Length() == 0);
PDict<int> dict2(ORDERED);
CHECK(dict2.IsOrdered());
CHECK(dict2.Length() == 0);
}
TEST_CASE("dict operation") PDict<int> dict2(ORDERED);
{ CHECK(dict2.IsOrdered());
PDict<uint32_t> dict; CHECK(dict2.Length() == 0);
}
uint32_t val = 10; TEST_CASE("dict operation") {
uint32_t key_val = 5; PDict<uint32_t> dict;
detail::HashKey* key = new detail::HashKey(key_val); uint32_t val = 10;
dict.Insert(key, &val); uint32_t key_val = 5;
CHECK(dict.Length() == 1);
detail::HashKey* key = new detail::HashKey(key_val);
detail::HashKey* key2 = new detail::HashKey(key_val); dict.Insert(key, &val);
uint32_t* lookup = dict.Lookup(key2); CHECK(dict.Length() == 1);
CHECK(*lookup == val);
detail::HashKey* key2 = new detail::HashKey(key_val);
dict.Remove(key2); uint32_t* lookup = dict.Lookup(key2);
CHECK(dict.Length() == 0); CHECK(*lookup == val);
uint32_t* lookup2 = dict.Lookup(key2);
CHECK(lookup2 == (uint32_t*)0); dict.Remove(key2);
delete key2; CHECK(dict.Length() == 0);
uint32_t* lookup2 = dict.Lookup(key2);
CHECK(dict.MaxLength() == 1); CHECK(lookup2 == (uint32_t*)0);
CHECK(dict.NumCumulativeInserts() == 1); delete key2;
dict.Insert(key, &val); CHECK(dict.MaxLength() == 1);
dict.Remove(key); CHECK(dict.NumCumulativeInserts() == 1);
CHECK(dict.MaxLength() == 1); dict.Insert(key, &val);
CHECK(dict.NumCumulativeInserts() == 2); dict.Remove(key);
uint32_t val2 = 15; CHECK(dict.MaxLength() == 1);
uint32_t key_val2 = 25; CHECK(dict.NumCumulativeInserts() == 2);
key2 = new detail::HashKey(key_val2);
uint32_t val2 = 15;
dict.Insert(key, &val); uint32_t key_val2 = 25;
dict.Insert(key2, &val2); key2 = new detail::HashKey(key_val2);
CHECK(dict.Length() == 2);
CHECK(dict.NumCumulativeInserts() == 4); dict.Insert(key, &val);
dict.Insert(key2, &val2);
dict.Clear(); CHECK(dict.Length() == 2);
CHECK(dict.Length() == 0); CHECK(dict.NumCumulativeInserts() == 4);
delete key; dict.Clear();
delete key2; CHECK(dict.Length() == 0);
}
delete key;
TEST_CASE("dict nthentry") delete key2;
{ }
PDict<uint32_t> unordered(UNORDERED);
PDict<uint32_t> ordered(ORDERED); TEST_CASE("dict nthentry") {
PDict<uint32_t> unordered(UNORDERED);
uint32_t val = 15; PDict<uint32_t> ordered(ORDERED);
uint32_t key_val = 5;
detail::HashKey* okey = new detail::HashKey(key_val); uint32_t val = 15;
detail::HashKey* ukey = new detail::HashKey(key_val); uint32_t key_val = 5;
detail::HashKey* okey = new detail::HashKey(key_val);
uint32_t val2 = 10; detail::HashKey* ukey = new detail::HashKey(key_val);
uint32_t key_val2 = 25;
detail::HashKey* okey2 = new detail::HashKey(key_val2); uint32_t val2 = 10;
detail::HashKey* ukey2 = new detail::HashKey(key_val2); uint32_t key_val2 = 25;
detail::HashKey* okey2 = new detail::HashKey(key_val2);
unordered.Insert(ukey, &val); detail::HashKey* ukey2 = new detail::HashKey(key_val2);
unordered.Insert(ukey2, &val2);
unordered.Insert(ukey, &val);
ordered.Insert(okey, &val); unordered.Insert(ukey2, &val2);
ordered.Insert(okey2, &val2);
ordered.Insert(okey, &val);
// NthEntry returns null for unordered dicts ordered.Insert(okey2, &val2);
uint32_t* lookup = unordered.NthEntry(0);
CHECK(lookup == (uint32_t*)0); // NthEntry returns null for unordered dicts
uint32_t* lookup = unordered.NthEntry(0);
// Ordered dicts are based on order of insertion, nothing about the CHECK(lookup == (uint32_t*)0);
// data itself
lookup = ordered.NthEntry(0); // Ordered dicts are based on order of insertion, nothing about the
CHECK(*lookup == 15); // data itself
lookup = ordered.NthEntry(0);
delete okey; CHECK(*lookup == 15);
delete okey2;
delete ukey; delete okey;
delete ukey2; delete okey2;
} delete ukey;
delete ukey2;
TEST_CASE("dict iteration") }
{
PDict<uint32_t> dict; TEST_CASE("dict iteration") {
PDict<uint32_t> dict;
uint32_t val = 15;
uint32_t key_val = 5; uint32_t val = 15;
detail::HashKey* key = new detail::HashKey(key_val); uint32_t key_val = 5;
detail::HashKey* key = new detail::HashKey(key_val);
uint32_t val2 = 10;
uint32_t key_val2 = 25; uint32_t val2 = 10;
detail::HashKey* key2 = new detail::HashKey(key_val2); uint32_t key_val2 = 25;
detail::HashKey* key2 = new detail::HashKey(key_val2);
dict.Insert(key, &val);
dict.Insert(key2, &val2); dict.Insert(key, &val);
dict.Insert(key2, &val2);
int count = 0;
int count = 0;
for ( const auto& entry : dict )
{ for ( const auto& entry : dict ) {
auto* v = static_cast<uint32_t*>(entry.value); auto* v = static_cast<uint32_t*>(entry.value);
uint64_t k = *(uint32_t*)entry.GetKey(); uint64_t k = *(uint32_t*)entry.GetKey();
switch ( count ) switch ( count ) {
{ case 0:
case 0: CHECK(k == key_val2);
CHECK(k == key_val2); CHECK(*v == val2);
CHECK(*v == val2); break;
break; case 1:
case 1: CHECK(k == key_val);
CHECK(k == key_val); CHECK(*v == val);
CHECK(*v == val); break;
break; default: break;
default: }
break;
} count++;
}
count++;
} PDict<uint32_t>::iterator it;
it = dict.begin();
PDict<uint32_t>::iterator it; it = dict.end();
it = dict.begin(); PDict<uint32_t>::iterator it2 = it;
it = dict.end();
PDict<uint32_t>::iterator it2 = it; CHECK(count == 2);
CHECK(count == 2); delete key;
delete key2;
delete key; }
delete key2;
} TEST_CASE("dict robust iteration") {
PDict<uint32_t> dict;
TEST_CASE("dict robust iteration")
{ uint32_t val = 15;
PDict<uint32_t> dict; uint32_t key_val = 5;
detail::HashKey* key = new detail::HashKey(key_val);
uint32_t val = 15;
uint32_t key_val = 5; uint32_t val2 = 10;
detail::HashKey* key = new detail::HashKey(key_val); uint32_t key_val2 = 25;
detail::HashKey* key2 = new detail::HashKey(key_val2);
uint32_t val2 = 10;
uint32_t key_val2 = 25; uint32_t val3 = 20;
detail::HashKey* key2 = new detail::HashKey(key_val2); uint32_t key_val3 = 35;
detail::HashKey* key3 = new detail::HashKey(key_val3);
uint32_t val3 = 20;
uint32_t key_val3 = 35; dict.Insert(key, &val);
detail::HashKey* key3 = new detail::HashKey(key_val3); dict.Insert(key2, &val2);
dict.Insert(key, &val); {
dict.Insert(key2, &val2); int count = 0;
auto it = dict.begin_robust();
{
int count = 0; for ( ; it != dict.end_robust(); ++it ) {
auto it = dict.begin_robust(); auto* v = it->value;
uint64_t k = *(uint32_t*)it->GetKey();
for ( ; it != dict.end_robust(); ++it )
{ switch ( count ) {
auto* v = it->value; case 0:
uint64_t k = *(uint32_t*)it->GetKey(); CHECK(k == key_val2);
CHECK(*v == val2);
switch ( count ) dict.Insert(key3, &val3);
{ break;
case 0: case 1:
CHECK(k == key_val2); CHECK(k == key_val);
CHECK(*v == val2); CHECK(*v == val);
dict.Insert(key3, &val3); break;
break; case 2:
case 1: CHECK(k == key_val3);
CHECK(k == key_val); CHECK(*v == val3);
CHECK(*v == val); break;
break; default:
case 2: // We shouldn't get here.
CHECK(k == key_val3); CHECK(false);
CHECK(*v == val3); break;
break; }
default: count++;
// We shouldn't get here. }
CHECK(false);
break; CHECK(count == 3);
} }
count++;
} {
int count = 0;
CHECK(count == 3); auto it = dict.begin_robust();
}
for ( ; it != dict.end_robust(); ++it ) {
{ auto* v = it->value;
int count = 0; uint64_t k = *(uint32_t*)it->GetKey();
auto it = dict.begin_robust();
switch ( count ) {
for ( ; it != dict.end_robust(); ++it ) case 0:
{ CHECK(k == key_val2);
auto* v = it->value; CHECK(*v == val2);
uint64_t k = *(uint32_t*)it->GetKey(); dict.Insert(key3, &val3);
dict.Remove(key3);
switch ( count ) break;
{ case 1:
case 0: CHECK(k == key_val);
CHECK(k == key_val2); CHECK(*v == val);
CHECK(*v == val2); break;
dict.Insert(key3, &val3); default:
dict.Remove(key3); // We shouldn't get here.
break; CHECK(false);
case 1: break;
CHECK(k == key_val); }
CHECK(*v == val); count++;
break; }
default:
// We shouldn't get here. CHECK(count == 2);
CHECK(false); }
break;
} delete key;
count++; delete key2;
} delete key3;
}
CHECK(count == 2);
} TEST_CASE("dict ordered iteration") {
PDict<uint32_t> dict(DictOrder::ORDERED);
delete key;
delete key2; // These key values are specifically contrived to be inserted
delete key3; // into the dictionary in a different order by default.
} uint32_t val = 15;
uint32_t key_val = 5;
TEST_CASE("dict ordered iteration") auto key = std::make_unique<detail::HashKey>(key_val);
{
PDict<uint32_t> dict(DictOrder::ORDERED); uint32_t val2 = 10;
uint32_t key_val2 = 25;
// These key values are specifically contrived to be inserted auto key2 = std::make_unique<detail::HashKey>(key_val2);
// into the dictionary in a different order by default.
uint32_t val = 15; uint32_t val3 = 30;
uint32_t key_val = 5; uint32_t key_val3 = 45;
auto key = std::make_unique<detail::HashKey>(key_val); auto key3 = std::make_unique<detail::HashKey>(key_val3);
uint32_t val2 = 10; uint32_t val4 = 20;
uint32_t key_val2 = 25; uint32_t key_val4 = 35;
auto key2 = std::make_unique<detail::HashKey>(key_val2); auto key4 = std::make_unique<detail::HashKey>(key_val4);
uint32_t val3 = 30; // Only insert the first three to start with so we can test the order
uint32_t key_val3 = 45; // being the same after a later insertion.
auto key3 = std::make_unique<detail::HashKey>(key_val3); dict.Insert(key.get(), &val);
dict.Insert(key2.get(), &val2);
uint32_t val4 = 20; dict.Insert(key3.get(), &val3);
uint32_t key_val4 = 35;
auto key4 = std::make_unique<detail::HashKey>(key_val4); int count = 0;
// Only insert the first three to start with so we can test the order for ( const auto& entry : dict ) {
// being the same after a later insertion. auto* v = static_cast<uint32_t*>(entry.value);
dict.Insert(key.get(), &val); uint32_t k = *(uint32_t*)entry.GetKey();
dict.Insert(key2.get(), &val2);
dict.Insert(key3.get(), &val3); // The keys should be returned in the same order we inserted
// them, which is 5, 25, 45.
int count = 0; if ( count == 0 )
CHECK(k == 5);
for ( const auto& entry : dict ) else if ( count == 1 )
{ CHECK(k == 25);
auto* v = static_cast<uint32_t*>(entry.value); else if ( count == 2 )
uint32_t k = *(uint32_t*)entry.GetKey(); CHECK(k == 45);
// The keys should be returned in the same order we inserted count++;
// them, which is 5, 25, 45. }
if ( count == 0 )
CHECK(k == 5); dict.Insert(key4.get(), &val4);
else if ( count == 1 ) count = 0;
CHECK(k == 25);
else if ( count == 2 ) for ( const auto& entry : dict ) {
CHECK(k == 45); auto* v = static_cast<uint32_t*>(entry.value);
uint32_t k = *(uint32_t*)entry.GetKey();
count++;
} // The keys should be returned in the same order we inserted
// them, which is 5, 25, 45, 35.
dict.Insert(key4.get(), &val4); if ( count == 0 )
count = 0; CHECK(k == 5);
else if ( count == 1 )
for ( const auto& entry : dict ) CHECK(k == 25);
{ else if ( count == 2 )
auto* v = static_cast<uint32_t*>(entry.value); CHECK(k == 45);
uint32_t k = *(uint32_t*)entry.GetKey(); else if ( count == 3 )
CHECK(k == 35);
// The keys should be returned in the same order we inserted
// them, which is 5, 25, 45, 35. count++;
if ( count == 0 ) }
CHECK(k == 5);
else if ( count == 1 ) dict.Remove(key2.get());
CHECK(k == 25); count = 0;
else if ( count == 2 )
CHECK(k == 45); for ( const auto& entry : dict ) {
else if ( count == 3 ) auto* v = static_cast<uint32_t*>(entry.value);
CHECK(k == 35); uint32_t k = *(uint32_t*)entry.GetKey();
count++; // The keys should be returned in the same order we inserted
} // them, which is 5, 45, 35.
if ( count == 0 )
dict.Remove(key2.get()); CHECK(k == 5);
count = 0; else if ( count == 1 )
CHECK(k == 45);
for ( const auto& entry : dict ) else if ( count == 2 )
{ CHECK(k == 35);
auto* v = static_cast<uint32_t*>(entry.value);
uint32_t k = *(uint32_t*)entry.GetKey(); count++;
}
// The keys should be returned in the same order we inserted }
// them, which is 5, 45, 35.
if ( count == 0 ) class DictTestDummy {
CHECK(k == 5);
else if ( count == 1 )
CHECK(k == 45);
else if ( count == 2 )
CHECK(k == 35);
count++;
}
}
class DictTestDummy
{
public: public:
DictTestDummy(int v) : v(v) { } DictTestDummy(int v) : v(v) {}
~DictTestDummy() = default; ~DictTestDummy() = default;
int v = 0; int v = 0;
}; };
TEST_CASE("dict robust iteration replacement") TEST_CASE("dict robust iteration replacement") {
{ PDict<DictTestDummy> dict;
PDict<DictTestDummy> dict;
DictTestDummy* val1 = new DictTestDummy(15); DictTestDummy* val1 = new DictTestDummy(15);
uint32_t key_val1 = 5; uint32_t key_val1 = 5;
detail::HashKey* key1 = new detail::HashKey(key_val1); detail::HashKey* key1 = new detail::HashKey(key_val1);
DictTestDummy* val2 = new DictTestDummy(10); DictTestDummy* val2 = new DictTestDummy(10);
uint32_t key_val2 = 25; uint32_t key_val2 = 25;
detail::HashKey* key2 = new detail::HashKey(key_val2); detail::HashKey* key2 = new detail::HashKey(key_val2);
DictTestDummy* val3 = new DictTestDummy(20); DictTestDummy* val3 = new DictTestDummy(20);
uint32_t key_val3 = 35; uint32_t key_val3 = 35;
detail::HashKey* key3 = new detail::HashKey(key_val3); detail::HashKey* key3 = new detail::HashKey(key_val3);
dict.Insert(key1, val1); dict.Insert(key1, val1);
dict.Insert(key2, val2); dict.Insert(key2, val2);
dict.Insert(key3, val3); dict.Insert(key3, val3);
int count = 0; int count = 0;
auto it = dict.begin_robust(); auto it = dict.begin_robust();
// Iterate past the first couple of elements so we're not done, but the // Iterate past the first couple of elements so we're not done, but the
// iterator is still pointing at a valid element. // iterator is still pointing at a valid element.
for ( ; count != 2 && it != dict.end_robust(); ++count, ++it ) { } for ( ; count != 2 && it != dict.end_robust(); ++count, ++it ) {
}
// Store off the value at this iterator index // Store off the value at this iterator index
auto* v = it->value; auto* v = it->value;
// Replace it with something else // Replace it with something else
auto k = it->GetHashKey(); auto k = it->GetHashKey();
DictTestDummy* val4 = new DictTestDummy(50); DictTestDummy* val4 = new DictTestDummy(50);
dict.Insert(k.get(), val4); dict.Insert(k.get(), val4);
// Delete the original element // Delete the original element
delete val2; delete val2;
// This shouldn't crash with AddressSanitizer // This shouldn't crash with AddressSanitizer
for ( ; it != dict.end_robust(); ++it ) for ( ; it != dict.end_robust(); ++it ) {
{ uint64_t k = *(uint32_t*)it->GetKey();
uint64_t k = *(uint32_t*)it->GetKey(); auto* v = it->value;
auto* v = it->value; CHECK(v->v == 50);
CHECK(v->v == 50); }
}
delete key1; delete key1;
delete key2; delete key2;
delete key3; delete key3;
delete val1; delete val1;
delete val3; delete val3;
delete val4; delete val4;
} }
TEST_CASE("dict iterator invalidation") TEST_CASE("dict iterator invalidation") {
{ PDict<uint32_t> dict;
PDict<uint32_t> dict;
uint32_t val = 15; uint32_t val = 15;
uint32_t key_val = 5; uint32_t key_val = 5;
auto key = new detail::HashKey(key_val); auto key = new detail::HashKey(key_val);
uint32_t val2 = 10; uint32_t val2 = 10;
uint32_t key_val2 = 25; uint32_t key_val2 = 25;
auto key2 = new detail::HashKey(key_val2); auto key2 = new detail::HashKey(key_val2);
uint32_t val3 = 42; uint32_t val3 = 42;
uint32_t key_val3 = 37; uint32_t key_val3 = 37;
auto key3 = new detail::HashKey(key_val3); auto key3 = new detail::HashKey(key_val3);
dict.Insert(key, &val); dict.Insert(key, &val);
dict.Insert(key2, &val2); dict.Insert(key2, &val2);
detail::HashKey* it_key; detail::HashKey* it_key;
bool iterators_invalidated = false; bool iterators_invalidated = false;
auto it = dict.begin(); auto it = dict.begin();
iterators_invalidated = false; iterators_invalidated = false;
dict.Remove(key3, &iterators_invalidated); dict.Remove(key3, &iterators_invalidated);
// Key doesn't exist, nothing to remove, iteration not invalidated. // Key doesn't exist, nothing to remove, iteration not invalidated.
CHECK(! iterators_invalidated); CHECK(! iterators_invalidated);
iterators_invalidated = false; iterators_invalidated = false;
dict.Insert(key, &val2, &iterators_invalidated); dict.Insert(key, &val2, &iterators_invalidated);
// Key exists, value gets overwritten, iteration not invalidated. // Key exists, value gets overwritten, iteration not invalidated.
CHECK(! iterators_invalidated); CHECK(! iterators_invalidated);
iterators_invalidated = false; iterators_invalidated = false;
dict.Remove(key2, &iterators_invalidated); dict.Remove(key2, &iterators_invalidated);
// Key exists, gets removed, iteration is invalidated. // Key exists, gets removed, iteration is invalidated.
CHECK(iterators_invalidated); CHECK(iterators_invalidated);
it = dict.begin(); it = dict.begin();
iterators_invalidated = false; iterators_invalidated = false;
dict.Insert(key3, &val3, &iterators_invalidated); dict.Insert(key3, &val3, &iterators_invalidated);
// Key doesn't exist, gets inserted, iteration is invalidated. // Key doesn't exist, gets inserted, iteration is invalidated.
CHECK(iterators_invalidated); CHECK(iterators_invalidated);
CHECK(dict.Length() == 2); CHECK(dict.Length() == 2);
CHECK(*static_cast<uint32_t*>(dict.Lookup(key)) == val2); CHECK(*static_cast<uint32_t*>(dict.Lookup(key)) == val2);
CHECK(*static_cast<uint32_t*>(dict.Lookup(key3)) == val3); CHECK(*static_cast<uint32_t*>(dict.Lookup(key3)) == val3);
CHECK(static_cast<uint32_t*>(dict.Lookup(key2)) == nullptr); CHECK(static_cast<uint32_t*>(dict.Lookup(key2)) == nullptr);
delete key; delete key;
delete key2; delete key2;
delete key3; delete key3;
} }
// private // private
void generic_delete_func(void* v) void generic_delete_func(void* v) { free(v); }
{
free(v);
}
} // namespace zeek } // namespace zeek

2595
src/Dict.h

File diff suppressed because it is too large Load diff

View file

@ -14,153 +14,130 @@
#include "zeek/Var.h" #include "zeek/Var.h"
#include "zeek/ZeekString.h" #include "zeek/ZeekString.h"
namespace zeek::detail namespace zeek::detail {
{
Discarder::Discarder() Discarder::Discarder() {
{ check_ip = id::find_func("discarder_check_ip");
check_ip = id::find_func("discarder_check_ip"); check_tcp = id::find_func("discarder_check_tcp");
check_tcp = id::find_func("discarder_check_tcp"); check_udp = id::find_func("discarder_check_udp");
check_udp = id::find_func("discarder_check_udp"); check_icmp = id::find_func("discarder_check_icmp");
check_icmp = id::find_func("discarder_check_icmp");
discarder_maxlen = static_cast<int>(id::find_val("discarder_maxlen")->AsCount()); discarder_maxlen = static_cast<int>(id::find_val("discarder_maxlen")->AsCount());
} }
bool Discarder::IsActive() bool Discarder::IsActive() { return check_ip || check_tcp || check_udp || check_icmp; }
{
return check_ip || check_tcp || check_udp || check_icmp;
}
bool Discarder::NextPacket(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen) bool Discarder::NextPacket(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen) {
{ bool discard_packet = false;
bool discard_packet = false;
if ( check_ip ) if ( check_ip ) {
{ zeek::Args args{ip->ToPktHdrVal()};
zeek::Args args{ip->ToPktHdrVal()};
try try {
{ discard_packet = check_ip->Invoke(&args)->AsBool();
discard_packet = check_ip->Invoke(&args)->AsBool(); }
}
catch ( InterpreterException& e ) catch ( InterpreterException& e ) {
{ discard_packet = false;
discard_packet = false; }
}
if ( discard_packet ) if ( discard_packet )
return discard_packet; return discard_packet;
} }
int proto = ip->NextProto(); int proto = ip->NextProto();
if ( proto != IPPROTO_TCP && proto != IPPROTO_UDP && proto != IPPROTO_ICMP ) if ( proto != IPPROTO_TCP && proto != IPPROTO_UDP && proto != IPPROTO_ICMP )
// This is not a protocol we understand. // This is not a protocol we understand.
return false; return false;
// XXX shall we only check the first packet??? // XXX shall we only check the first packet???
if ( ip->IsFragment() ) if ( ip->IsFragment() )
// Never check any fragment. // Never check any fragment.
return false; return false;
int ip_hdr_len = ip->HdrLen(); int ip_hdr_len = ip->HdrLen();
len -= ip_hdr_len; // remove IP header len -= ip_hdr_len; // remove IP header
caplen -= ip_hdr_len; caplen -= ip_hdr_len;
bool is_tcp = (proto == IPPROTO_TCP); bool is_tcp = (proto == IPPROTO_TCP);
bool is_udp = (proto == IPPROTO_UDP); bool is_udp = (proto == IPPROTO_UDP);
int min_hdr_len = is_tcp ? sizeof(struct tcphdr) int min_hdr_len = is_tcp ? sizeof(struct tcphdr) : (is_udp ? sizeof(struct udphdr) : sizeof(struct icmp));
: (is_udp ? sizeof(struct udphdr) : sizeof(struct icmp));
if ( len < min_hdr_len || caplen < min_hdr_len ) if ( len < min_hdr_len || caplen < min_hdr_len )
// we don't have a complete protocol header // we don't have a complete protocol header
return false; return false;
// Where the data starts - if this is a protocol we know about, // Where the data starts - if this is a protocol we know about,
// this gets advanced past the transport header. // this gets advanced past the transport header.
const u_char* data = ip->Payload(); const u_char* data = ip->Payload();
if ( is_tcp ) if ( is_tcp ) {
{ if ( check_tcp ) {
if ( check_tcp ) const struct tcphdr* tp = (const struct tcphdr*)data;
{ int th_len = tp->th_off * 4;
const struct tcphdr* tp = (const struct tcphdr*)data;
int th_len = tp->th_off * 4;
zeek::Args args{ zeek::Args args{
ip->ToPktHdrVal(), ip->ToPktHdrVal(),
{AdoptRef{}, BuildData(data, th_len, len, caplen)}, {AdoptRef{}, BuildData(data, th_len, len, caplen)},
}; };
try try {
{ discard_packet = check_tcp->Invoke(&args)->AsBool();
discard_packet = check_tcp->Invoke(&args)->AsBool(); }
}
catch ( InterpreterException& e ) catch ( InterpreterException& e ) {
{ discard_packet = false;
discard_packet = false; }
} }
} }
}
else if ( is_udp ) else if ( is_udp ) {
{ if ( check_udp ) {
if ( check_udp ) const struct udphdr* up = (const struct udphdr*)data;
{ int uh_len = sizeof(struct udphdr);
const struct udphdr* up = (const struct udphdr*)data;
int uh_len = sizeof(struct udphdr);
zeek::Args args{ zeek::Args args{
ip->ToPktHdrVal(), ip->ToPktHdrVal(),
{AdoptRef{}, BuildData(data, uh_len, len, caplen)}, {AdoptRef{}, BuildData(data, uh_len, len, caplen)},
}; };
try try {
{ discard_packet = check_udp->Invoke(&args)->AsBool();
discard_packet = check_udp->Invoke(&args)->AsBool(); }
}
catch ( InterpreterException& e ) catch ( InterpreterException& e ) {
{ discard_packet = false;
discard_packet = false; }
} }
} }
}
else else {
{ if ( check_icmp ) {
if ( check_icmp ) const struct icmp* ih = (const struct icmp*)data;
{
const struct icmp* ih = (const struct icmp*)data;
zeek::Args args{ip->ToPktHdrVal()}; zeek::Args args{ip->ToPktHdrVal()};
try try {
{ discard_packet = check_icmp->Invoke(&args)->AsBool();
discard_packet = check_icmp->Invoke(&args)->AsBool(); }
}
catch ( InterpreterException& e ) catch ( InterpreterException& e ) {
{ discard_packet = false;
discard_packet = false; }
} }
} }
}
return discard_packet; return discard_packet;
} }
Val* Discarder::BuildData(const u_char* data, int hdrlen, int len, int caplen) Val* Discarder::BuildData(const u_char* data, int hdrlen, int len, int caplen) {
{ len -= hdrlen;
len -= hdrlen; caplen -= hdrlen;
caplen -= hdrlen; data += hdrlen;
data += hdrlen;
len = std::max(std::min(std::min(len, caplen), discarder_maxlen), 0); len = std::max(std::min(std::min(len, caplen), discarder_maxlen), 0);
return new StringVal(new String(data, len, true)); return new StringVal(new String(data, len, true));
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -7,38 +7,35 @@
#include "zeek/IntrusivePtr.h" #include "zeek/IntrusivePtr.h"
namespace zeek namespace zeek {
{
class IP_Hdr; class IP_Hdr;
class Val; class Val;
class Func; class Func;
using FuncPtr = IntrusivePtr<Func>; using FuncPtr = IntrusivePtr<Func>;
namespace detail namespace detail {
{
class Discarder final class Discarder final {
{
public: public:
Discarder(); Discarder();
~Discarder() = default; ~Discarder() = default;
bool IsActive(); bool IsActive();
bool NextPacket(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen); bool NextPacket(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen);
protected: protected:
Val* BuildData(const u_char* data, int hdrlen, int len, int caplen); Val* BuildData(const u_char* data, int hdrlen, int len, int caplen);
FuncPtr check_ip; FuncPtr check_ip;
FuncPtr check_tcp; FuncPtr check_tcp;
FuncPtr check_udp; FuncPtr check_udp;
FuncPtr check_icmp; FuncPtr check_icmp;
// Maximum amount of application data passed to filtering functions. // Maximum amount of application data passed to filtering functions.
int discarder_maxlen; int discarder_maxlen;
}; };
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -7,190 +7,168 @@
#include "zeek/CCL.h" #include "zeek/CCL.h"
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek::detail namespace zeek::detail {
{
EquivClass::EquivClass(int arg_size) EquivClass::EquivClass(int arg_size) {
{ size = arg_size;
size = arg_size; fwd = new int[size];
fwd = new int[size]; bck = new int[size];
bck = new int[size]; equiv_class = new int[size];
equiv_class = new int[size]; rep = new int[size];
rep = new int[size]; ccl_flags = nullptr;
ccl_flags = nullptr; num_ecs = 0;
num_ecs = 0;
ec_nil = no_class = no_rep = size + 1; ec_nil = no_class = no_rep = size + 1;
bck[0] = ec_nil; bck[0] = ec_nil;
fwd[size - 1] = ec_nil; fwd[size - 1] = ec_nil;
for ( int i = 0; i < size; ++i ) for ( int i = 0; i < size; ++i ) {
{ if ( i > 0 ) {
if ( i > 0 ) fwd[i - 1] = i;
{ bck[i] = i - 1;
fwd[i - 1] = i; }
bck[i] = i - 1;
}
equiv_class[i] = no_class; equiv_class[i] = no_class;
rep[i] = no_rep; rep[i] = no_rep;
} }
} }
EquivClass::~EquivClass() EquivClass::~EquivClass() {
{ delete[] fwd;
delete[] fwd; delete[] bck;
delete[] bck; delete[] equiv_class;
delete[] equiv_class; delete[] rep;
delete[] rep; delete[] ccl_flags;
delete[] ccl_flags; }
}
void EquivClass::ConvertCCL(CCL* ccl) void EquivClass::ConvertCCL(CCL* ccl) {
{ // For each character in the class, add the character's
// For each character in the class, add the character's // equivalence class to the new "character" class we are
// equivalence class to the new "character" class we are // creating. Thus when we are all done, the character class
// creating. Thus when we are all done, the character class // will really consist of collections of equivalence classes
// will really consist of collections of equivalence classes // instead of collections of characters.
// instead of collections of characters.
int_list* c_syms = ccl->Syms(); int_list* c_syms = ccl->Syms();
int_list* new_syms = new int_list; int_list* new_syms = new int_list;
for ( auto sym : *c_syms ) for ( auto sym : *c_syms ) {
{ if ( IsRep(sym) )
if ( IsRep(sym) ) new_syms->push_back(SymEquivClass(sym));
new_syms->push_back(SymEquivClass(sym)); }
}
ccl->ReplaceSyms(new_syms); ccl->ReplaceSyms(new_syms);
} }
int EquivClass::BuildECs() int EquivClass::BuildECs() {
{ // Create equivalence class numbers. If bck[x] is nil,
// Create equivalence class numbers. If bck[x] is nil, // then x is the representative of its equivalence class.
// then x is the representative of its equivalence class.
for ( int i = 0; i < size; ++i ) for ( int i = 0; i < size; ++i )
if ( bck[i] == ec_nil ) if ( bck[i] == ec_nil ) {
{ equiv_class[i] = num_ecs++;
equiv_class[i] = num_ecs++; rep[i] = i;
rep[i] = i; for ( int j = fwd[i]; j != ec_nil; j = fwd[j] ) {
for ( int j = fwd[i]; j != ec_nil; j = fwd[j] ) equiv_class[j] = equiv_class[i];
{ rep[j] = i;
equiv_class[j] = equiv_class[i]; }
rep[j] = i; }
}
}
return num_ecs; return num_ecs;
} }
void EquivClass::CCL_Use(CCL* ccl) void EquivClass::CCL_Use(CCL* ccl) {
{ // Note that it doesn't matter whether or not the character class is
// Note that it doesn't matter whether or not the character class is // negated. The same results will be obtained in either case.
// negated. The same results will be obtained in either case.
if ( ! ccl_flags ) if ( ! ccl_flags ) {
{ ccl_flags = new int[size];
ccl_flags = new int[size]; for ( int i = 0; i < size; ++i )
for ( int i = 0; i < size; ++i ) ccl_flags[i] = 0;
ccl_flags[i] = 0; }
}
int_list* csyms = ccl->Syms(); int_list* csyms = ccl->Syms();
for ( size_t i = 0; i < csyms->size(); /* no increment */ ) for ( size_t i = 0; i < csyms->size(); /* no increment */ ) {
{ int sym = (*csyms)[i];
int sym = (*csyms)[i];
int old_ec = bck[sym]; int old_ec = bck[sym];
int new_ec = sym; int new_ec = sym;
size_t j = i + 1; size_t j = i + 1;
for ( int k = fwd[sym]; k && k < size; k = fwd[k] ) for ( int k = fwd[sym]; k && k < size; k = fwd[k] ) { // look for the symbol in the character class
{ // look for the symbol in the character class for ( ; j < csyms->size(); ++j ) {
for ( ; j < csyms->size(); ++j ) if ( (*csyms)[j] > k )
{ // Since the character class is sorted,
if ( (*csyms)[j] > k ) // we can stop.
// Since the character class is sorted, break;
// we can stop.
break;
if ( (*csyms)[j] == k && ! ccl_flags[j] ) if ( (*csyms)[j] == k && ! ccl_flags[j] ) {
{ // We found an old companion of sym
// We found an old companion of sym // in the ccl. Link it into the new
// in the ccl. Link it into the new // equivalence class and flag it as
// equivalence class and flag it as // having been processed.
// having been processed. bck[k] = new_ec;
bck[k] = new_ec; fwd[new_ec] = k;
fwd[new_ec] = k; new_ec = k;
new_ec = k;
// Set flag so we don't reprocess. // Set flag so we don't reprocess.
ccl_flags[j] = 1; ccl_flags[j] = 1;
// Get next equivalence class member. // Get next equivalence class member.
break; break;
} }
} }
if ( j < csyms->size() && (*csyms)[j] == k ) if ( j < csyms->size() && (*csyms)[j] == k )
// We broke out of the above loop by finding // We broke out of the above loop by finding
// an old companion - go to the next symbol. // an old companion - go to the next symbol.
continue; continue;
// Symbol isn't in character class. Put it in the old // Symbol isn't in character class. Put it in the old
// equivalence class. // equivalence class.
bck[k] = old_ec; bck[k] = old_ec;
if ( old_ec != ec_nil ) if ( old_ec != ec_nil )
fwd[old_ec] = k; fwd[old_ec] = k;
old_ec = k; old_ec = k;
} }
if ( bck[sym] != ec_nil || old_ec != bck[sym] ) if ( bck[sym] != ec_nil || old_ec != bck[sym] ) {
{ bck[sym] = ec_nil;
bck[sym] = ec_nil; fwd[old_ec] = ec_nil;
fwd[old_ec] = ec_nil; }
}
fwd[new_ec] = ec_nil; fwd[new_ec] = ec_nil;
// Find next ccl member to process. // Find next ccl member to process.
for ( ++i; i < csyms->size() && ccl_flags[i]; ++i ) for ( ++i; i < csyms->size() && ccl_flags[i]; ++i )
// Reset "doesn't need processing" flag. // Reset "doesn't need processing" flag.
ccl_flags[i] = 0; ccl_flags[i] = 0;
} }
} }
void EquivClass::UniqueChar(int sym) void EquivClass::UniqueChar(int sym) {
{ // If until now the character has been a proper subset of
// If until now the character has been a proper subset of // an equivalence class, break it away to create a new ec.
// an equivalence class, break it away to create a new ec.
if ( fwd[sym] != ec_nil ) if ( fwd[sym] != ec_nil )
bck[fwd[sym]] = bck[sym]; bck[fwd[sym]] = bck[sym];
if ( bck[sym] != ec_nil ) if ( bck[sym] != ec_nil )
fwd[bck[sym]] = fwd[sym]; fwd[bck[sym]] = fwd[sym];
fwd[sym] = ec_nil; fwd[sym] = ec_nil;
bck[sym] = ec_nil; bck[sym] = ec_nil;
} }
void EquivClass::Dump(FILE* f) void EquivClass::Dump(FILE* f) {
{ fprintf(f, "%d symbols in EC yielded %d ecs\n", size, num_ecs);
fprintf(f, "%d symbols in EC yielded %d ecs\n", size, num_ecs); for ( int i = 0; i < size; ++i )
for ( int i = 0; i < size; ++i ) if ( SymEquivClass(i) != 0 ) // skip usually huge default ec
if ( SymEquivClass(i) != 0 ) // skip usually huge default ec fprintf(f, "map %d ('%c') -> %d\n", i, i, SymEquivClass(i));
fprintf(f, "map %d ('%c') -> %d\n", i, i, SymEquivClass(i)); }
}
int EquivClass::Size() const int EquivClass::Size() const { return padded_sizeof(*this) + util::pad_size(sizeof(int) * size * (ccl_flags ? 5 : 4)); }
{
return padded_sizeof(*this) + util::pad_size(sizeof(int) * size * (ccl_flags ? 5 : 4));
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -4,46 +4,44 @@
#include <stdio.h> #include <stdio.h>
namespace zeek::detail namespace zeek::detail {
{
class CCL; class CCL;
class EquivClass class EquivClass {
{
public: public:
explicit EquivClass(int size); explicit EquivClass(int size);
~EquivClass(); ~EquivClass();
void UniqueChar(int sym); void UniqueChar(int sym);
void CCL_Use(CCL* ccl); void CCL_Use(CCL* ccl);
// All done adding character usage info - generate equivalence // All done adding character usage info - generate equivalence
// classes. Returns number of classes. // classes. Returns number of classes.
int BuildECs(); int BuildECs();
void ConvertCCL(CCL* ccl); void ConvertCCL(CCL* ccl);
bool IsRep(int sym) const { return rep[sym] == sym; } bool IsRep(int sym) const { return rep[sym] == sym; }
int EquivRep(int sym) const { return rep[sym]; } int EquivRep(int sym) const { return rep[sym]; }
int SymEquivClass(int sym) const { return equiv_class[sym]; } int SymEquivClass(int sym) const { return equiv_class[sym]; }
int* EquivClasses() const { return equiv_class; } int* EquivClasses() const { return equiv_class; }
int NumSyms() const { return size; } int NumSyms() const { return size; }
int NumClasses() const { return num_ecs; } int NumClasses() const { return num_ecs; }
void Dump(FILE* f); void Dump(FILE* f);
int Size() const; int Size() const;
protected: protected:
int size; // size of character set int size; // size of character set
int num_ecs; // size of equivalence classes int num_ecs; // size of equivalence classes
int* fwd; // forward list of different classes int* fwd; // forward list of different classes
int* bck; // backward list int* bck; // backward list
int* equiv_class; // symbol's equivalence class int* equiv_class; // symbol's equivalence class
int* rep; // representative for symbol's equivalence class int* rep; // representative for symbol's equivalence class
int* ccl_flags; int* ccl_flags;
int ec_nil, no_class, no_rep; int ec_nil, no_class, no_rep;
}; };
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -15,204 +15,187 @@
zeek::EventMgr zeek::event_mgr; zeek::EventMgr zeek::event_mgr;
namespace zeek namespace zeek {
{
Event::Event(const EventHandlerPtr& arg_handler, zeek::Args arg_args, Event::Event(const EventHandlerPtr& arg_handler, zeek::Args arg_args, util::detail::SourceID arg_src,
util::detail::SourceID arg_src, analyzer::ID arg_aid, Obj* arg_obj, double arg_ts) analyzer::ID arg_aid, Obj* arg_obj, double arg_ts)
: handler(arg_handler), args(std::move(arg_args)), src(arg_src), aid(arg_aid), ts(arg_ts), : handler(arg_handler),
obj(arg_obj), next_event(nullptr) args(std::move(arg_args)),
{ src(arg_src),
if ( obj ) aid(arg_aid),
Ref(obj); ts(arg_ts),
} obj(arg_obj),
next_event(nullptr) {
if ( obj )
Ref(obj);
}
void Event::Describe(ODesc* d) const void Event::Describe(ODesc* d) const {
{ if ( d->IsReadable() )
if ( d->IsReadable() ) d->AddSP("event");
d->AddSP("event");
bool s = d->IsShort(); bool s = d->IsShort();
d->SetShort(s); d->SetShort(s);
if ( ! d->IsBinary() ) if ( ! d->IsBinary() )
d->Add("("); d->Add("(");
describe_vals(args, d); describe_vals(args, d);
if ( ! d->IsBinary() ) if ( ! d->IsBinary() )
d->Add("("); d->Add("(");
} }
void Event::Dispatch(bool no_remote) void Event::Dispatch(bool no_remote) {
{ if ( src == util::detail::SOURCE_BROKER )
if ( src == util::detail::SOURCE_BROKER ) no_remote = true;
no_remote = true;
if ( handler->ErrorHandler() ) if ( handler->ErrorHandler() )
reporter->BeginErrorHandler(); reporter->BeginErrorHandler();
try try {
{ handler->Call(&args, no_remote, ts);
handler->Call(&args, no_remote, ts); }
}
catch ( InterpreterException& e ) catch ( InterpreterException& e ) {
{ // Already reported.
// Already reported. }
}
if ( obj ) if ( obj )
// obj->EventDone(); // obj->EventDone();
Unref(obj); Unref(obj);
if ( handler->ErrorHandler() ) if ( handler->ErrorHandler() )
reporter->EndErrorHandler(); reporter->EndErrorHandler();
} }
EventMgr::EventMgr() EventMgr::EventMgr() {
{ head = tail = nullptr;
head = tail = nullptr; current_src = util::detail::SOURCE_LOCAL;
current_src = util::detail::SOURCE_LOCAL; current_aid = 0;
current_aid = 0; current_ts = 0;
current_ts = 0; src_val = nullptr;
src_val = nullptr; draining = false;
draining = false; }
}
EventMgr::~EventMgr() EventMgr::~EventMgr() {
{ while ( head ) {
while ( head ) Event* n = head->NextEvent();
{ Unref(head);
Event* n = head->NextEvent(); head = n;
Unref(head); }
head = n;
}
Unref(src_val); Unref(src_val);
} }
void EventMgr::Enqueue(const EventHandlerPtr& h, Args vl, util::detail::SourceID src, void EventMgr::Enqueue(const EventHandlerPtr& h, Args vl, util::detail::SourceID src, analyzer::ID aid, Obj* obj,
analyzer::ID aid, Obj* obj, double ts) double ts) {
{ QueueEvent(new Event(h, std::move(vl), src, aid, obj, ts));
QueueEvent(new Event(h, std::move(vl), src, aid, obj, ts)); }
}
void EventMgr::QueueEvent(Event* event) void EventMgr::QueueEvent(Event* event) {
{ bool done = PLUGIN_HOOK_WITH_RESULT(HOOK_QUEUE_EVENT, HookQueueEvent(event), false);
bool done = PLUGIN_HOOK_WITH_RESULT(HOOK_QUEUE_EVENT, HookQueueEvent(event), false);
if ( done ) if ( done )
return; return;
if ( ! head ) if ( ! head ) {
{ head = tail = event;
head = tail = event; queue_flare.Fire();
queue_flare.Fire(); }
} else {
else tail->SetNext(event);
{ tail = event;
tail->SetNext(event); }
tail = event;
}
++event_mgr.num_events_queued; ++event_mgr.num_events_queued;
} }
void EventMgr::Dispatch(Event* event, bool no_remote) void EventMgr::Dispatch(Event* event, bool no_remote) {
{ current_src = event->Source();
current_src = event->Source(); current_aid = event->Analyzer();
current_aid = event->Analyzer(); current_ts = event->Time();
current_ts = event->Time(); event->Dispatch(no_remote);
event->Dispatch(no_remote); Unref(event);
Unref(event); }
}
void EventMgr::Drain() void EventMgr::Drain() {
{ if ( event_queue_flush_point )
if ( event_queue_flush_point ) Enqueue(event_queue_flush_point, Args{});
Enqueue(event_queue_flush_point, Args{});
detail::SegmentProfiler prof(detail::segment_logger, "draining-events"); detail::SegmentProfiler prof(detail::segment_logger, "draining-events");
PLUGIN_HOOK_VOID(HOOK_DRAIN_EVENTS, HookDrainEvents()); PLUGIN_HOOK_VOID(HOOK_DRAIN_EVENTS, HookDrainEvents());
draining = true; draining = true;
// Past Zeek versions drained as long as there events, including when // Past Zeek versions drained as long as there events, including when
// a handler queued new events during its execution. This could lead // a handler queued new events during its execution. This could lead
// to endless loops in case a handler kept triggering its own event. // to endless loops in case a handler kept triggering its own event.
// We now limit this to just a couple of rounds. We do more than // We now limit this to just a couple of rounds. We do more than
// just one round to make it less likely to break existing scripts // just one round to make it less likely to break existing scripts
// that expect the old behavior to trigger something quickly. // that expect the old behavior to trigger something quickly.
for ( int round = 0; head && round < 2; round++ ) for ( int round = 0; head && round < 2; round++ ) {
{ Event* current = head;
Event* current = head; head = nullptr;
head = nullptr; tail = nullptr;
tail = nullptr;
while ( current ) while ( current ) {
{ Event* next = current->NextEvent();
Event* next = current->NextEvent();
current_src = current->Source(); current_src = current->Source();
current_aid = current->Analyzer(); current_aid = current->Analyzer();
current_ts = current->Time(); current_ts = current->Time();
current->Dispatch(); current->Dispatch();
Unref(current); Unref(current);
++event_mgr.num_events_dispatched; ++event_mgr.num_events_dispatched;
current = next; current = next;
} }
} }
// Note: we might eventually need a general way to specify things to // Note: we might eventually need a general way to specify things to
// do after draining events. // do after draining events.
draining = false; draining = false;
// Make sure all of the triggers get processed every time the events // Make sure all of the triggers get processed every time the events
// drain. // drain.
detail::trigger_mgr->Process(); detail::trigger_mgr->Process();
} }
void EventMgr::Describe(ODesc* d) const void EventMgr::Describe(ODesc* d) const {
{ int n = 0;
int n = 0; Event* e;
Event* e; for ( e = head; e; e = e->NextEvent() )
for ( e = head; e; e = e->NextEvent() ) ++n;
++n;
d->AddCount(n); d->AddCount(n);
for ( e = head; e; e = e->NextEvent() ) for ( e = head; e; e = e->NextEvent() ) {
{ e->Describe(d);
e->Describe(d); d->NL();
d->NL(); }
} }
}
void EventMgr::Process() void EventMgr::Process() {
{ queue_flare.Extinguish();
queue_flare.Extinguish();
// While it semes like the most logical thing to do, we dont want // While it semes like the most logical thing to do, we dont want
// to call Drain() as part of this method. It will get called at // to call Drain() as part of this method. It will get called at
// the end of net_run after all of the sources have been processed // the end of net_run after all of the sources have been processed
// and had the opportunity to spawn new events. // and had the opportunity to spawn new events.
} }
void EventMgr::InitPostScript() void EventMgr::InitPostScript() {
{ iosource_mgr->Register(this, true, false);
iosource_mgr->Register(this, true, false); if ( ! iosource_mgr->RegisterFd(queue_flare.FD(), this) )
if ( ! iosource_mgr->RegisterFd(queue_flare.FD(), this) ) reporter->FatalError("Failed to register event manager FD with iosource_mgr");
reporter->FatalError("Failed to register event manager FD with iosource_mgr"); }
}
void EventMgr::InitPostFork() void EventMgr::InitPostFork() {
{ // Re-initialize the flare, closing and re-opening the underlying
// Re-initialize the flare, closing and re-opening the underlying // pipe FDs. This is needed so that each Zeek process in a supervisor
// pipe FDs. This is needed so that each Zeek process in a supervisor // setup has its own pipe instead of them all sharing a single pipe.
// setup has its own pipe instead of them all sharing a single pipe. queue_flare = zeek::detail::Flare{};
queue_flare = zeek::detail::Flare{}; }
}
} // namespace zeek } // namespace zeek

View file

@ -12,131 +12,124 @@
#include "zeek/analyzer/Analyzer.h" #include "zeek/analyzer/Analyzer.h"
#include "zeek/iosource/IOSource.h" #include "zeek/iosource/IOSource.h"
namespace zeek namespace zeek {
{
namespace run_state namespace run_state {
{
extern double network_time; extern double network_time;
} // namespace run_state } // namespace run_state
class EventMgr; class EventMgr;
class Event final : public Obj class Event final : public Obj {
{
public: public:
Event(const EventHandlerPtr& handler, zeek::Args args, Event(const EventHandlerPtr& handler, zeek::Args args, util::detail::SourceID src = util::detail::SOURCE_LOCAL,
util::detail::SourceID src = util::detail::SOURCE_LOCAL, analyzer::ID aid = 0, analyzer::ID aid = 0, Obj* obj = nullptr, double ts = run_state::network_time);
Obj* obj = nullptr, double ts = run_state::network_time);
void SetNext(Event* n) { next_event = n; } void SetNext(Event* n) { next_event = n; }
Event* NextEvent() const { return next_event; } Event* NextEvent() const { return next_event; }
util::detail::SourceID Source() const { return src; } util::detail::SourceID Source() const { return src; }
analyzer::ID Analyzer() const { return aid; } analyzer::ID Analyzer() const { return aid; }
EventHandlerPtr Handler() const { return handler; } EventHandlerPtr Handler() const { return handler; }
const zeek::Args& Args() const { return args; } const zeek::Args& Args() const { return args; }
double Time() const { return ts; } double Time() const { return ts; }
void Describe(ODesc* d) const override; void Describe(ODesc* d) const override;
protected: protected:
friend class EventMgr; friend class EventMgr;
// This method is protected to make sure that everybody goes through // This method is protected to make sure that everybody goes through
// EventMgr::Dispatch(). // EventMgr::Dispatch().
void Dispatch(bool no_remote = false); void Dispatch(bool no_remote = false);
EventHandlerPtr handler; EventHandlerPtr handler;
zeek::Args args; zeek::Args args;
util::detail::SourceID src; util::detail::SourceID src;
analyzer::ID aid; analyzer::ID aid;
double ts; double ts;
Obj* obj; Obj* obj;
Event* next_event; Event* next_event;
}; };
class EventMgr final : public Obj, public iosource::IOSource class EventMgr final : public Obj, public iosource::IOSource {
{
public: public:
EventMgr(); EventMgr();
~EventMgr() override; ~EventMgr() override;
/** /**
* Adds an event to the queue. If no handler is found for the event * Adds an event to the queue. If no handler is found for the event
* when later going to call it, nothing happens except for having * when later going to call it, nothing happens except for having
* wasted a bit of time/resources, so callers may want to first check * wasted a bit of time/resources, so callers may want to first check
* if any handler/consumer exists before enqueuing an event. * if any handler/consumer exists before enqueuing an event.
* @param h reference to the event handler to later call. * @param h reference to the event handler to later call.
* @param vl the argument list to the event handler call. * @param vl the argument list to the event handler call.
* @param src indicates the origin of the event (local versus remote). * @param src indicates the origin of the event (local versus remote).
* @param aid identifies the protocol analyzer generating the event. * @param aid identifies the protocol analyzer generating the event.
* @param obj an arbitrary object to use as a "cookie" or just hold a * @param obj an arbitrary object to use as a "cookie" or just hold a
* reference to until dispatching the event. * reference to until dispatching the event.
* @param ts timestamp at which the event is intended to be executed * @param ts timestamp at which the event is intended to be executed
* (defaults to current network time). * (defaults to current network time).
*/ */
void Enqueue(const EventHandlerPtr& h, zeek::Args vl, void Enqueue(const EventHandlerPtr& h, zeek::Args vl, util::detail::SourceID src = util::detail::SOURCE_LOCAL,
util::detail::SourceID src = util::detail::SOURCE_LOCAL, analyzer::ID aid = 0, analyzer::ID aid = 0, Obj* obj = nullptr, double ts = run_state::network_time);
Obj* obj = nullptr, double ts = run_state::network_time);
/** /**
* A version of Enqueue() taking a variable number of arguments. * A version of Enqueue() taking a variable number of arguments.
*/ */
template <class... Args> template<class... Args>
std::enable_if_t<std::is_convertible_v<std::tuple_element_t<0, std::tuple<Args...>>, ValPtr>> std::enable_if_t<std::is_convertible_v<std::tuple_element_t<0, std::tuple<Args...>>, ValPtr>> Enqueue(
Enqueue(const EventHandlerPtr& h, Args&&... args) const EventHandlerPtr& h, Args&&... args) {
{ return Enqueue(h, zeek::Args{std::forward<Args>(args)...});
return Enqueue(h, zeek::Args{std::forward<Args>(args)...}); }
}
void Dispatch(Event* event, bool no_remote = false); void Dispatch(Event* event, bool no_remote = false);
void Drain(); void Drain();
bool IsDraining() const { return draining; } bool IsDraining() const { return draining; }
bool HasEvents() const { return head != nullptr; } bool HasEvents() const { return head != nullptr; }
// Returns the source ID of last raised event. // Returns the source ID of last raised event.
util::detail::SourceID CurrentSource() const { return current_src; } util::detail::SourceID CurrentSource() const { return current_src; }
// Returns the ID of the analyzer which raised the last event, or 0 if // Returns the ID of the analyzer which raised the last event, or 0 if
// non-analyzer event. // non-analyzer event.
analyzer::ID CurrentAnalyzer() const { return current_aid; } analyzer::ID CurrentAnalyzer() const { return current_aid; }
// Returns the timestamp of the last raised event. The timestamp reflects the network time // Returns the timestamp of the last raised event. The timestamp reflects the network time
// the event was intended to be executed. For scheduled events, this is the time the event // the event was intended to be executed. For scheduled events, this is the time the event
// was scheduled to. For any other event, this is the time when the event was created. // was scheduled to. For any other event, this is the time when the event was created.
double CurrentEventTime() const { return current_ts; } double CurrentEventTime() const { return current_ts; }
int Size() const { return num_events_queued - num_events_dispatched; } int Size() const { return num_events_queued - num_events_dispatched; }
void Describe(ODesc* d) const override; void Describe(ODesc* d) const override;
double GetNextTimeout() override { return -1; } double GetNextTimeout() override { return -1; }
void Process() override; void Process() override;
const char* Tag() override { return "EventManager"; } const char* Tag() override { return "EventManager"; }
void InitPostScript(); void InitPostScript();
// Initialization to be done after a fork() happened. // Initialization to be done after a fork() happened.
void InitPostFork(); void InitPostFork();
uint64_t num_events_queued = 0; uint64_t num_events_queued = 0;
uint64_t num_events_dispatched = 0; uint64_t num_events_dispatched = 0;
protected: protected:
void QueueEvent(Event* event); void QueueEvent(Event* event);
Event* head; Event* head;
Event* tail; Event* tail;
util::detail::SourceID current_src; util::detail::SourceID current_src;
analyzer::ID current_aid; analyzer::ID current_aid;
double current_ts; double current_ts;
RecordVal* src_val; RecordVal* src_val;
bool draining; bool draining;
detail::Flare queue_flare; detail::Flare queue_flare;
}; };
extern EventMgr event_mgr; extern EventMgr event_mgr;
} // namespace zeek } // namespace zeek

View file

@ -11,127 +11,108 @@
#include "zeek/broker/Manager.h" #include "zeek/broker/Manager.h"
#include "zeek/telemetry/Manager.h" #include "zeek/telemetry/Manager.h"
namespace zeek namespace zeek {
{
EventHandler::EventHandler(std::string arg_name) EventHandler::EventHandler(std::string arg_name) {
{ name = std::move(arg_name);
name = std::move(arg_name); used = false;
used = false; error_handler = false;
error_handler = false; enabled = true;
enabled = true; generate_always = false;
generate_always = false; }
}
EventHandler::operator bool() const EventHandler::operator bool() const {
{ return enabled && ((local && local->HasEnabledBodies()) || generate_always || ! auto_publish.empty());
return enabled && }
((local && local->HasEnabledBodies()) || generate_always || ! auto_publish.empty());
}
const FuncTypePtr& EventHandler::GetType(bool check_export) const FuncTypePtr& EventHandler::GetType(bool check_export) {
{ if ( type )
if ( type ) return type;
return type;
const auto& id = detail::lookup_ID(name.data(), detail::current_module.c_str(), false, false, const auto& id = detail::lookup_ID(name.data(), detail::current_module.c_str(), false, false, check_export);
check_export);
if ( ! id ) if ( ! id )
return FuncType::nil; return FuncType::nil;
if ( id->GetType()->Tag() != TYPE_FUNC ) if ( id->GetType()->Tag() != TYPE_FUNC )
return FuncType::nil; return FuncType::nil;
type = id->GetType<FuncType>(); type = id->GetType<FuncType>();
return type; return type;
} }
void EventHandler::SetFunc(FuncPtr f) void EventHandler::SetFunc(FuncPtr f) { local = std::move(f); }
{
local = std::move(f);
}
void EventHandler::Call(Args* vl, bool no_remote, double ts) void EventHandler::Call(Args* vl, bool no_remote, double ts) {
{ if ( ! call_count ) {
if ( ! call_count ) static auto eh_invocations_family =
{ telemetry_mgr->CounterFamily("zeek", "event-handler-invocations", {"name"},
static auto eh_invocations_family = telemetry_mgr->CounterFamily( "Number of times the given event handler was called", "1", true);
"zeek", "event-handler-invocations", {"name"},
"Number of times the given event handler was called", "1", true);
call_count = eh_invocations_family.GetOrAdd({{"name", name}}); call_count = eh_invocations_family.GetOrAdd({{"name", name}});
} }
call_count->Inc(); call_count->Inc();
if ( new_event ) if ( new_event )
NewEvent(vl); NewEvent(vl);
if ( ! no_remote ) if ( ! no_remote ) {
{ if ( ! auto_publish.empty() ) {
if ( ! auto_publish.empty() ) // Send event in form [name, xs...] where xs represent the arguments.
{ broker::vector xs;
// Send event in form [name, xs...] where xs represent the arguments. xs.reserve(vl->size());
broker::vector xs; bool valid_args = true;
xs.reserve(vl->size());
bool valid_args = true;
for ( auto i = 0u; i < vl->size(); ++i ) for ( auto i = 0u; i < vl->size(); ++i ) {
{ auto opt_data = Broker::detail::val_to_data((*vl)[i].get());
auto opt_data = Broker::detail::val_to_data((*vl)[i].get());
if ( opt_data ) if ( opt_data )
xs.emplace_back(std::move(*opt_data)); xs.emplace_back(std::move(*opt_data));
else else {
{ valid_args = false;
valid_args = false; auto_publish.clear();
auto_publish.clear(); reporter->Error("failed auto-remote event '%s', disabled", Name());
reporter->Error("failed auto-remote event '%s', disabled", Name()); break;
break; }
} }
}
if ( valid_args ) if ( valid_args ) {
{ for ( auto it = auto_publish.begin();; ) {
for ( auto it = auto_publish.begin();; ) const auto& topic = *it;
{ ++it;
const auto& topic = *it;
++it;
if ( it != auto_publish.end() ) if ( it != auto_publish.end() )
broker_mgr->PublishEvent(topic, Name(), xs, ts); broker_mgr->PublishEvent(topic, Name(), xs, ts);
else else {
{ broker_mgr->PublishEvent(topic, Name(), std::move(xs), ts);
broker_mgr->PublishEvent(topic, Name(), std::move(xs), ts); break;
break; }
} }
} }
} }
} }
}
if ( local ) if ( local )
// No try/catch here; we pass exceptions upstream. // No try/catch here; we pass exceptions upstream.
local->Invoke(vl); local->Invoke(vl);
} }
void EventHandler::NewEvent(Args* vl) void EventHandler::NewEvent(Args* vl) {
{ if ( ! new_event )
if ( ! new_event ) return;
return;
if ( this == new_event.Ptr() ) if ( this == new_event.Ptr() )
// new_event() is the one event we don't want to report. // new_event() is the one event we don't want to report.
return; return;
auto vargs = MakeCallArgumentVector(*vl, GetType()->Params()); auto vargs = MakeCallArgumentVector(*vl, GetType()->Params());
auto ev = new Event(new_event, { auto ev = new Event(new_event, {
make_intrusive<StringVal>(name), make_intrusive<StringVal>(name),
std::move(vargs), std::move(vargs),
}); });
event_mgr.Dispatch(ev); event_mgr.Dispatch(ev);
} }
} // namespace zeek } // namespace zeek

View file

@ -11,104 +11,95 @@
#include "zeek/ZeekList.h" #include "zeek/ZeekList.h"
#include "zeek/telemetry/Counter.h" #include "zeek/telemetry/Counter.h"
namespace zeek namespace zeek {
{
namespace run_state namespace run_state {
{
extern double network_time; extern double network_time;
} // namespace run_state } // namespace run_state
class Func; class Func;
using FuncPtr = IntrusivePtr<Func>; using FuncPtr = IntrusivePtr<Func>;
class EventHandler class EventHandler {
{
public: public:
explicit EventHandler(std::string name); explicit EventHandler(std::string name);
const char* Name() const { return name.data(); } const char* Name() const { return name.data(); }
const FuncPtr& GetFunc() const { return local; } const FuncPtr& GetFunc() const { return local; }
const FuncTypePtr& GetType(bool check_export = true); const FuncTypePtr& GetType(bool check_export = true);
void SetFunc(FuncPtr f); void SetFunc(FuncPtr f);
void AutoPublish(std::string topic) { auto_publish.insert(std::move(topic)); } void AutoPublish(std::string topic) { auto_publish.insert(std::move(topic)); }
void AutoUnpublish(const std::string& topic) { auto_publish.erase(topic); } void AutoUnpublish(const std::string& topic) { auto_publish.erase(topic); }
void Call(zeek::Args* vl, bool no_remote = false, double ts = run_state::network_time); void Call(zeek::Args* vl, bool no_remote = false, double ts = run_state::network_time);
// Returns true if there is at least one local or remote handler. // Returns true if there is at least one local or remote handler.
explicit operator bool() const; explicit operator bool() const;
void SetUsed() { used = true; } void SetUsed() { used = true; }
bool Used() const { return used; } bool Used() const { return used; }
// Handlers marked as error handlers will not be called recursively to // Handlers marked as error handlers will not be called recursively to
// avoid infinite loops if they trigger a similar error themselves. // avoid infinite loops if they trigger a similar error themselves.
void SetErrorHandler() { error_handler = true; } void SetErrorHandler() { error_handler = true; }
bool ErrorHandler() const { return error_handler; } bool ErrorHandler() const { return error_handler; }
void SetEnable(bool arg_enable) { enabled = arg_enable; } void SetEnable(bool arg_enable) { enabled = arg_enable; }
// Flags the event as interesting even if there is no body defined. In // Flags the event as interesting even if there is no body defined. In
// particular, this will then still pass the event on to plugins. // particular, this will then still pass the event on to plugins.
void SetGenerateAlways(bool arg_generate_always = true) void SetGenerateAlways(bool arg_generate_always = true) { generate_always = arg_generate_always; }
{ bool GenerateAlways() const { return generate_always; }
generate_always = arg_generate_always;
}
bool GenerateAlways() const { return generate_always; }
uint64_t CallCount() const { return call_count ? call_count->Value() : 0; } uint64_t CallCount() const { return call_count ? call_count->Value() : 0; }
private: private:
void NewEvent(zeek::Args* vl); // Raise new_event() meta event. void NewEvent(zeek::Args* vl); // Raise new_event() meta event.
std::string name; std::string name;
FuncPtr local; FuncPtr local;
FuncTypePtr type; FuncTypePtr type;
bool used; // this handler is indeed used somewhere bool used; // this handler is indeed used somewhere
bool enabled; bool enabled;
bool error_handler; // this handler reports error messages. bool error_handler; // this handler reports error messages.
bool generate_always; bool generate_always;
// Initialize this lazy, so we don't expose metrics for 0 values. // Initialize this lazy, so we don't expose metrics for 0 values.
std::optional<zeek::telemetry::IntCounter> call_count; std::optional<zeek::telemetry::IntCounter> call_count;
std::unordered_set<std::string> auto_publish; std::unordered_set<std::string> auto_publish;
}; };
// Encapsulates a ptr to an event handler to overload the boolean operator. // Encapsulates a ptr to an event handler to overload the boolean operator.
class EventHandlerPtr class EventHandlerPtr {
{
public: public:
EventHandlerPtr(EventHandler* p = nullptr) { handler = p; } EventHandlerPtr(EventHandler* p = nullptr) { handler = p; }
EventHandlerPtr(const EventHandlerPtr& h) { handler = h.handler; } EventHandlerPtr(const EventHandlerPtr& h) { handler = h.handler; }
const EventHandlerPtr& operator=(EventHandler* p) const EventHandlerPtr& operator=(EventHandler* p) {
{ handler = p;
handler = p; return *this;
return *this; }
} const EventHandlerPtr& operator=(const EventHandlerPtr& h) {
const EventHandlerPtr& operator=(const EventHandlerPtr& h) handler = h.handler;
{ return *this;
handler = h.handler; }
return *this;
}
bool operator==(const EventHandlerPtr& h) const { return handler == h.handler; } bool operator==(const EventHandlerPtr& h) const { return handler == h.handler; }
EventHandler* Ptr() { return handler; } EventHandler* Ptr() { return handler; }
explicit operator bool() const { return handler && *handler; } explicit operator bool() const { return handler && *handler; }
EventHandler* operator->() { return handler; } EventHandler* operator->() { return handler; }
const EventHandler* operator->() const { return handler; } const EventHandler* operator->() const { return handler; }
private: private:
EventHandler* handler; EventHandler* handler;
}; };
} // namespace zeek } // namespace zeek

View file

@ -7,167 +7,144 @@
#include "zeek/RE.h" #include "zeek/RE.h"
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
namespace zeek namespace zeek {
{
EventRegistry::EventRegistry() = default; EventRegistry::EventRegistry() = default;
EventRegistry::~EventRegistry() noexcept = default; EventRegistry::~EventRegistry() noexcept = default;
EventHandlerPtr EventRegistry::Register(std::string_view name, bool is_from_script) EventHandlerPtr EventRegistry::Register(std::string_view name, bool is_from_script) {
{ // If there already is an entry in the registry, we have a
// If there already is an entry in the registry, we have a // local handler on the script layer.
// local handler on the script layer. EventHandler* h = event_registry->Lookup(name);
EventHandler* h = event_registry->Lookup(name);
if ( h ) if ( h ) {
{ if ( ! is_from_script )
if ( ! is_from_script ) not_only_from_script.insert(std::string(name));
not_only_from_script.insert(std::string(name));
h->SetUsed(); h->SetUsed();
return h; return h;
} }
h = new EventHandler(std::string(name)); h = new EventHandler(std::string(name));
event_registry->Register(h, is_from_script); event_registry->Register(h, is_from_script);
h->SetUsed(); h->SetUsed();
return h; return h;
} }
void EventRegistry::Register(EventHandlerPtr handler, bool is_from_script) void EventRegistry::Register(EventHandlerPtr handler, bool is_from_script) {
{ std::string name = handler->Name();
std::string name = handler->Name();
handlers[name] = std::unique_ptr<EventHandler>(handler.Ptr()); handlers[name] = std::unique_ptr<EventHandler>(handler.Ptr());
if ( ! is_from_script ) if ( ! is_from_script )
not_only_from_script.insert(name); not_only_from_script.insert(name);
} }
EventHandler* EventRegistry::Lookup(std::string_view name) EventHandler* EventRegistry::Lookup(std::string_view name) {
{ auto it = handlers.find(name);
auto it = handlers.find(name); if ( it != handlers.end() )
if ( it != handlers.end() ) return it->second.get();
return it->second.get();
return nullptr; return nullptr;
} }
bool EventRegistry::NotOnlyRegisteredFromScript(std::string_view name) bool EventRegistry::NotOnlyRegisteredFromScript(std::string_view name) {
{ return not_only_from_script.count(std::string(name)) > 0;
return not_only_from_script.count(std::string(name)) > 0; }
}
EventRegistry::string_list EventRegistry::Match(RE_Matcher* pattern) EventRegistry::string_list EventRegistry::Match(RE_Matcher* pattern) {
{ string_list names;
string_list names;
for ( const auto& entry : handlers ) for ( const auto& entry : handlers ) {
{ EventHandler* v = entry.second.get();
EventHandler* v = entry.second.get(); if ( v->GetFunc() && pattern->MatchExactly(v->Name()) )
if ( v->GetFunc() && pattern->MatchExactly(v->Name()) ) names.push_back(entry.first);
names.push_back(entry.first); }
}
return names; return names;
} }
EventRegistry::string_list EventRegistry::UnusedHandlers() EventRegistry::string_list EventRegistry::UnusedHandlers() {
{ string_list names;
string_list names;
for ( const auto& entry : handlers ) for ( const auto& entry : handlers ) {
{ EventHandler* v = entry.second.get();
EventHandler* v = entry.second.get(); if ( v->GetFunc() && ! v->Used() )
if ( v->GetFunc() && ! v->Used() ) names.push_back(entry.first);
names.push_back(entry.first); }
}
return names; return names;
} }
EventRegistry::string_list EventRegistry::UsedHandlers() EventRegistry::string_list EventRegistry::UsedHandlers() {
{ string_list names;
string_list names;
for ( const auto& entry : handlers ) for ( const auto& entry : handlers ) {
{ EventHandler* v = entry.second.get();
EventHandler* v = entry.second.get(); if ( v->GetFunc() && v->Used() )
if ( v->GetFunc() && v->Used() ) names.push_back(entry.first);
names.push_back(entry.first); }
}
return names; return names;
} }
EventRegistry::string_list EventRegistry::AllHandlers() EventRegistry::string_list EventRegistry::AllHandlers() {
{ string_list names;
string_list names;
for ( const auto& entry : handlers ) for ( const auto& entry : handlers ) {
{ names.push_back(entry.first);
names.push_back(entry.first); }
}
return names; return names;
} }
void EventRegistry::PrintDebug() void EventRegistry::PrintDebug() {
{ for ( const auto& entry : handlers ) {
for ( const auto& entry : handlers ) EventHandler* v = entry.second.get();
{ fprintf(stderr, "Registered event %s (%s handler / %s)\n", v->Name(), v->GetFunc() ? "local" : "no",
EventHandler* v = entry.second.get(); *v ? "active" : "not active");
fprintf(stderr, "Registered event %s (%s handler / %s)\n", v->Name(), }
v->GetFunc() ? "local" : "no", *v ? "active" : "not active"); }
}
}
void EventRegistry::SetErrorHandler(std::string_view name) void EventRegistry::SetErrorHandler(std::string_view name) {
{ EventHandler* eh = Lookup(name);
EventHandler* eh = Lookup(name);
if ( eh ) if ( eh ) {
{ eh->SetErrorHandler();
eh->SetErrorHandler(); return;
return; }
}
reporter->InternalWarning("unknown event handler '%s' in SetErrorHandler()", reporter->InternalWarning("unknown event handler '%s' in SetErrorHandler()", std::string(name).c_str());
std::string(name).c_str()); }
}
void EventRegistry::ActivateAllHandlers() void EventRegistry::ActivateAllHandlers() {
{ auto event_names = AllHandlers();
auto event_names = AllHandlers(); for ( const auto& name : event_names ) {
for ( const auto& name : event_names ) if ( auto event = Lookup(name) )
{ event->SetGenerateAlways();
if ( auto event = Lookup(name) ) }
event->SetGenerateAlways(); }
}
}
EventGroupPtr EventRegistry::RegisterGroup(EventGroupKind kind, std::string_view name) EventGroupPtr EventRegistry::RegisterGroup(EventGroupKind kind, std::string_view name) {
{ auto key = std::pair{kind, std::string{name}};
auto key = std::pair{kind, std::string{name}}; if ( const auto& it = event_groups.find(key); it != event_groups.end() )
if ( const auto& it = event_groups.find(key); it != event_groups.end() ) return it->second;
return it->second;
auto group = std::make_shared<EventGroup>(kind, name); auto group = std::make_shared<EventGroup>(kind, name);
return event_groups.emplace(key, group).first->second; return event_groups.emplace(key, group).first->second;
} }
EventGroupPtr EventRegistry::LookupGroup(EventGroupKind kind, std::string_view name) EventGroupPtr EventRegistry::LookupGroup(EventGroupKind kind, std::string_view name) {
{ auto key = std::pair{kind, std::string{name}};
auto key = std::pair{kind, std::string{name}}; if ( const auto& it = event_groups.find(key); it != event_groups.end() )
if ( const auto& it = event_groups.find(key); it != event_groups.end() ) return it->second;
return it->second;
return nullptr; return nullptr;
} }
EventGroup::EventGroup(EventGroupKind kind, std::string_view name) : kind(kind), name(name) { } EventGroup::EventGroup(EventGroupKind kind, std::string_view name) : kind(kind), name(name) {}
// Run through all ScriptFunc instances associated with this group and // Run through all ScriptFunc instances associated with this group and
// update their bodies after a group's enable/disable state has changed. // update their bodies after a group's enable/disable state has changed.
@ -177,50 +154,36 @@ EventGroup::EventGroup(EventGroupKind kind, std::string_view name) : kind(kind),
// EventGroup is private friend with Func, so fiddling with the bodies // EventGroup is private friend with Func, so fiddling with the bodies
// and private members works and keeps the logic out of Func and away // and private members works and keeps the logic out of Func and away
// from the public zeek:: namespace. // from the public zeek:: namespace.
void EventGroup::UpdateFuncBodies() void EventGroup::UpdateFuncBodies() {
{ static auto is_group_disabled = [](const auto& g) { return g->IsDisabled(); };
static auto is_group_disabled = [](const auto& g)
{
return g->IsDisabled();
};
for ( auto& func : funcs ) for ( auto& func : funcs ) {
{ for ( auto& b : func->bodies )
for ( auto& b : func->bodies ) b.disabled = std::any_of(b.groups.cbegin(), b.groups.cend(), is_group_disabled);
b.disabled = std::any_of(b.groups.cbegin(), b.groups.cend(), is_group_disabled);
static auto is_body_enabled = [](const auto& b) static auto is_body_enabled = [](const auto& b) { return ! b.disabled; };
{ func->has_enabled_bodies = std::any_of(func->bodies.cbegin(), func->bodies.cend(), is_body_enabled);
return ! b.disabled; }
}; }
func->has_enabled_bodies = std::any_of(func->bodies.cbegin(), func->bodies.cend(),
is_body_enabled);
}
}
void EventGroup::Enable() void EventGroup::Enable() {
{ if ( enabled )
if ( enabled ) return;
return;
enabled = true; enabled = true;
UpdateFuncBodies(); UpdateFuncBodies();
} }
void EventGroup::Disable() void EventGroup::Disable() {
{ if ( ! enabled )
if ( ! enabled ) return;
return;
enabled = false; enabled = false;
UpdateFuncBodies(); UpdateFuncBodies();
} }
void EventGroup::AddFunc(detail::ScriptFuncPtr f) void EventGroup::AddFunc(detail::ScriptFuncPtr f) { funcs.insert(f); }
{
funcs.insert(f);
}
} // namespace zeek } // namespace zeek

View file

@ -13,15 +13,13 @@
#include "zeek/IntrusivePtr.h" #include "zeek/IntrusivePtr.h"
namespace zeek namespace zeek {
{
// The different kinds of event groups that exist. // The different kinds of event groups that exist.
enum class EventGroupKind enum class EventGroupKind {
{ Attribute,
Attribute, Module,
Module, };
};
class EventGroup; class EventGroup;
class EventHandler; class EventHandler;
@ -30,89 +28,86 @@ class RE_Matcher;
using EventGroupPtr = std::shared_ptr<EventGroup>; using EventGroupPtr = std::shared_ptr<EventGroup>;
namespace detail namespace detail {
{
class ScriptFunc; class ScriptFunc;
using ScriptFuncPtr = zeek::IntrusivePtr<ScriptFunc>; using ScriptFuncPtr = zeek::IntrusivePtr<ScriptFunc>;
} } // namespace detail
// The registry keeps track of all events that we provide or handle. // The registry keeps track of all events that we provide or handle.
class EventRegistry final class EventRegistry final {
{
public: public:
EventRegistry(); EventRegistry();
~EventRegistry() noexcept; ~EventRegistry() noexcept;
/** /**
* Performs a lookup for an existing event handler and returns it * Performs a lookup for an existing event handler and returns it
* if one exists, or else creates one, registers it, and returns it. * if one exists, or else creates one, registers it, and returns it.
* @param name The name of the event handler to lookup/register. * @param name The name of the event handler to lookup/register.
* @param name Whether the registration is coming from a script element. * @param name Whether the registration is coming from a script element.
* @return The event handler. * @return The event handler.
*/ */
EventHandlerPtr Register(std::string_view name, bool is_from_script = false); EventHandlerPtr Register(std::string_view name, bool is_from_script = false);
void Register(EventHandlerPtr handler, bool is_from_script = false); void Register(EventHandlerPtr handler, bool is_from_script = false);
// Return nil if unknown. // Return nil if unknown.
EventHandler* Lookup(std::string_view name); EventHandler* Lookup(std::string_view name);
// True if the given event handler (1) exists, and (2) was registered // True if the given event handler (1) exists, and (2) was registered
// in a non-script context (even if perhaps also registered in a script // in a non-script context (even if perhaps also registered in a script
// context). // context).
bool NotOnlyRegisteredFromScript(std::string_view name); bool NotOnlyRegisteredFromScript(std::string_view name);
// Returns a list of all local handlers that match the given pattern. // Returns a list of all local handlers that match the given pattern.
// Passes ownership of list. // Passes ownership of list.
using string_list = std::vector<std::string>; using string_list = std::vector<std::string>;
string_list Match(RE_Matcher* pattern); string_list Match(RE_Matcher* pattern);
// Marks a handler as handling errors. Error handler will not be called // Marks a handler as handling errors. Error handler will not be called
// recursively to avoid infinite loops in case they trigger an error // recursively to avoid infinite loops in case they trigger an error
// themselves. // themselves.
void SetErrorHandler(std::string_view name); void SetErrorHandler(std::string_view name);
string_list UnusedHandlers(); string_list UnusedHandlers();
string_list UsedHandlers(); string_list UsedHandlers();
string_list AllHandlers(); string_list AllHandlers();
void PrintDebug(); void PrintDebug();
/** /**
* Marks all event handlers as active. * Marks all event handlers as active.
* *
* By default, zeek does not generate (raise) events that have not handled by * By default, zeek does not generate (raise) events that have not handled by
* any scripts. This means that these events will be invisible to a lot of other * any scripts. This means that these events will be invisible to a lot of other
* event handlers - and will not raise :zeek:id:`new_event`. Calling this * event handlers - and will not raise :zeek:id:`new_event`. Calling this
* function will cause all event handlers to be raised. This is likely only * function will cause all event handlers to be raised. This is likely only
* useful for debugging and fuzzing, and likely causes reduced performance. * useful for debugging and fuzzing, and likely causes reduced performance.
*/ */
void ActivateAllHandlers(); void ActivateAllHandlers();
/** /**
* Lookup or register a new event group. * Lookup or register a new event group.
* *
* @return Pointer to the group. * @return Pointer to the group.
*/ */
EventGroupPtr RegisterGroup(EventGroupKind kind, std::string_view name); EventGroupPtr RegisterGroup(EventGroupKind kind, std::string_view name);
/** /**
* Lookup an event group. * Lookup an event group.
* *
* @return Pointer to the group or nil if the group does not exist. * @return Pointer to the group or nil if the group does not exist.
*/ */
EventGroupPtr LookupGroup(EventGroupKind kind, std::string_view name); EventGroupPtr LookupGroup(EventGroupKind kind, std::string_view name);
private: private:
std::map<std::string, std::unique_ptr<EventHandler>, std::less<>> handlers; std::map<std::string, std::unique_ptr<EventHandler>, std::less<>> handlers;
// Tracks whether a given event handler was registered in a // Tracks whether a given event handler was registered in a
// non-script context. // non-script context.
std::unordered_set<std::string> not_only_from_script; std::unordered_set<std::string> not_only_from_script;
// Map event groups identified by kind and name to their instances. // Map event groups identified by kind and name to their instances.
std::map<std::pair<EventGroupKind, std::string>, std::shared_ptr<EventGroup>, std::less<>> std::map<std::pair<EventGroupKind, std::string>, std::shared_ptr<EventGroup>, std::less<>> event_groups;
event_groups; };
};
/** /**
* Event group. * Event group.
@ -135,45 +130,44 @@ private:
* bodies of the tracked ScriptFuncs and updates them to reflect the current * bodies of the tracked ScriptFuncs and updates them to reflect the current
* group state. * group state.
*/ */
class EventGroup final class EventGroup final {
{
public: public:
EventGroup(EventGroupKind kind, std::string_view name); EventGroup(EventGroupKind kind, std::string_view name);
~EventGroup() noexcept = default; ~EventGroup() noexcept = default;
EventGroup(const EventGroup& g) = delete; EventGroup(const EventGroup& g) = delete;
EventGroup& operator=(const EventGroup&) = delete; EventGroup& operator=(const EventGroup&) = delete;
/** /**
* Enable this event group and update all event handlers associated with it. * Enable this event group and update all event handlers associated with it.
*/ */
void Enable(); void Enable();
/** /**
* Disable this event group and update all event handlers associated with it. * Disable this event group and update all event handlers associated with it.
*/ */
void Disable(); void Disable();
/** /**
* @return True if this group is disabled else false. * @return True if this group is disabled else false.
*/ */
bool IsDisabled() { return ! enabled; } bool IsDisabled() { return ! enabled; }
/** /**
* Add a function to this group that may contain matching bodies. * Add a function to this group that may contain matching bodies.
* *
* @param f Pointer to the function to track. * @param f Pointer to the function to track.
*/ */
void AddFunc(detail::ScriptFuncPtr f); void AddFunc(detail::ScriptFuncPtr f);
private: private:
void UpdateFuncBodies(); void UpdateFuncBodies();
EventGroupKind kind; EventGroupKind kind;
std::string name; std::string name;
bool enabled = true; bool enabled = true;
std::unordered_set<detail::ScriptFuncPtr> funcs; std::unordered_set<detail::ScriptFuncPtr> funcs;
}; };
extern EventRegistry* event_registry; extern EventRegistry* event_registry;
} // namespace zeek } // namespace zeek

File diff suppressed because it is too large Load diff

View file

@ -5,37 +5,35 @@
#include "zeek/Val.h" #include "zeek/Val.h"
#include "zeek/ZeekArgs.h" #include "zeek/ZeekArgs.h"
namespace zeek::detail namespace zeek::detail {
{
class ValTrace; class ValTrace;
class ValTraceMgr; class ValTraceMgr;
// Abstract class for capturing a single difference between two script-level // Abstract class for capturing a single difference between two script-level
// values. Includes notions of inserting, changing, or deleting a value. // values. Includes notions of inserting, changing, or deleting a value.
class ValDelta class ValDelta {
{
public: public:
ValDelta(const ValTrace* _vt) : vt(_vt) { } ValDelta(const ValTrace* _vt) : vt(_vt) {}
virtual ~ValDelta() { } virtual ~ValDelta() {}
// Return a string that performs the update operation, expressed // Return a string that performs the update operation, expressed
// as Zeek scripting. Does not include a terminating semicolon. // as Zeek scripting. Does not include a terminating semicolon.
virtual std::string Generate(ValTraceMgr* vtm) const; virtual std::string Generate(ValTraceMgr* vtm) const;
// Whether the generated string needs the affected value to // Whether the generated string needs the affected value to
// explicitly appear on the left-hand-side. Note that this // explicitly appear on the left-hand-side. Note that this
// might not be as a simple "LHS = RHS" assignment, but instead // might not be as a simple "LHS = RHS" assignment, but instead
// as "LHS$field = RHS" or "LHS[index] = RHS". // as "LHS$field = RHS" or "LHS[index] = RHS".
// //
// Returns false for generated strings like "delete LHS[index]". // Returns false for generated strings like "delete LHS[index]".
virtual bool NeedsLHS() const { return true; } virtual bool NeedsLHS() const { return true; }
const ValTrace* GetValTrace() const { return vt; } const ValTrace* GetValTrace() const { return vt; }
protected: protected:
const ValTrace* vt; const ValTrace* vt;
}; };
using DeltaVector = std::vector<std::unique_ptr<ValDelta>>; using DeltaVector = std::vector<std::unique_ptr<ValDelta>>;
@ -43,464 +41,426 @@ using DeltaVector = std::vector<std::unique_ptr<ValDelta>>;
// For non-aggregates, this is simply the Val object, but for aggregates // For non-aggregates, this is simply the Val object, but for aggregates
// it is (recursively) each of the sub-elements, in a manner that can then // it is (recursively) each of the sub-elements, in a manner that can then
// be readily compared against future instances. // be readily compared against future instances.
class ValTrace class ValTrace {
{
public: public:
ValTrace(const ValPtr& v); ValTrace(const ValPtr& v);
~ValTrace() = default; ~ValTrace() = default;
const ValPtr& GetVal() const { return v; } const ValPtr& GetVal() const { return v; }
const TypePtr& GetType() const { return t; } const TypePtr& GetType() const { return t; }
const auto& GetElems() const { return elems; } const auto& GetElems() const { return elems; }
// Returns true if this trace and the given one represent the // Returns true if this trace and the given one represent the
// same underlying value. Can involve subelement-by-subelement // same underlying value. Can involve subelement-by-subelement
// (recursive) comparisons. // (recursive) comparisons.
bool operator==(const ValTrace& vt) const; bool operator==(const ValTrace& vt) const;
bool operator!=(const ValTrace& vt) const { return ! ((*this) == vt); } bool operator!=(const ValTrace& vt) const { return ! ((*this) == vt); }
// Computes the deltas between a previous ValTrace and this one. // Computes the deltas between a previous ValTrace and this one.
// If "prev" is nil then we're creating this value from scratch // If "prev" is nil then we're creating this value from scratch
// (though if it's an aggregate, we may reuse existing values // (though if it's an aggregate, we may reuse existing values
// for some of its components). // for some of its components).
// //
// Returns the accumulated differences in "deltas". If on return // Returns the accumulated differences in "deltas". If on return
// nothing was added to "deltas" then the two ValTrace's are equivalent // nothing was added to "deltas" then the two ValTrace's are equivalent
// (no changes between them). // (no changes between them).
void ComputeDelta(const ValTrace* prev, DeltaVector& deltas) const; void ComputeDelta(const ValTrace* prev, DeltaVector& deltas) const;
private: private:
// Methods for tracing different types of aggregate values. // Methods for tracing different types of aggregate values.
void TraceList(const ListValPtr& lv); void TraceList(const ListValPtr& lv);
void TraceRecord(const RecordValPtr& rv); void TraceRecord(const RecordValPtr& rv);
void TraceTable(const TableValPtr& tv); void TraceTable(const TableValPtr& tv);
void TraceVector(const VectorValPtr& vv); void TraceVector(const VectorValPtr& vv);
// Predicates for comparing different types of aggregates for equality. // Predicates for comparing different types of aggregates for equality.
bool SameList(const ValTrace& vt) const; bool SameList(const ValTrace& vt) const;
bool SameRecord(const ValTrace& vt) const; bool SameRecord(const ValTrace& vt) const;
bool SameTable(const ValTrace& vt) const; bool SameTable(const ValTrace& vt) const;
bool SameVector(const ValTrace& vt) const; bool SameVector(const ValTrace& vt) const;
// Helper function that knows about the internal vector-of-subelements // Helper function that knows about the internal vector-of-subelements
// we use for aggregates. // we use for aggregates.
bool SameElems(const ValTrace& vt) const; bool SameElems(const ValTrace& vt) const;
// True if this value is a singleton and it's the same value as // True if this value is a singleton and it's the same value as
// represented in "vt". // represented in "vt".
bool SameSingleton(const ValTrace& vt) const; bool SameSingleton(const ValTrace& vt) const;
// Add to "deltas" the differences needed to turn a previous instance // Add to "deltas" the differences needed to turn a previous instance
// of the given type of aggregate to the current instance. // of the given type of aggregate to the current instance.
void ComputeRecordDelta(const ValTrace* prev, DeltaVector& deltas) const; void ComputeRecordDelta(const ValTrace* prev, DeltaVector& deltas) const;
void ComputeTableDelta(const ValTrace* prev, DeltaVector& deltas) const; void ComputeTableDelta(const ValTrace* prev, DeltaVector& deltas) const;
void ComputeVectorDelta(const ValTrace* prev, DeltaVector& deltas) const; void ComputeVectorDelta(const ValTrace* prev, DeltaVector& deltas) const;
// Holds sub-elements for aggregates. // Holds sub-elements for aggregates.
std::vector<std::shared_ptr<ValTrace>> elems; std::vector<std::shared_ptr<ValTrace>> elems;
// A parallel vector used for the yield values of tables. // A parallel vector used for the yield values of tables.
std::vector<std::shared_ptr<ValTrace>> elems2; std::vector<std::shared_ptr<ValTrace>> elems2;
ValPtr v; ValPtr v;
TypePtr t; // v's type, for convenience TypePtr t; // v's type, for convenience
}; };
// Captures the basic notion of a new, non-equivalent value being assigned. // Captures the basic notion of a new, non-equivalent value being assigned.
class DeltaReplaceValue : public ValDelta class DeltaReplaceValue : public ValDelta {
{
public: public:
DeltaReplaceValue(const ValTrace* _vt, ValPtr _new_val) DeltaReplaceValue(const ValTrace* _vt, ValPtr _new_val) : ValDelta(_vt), new_val(std::move(_new_val)) {}
: ValDelta(_vt), new_val(std::move(_new_val))
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
private: private:
ValPtr new_val; ValPtr new_val;
}; };
// Captures the notion of setting a record field. // Captures the notion of setting a record field.
class DeltaSetField : public ValDelta class DeltaSetField : public ValDelta {
{
public: public:
DeltaSetField(const ValTrace* _vt, int _field, ValPtr _new_val) DeltaSetField(const ValTrace* _vt, int _field, ValPtr _new_val)
: ValDelta(_vt), field(_field), new_val(std::move(_new_val)) : ValDelta(_vt), field(_field), new_val(std::move(_new_val)) {}
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
private: private:
int field; int field;
ValPtr new_val; ValPtr new_val;
}; };
// Captures the notion of deleting a record field. // Captures the notion of deleting a record field.
class DeltaRemoveField : public ValDelta class DeltaRemoveField : public ValDelta {
{
public: public:
DeltaRemoveField(const ValTrace* _vt, int _field) : ValDelta(_vt), field(_field) { } DeltaRemoveField(const ValTrace* _vt, int _field) : ValDelta(_vt), field(_field) {}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
bool NeedsLHS() const override { return false; } bool NeedsLHS() const override { return false; }
private: private:
int field; int field;
}; };
// Captures the notion of creating a record from scratch. // Captures the notion of creating a record from scratch.
class DeltaRecordCreate : public ValDelta class DeltaRecordCreate : public ValDelta {
{
public: public:
DeltaRecordCreate(const ValTrace* _vt) : ValDelta(_vt) { } DeltaRecordCreate(const ValTrace* _vt) : ValDelta(_vt) {}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
}; };
// Captures the notion of adding an element to a set. Use DeltaRemoveTableEntry to // Captures the notion of adding an element to a set. Use DeltaRemoveTableEntry to
// delete values. // delete values.
class DeltaSetSetEntry : public ValDelta class DeltaSetSetEntry : public ValDelta {
{
public: public:
DeltaSetSetEntry(const ValTrace* _vt, ValPtr _index) : ValDelta(_vt), index(_index) { } DeltaSetSetEntry(const ValTrace* _vt, ValPtr _index) : ValDelta(_vt), index(_index) {}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
bool NeedsLHS() const override { return false; } bool NeedsLHS() const override { return false; }
private: private:
ValPtr index; ValPtr index;
}; };
// Captures the notion of setting a table entry (which includes both changing // Captures the notion of setting a table entry (which includes both changing
// an existing one and adding a new one). Use DeltaRemoveTableEntry to // an existing one and adding a new one). Use DeltaRemoveTableEntry to
// delete values. // delete values.
class DeltaSetTableEntry : public ValDelta class DeltaSetTableEntry : public ValDelta {
{
public: public:
DeltaSetTableEntry(const ValTrace* _vt, ValPtr _index, ValPtr _new_val) DeltaSetTableEntry(const ValTrace* _vt, ValPtr _index, ValPtr _new_val)
: ValDelta(_vt), index(_index), new_val(std::move(_new_val)) : ValDelta(_vt), index(_index), new_val(std::move(_new_val)) {}
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
private: private:
ValPtr index; ValPtr index;
ValPtr new_val; ValPtr new_val;
}; };
// Captures the notion of removing a table/set entry. // Captures the notion of removing a table/set entry.
class DeltaRemoveTableEntry : public ValDelta class DeltaRemoveTableEntry : public ValDelta {
{
public: public:
DeltaRemoveTableEntry(const ValTrace* _vt, ValPtr _index) DeltaRemoveTableEntry(const ValTrace* _vt, ValPtr _index) : ValDelta(_vt), index(std::move(_index)) {}
: ValDelta(_vt), index(std::move(_index))
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
bool NeedsLHS() const override { return false; } bool NeedsLHS() const override { return false; }
private: private:
ValPtr index; ValPtr index;
}; };
// Captures the notion of creating a set from scratch. // Captures the notion of creating a set from scratch.
class DeltaSetCreate : public ValDelta class DeltaSetCreate : public ValDelta {
{
public: public:
DeltaSetCreate(const ValTrace* _vt) : ValDelta(_vt) { } DeltaSetCreate(const ValTrace* _vt) : ValDelta(_vt) {}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
}; };
// Captures the notion of creating a table from scratch. // Captures the notion of creating a table from scratch.
class DeltaTableCreate : public ValDelta class DeltaTableCreate : public ValDelta {
{
public: public:
DeltaTableCreate(const ValTrace* _vt) : ValDelta(_vt) { } DeltaTableCreate(const ValTrace* _vt) : ValDelta(_vt) {}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
}; };
// Captures the notion of changing an element of a vector. // Captures the notion of changing an element of a vector.
class DeltaVectorSet : public ValDelta class DeltaVectorSet : public ValDelta {
{
public: public:
DeltaVectorSet(const ValTrace* _vt, int _index, ValPtr _elem) DeltaVectorSet(const ValTrace* _vt, int _index, ValPtr _elem)
: ValDelta(_vt), index(_index), elem(std::move(_elem)) : ValDelta(_vt), index(_index), elem(std::move(_elem)) {}
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
private: private:
int index; int index;
ValPtr elem; ValPtr elem;
}; };
// Captures the notion of adding an entry to the end of a vector. // Captures the notion of adding an entry to the end of a vector.
class DeltaVectorAppend : public ValDelta class DeltaVectorAppend : public ValDelta {
{
public: public:
DeltaVectorAppend(const ValTrace* _vt, int _index, ValPtr _elem) DeltaVectorAppend(const ValTrace* _vt, int _index, ValPtr _elem)
: ValDelta(_vt), index(_index), elem(std::move(_elem)) : ValDelta(_vt), index(_index), elem(std::move(_elem)) {}
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
private: private:
int index; int index;
ValPtr elem; ValPtr elem;
}; };
// Captures the notion of replacing a vector wholesale. // Captures the notion of replacing a vector wholesale.
class DeltaVectorCreate : public ValDelta class DeltaVectorCreate : public ValDelta {
{
public: public:
DeltaVectorCreate(const ValTrace* _vt) : ValDelta(_vt) { } DeltaVectorCreate(const ValTrace* _vt) : ValDelta(_vt) {}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
}; };
// Captures the notion of creating a value with an unsupported type // Captures the notion of creating a value with an unsupported type
// (like "opaque"). // (like "opaque").
class DeltaUnsupportedCreate : public ValDelta class DeltaUnsupportedCreate : public ValDelta {
{
public: public:
DeltaUnsupportedCreate(const ValTrace* _vt) : ValDelta(_vt) { } DeltaUnsupportedCreate(const ValTrace* _vt) : ValDelta(_vt) {}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
}; };
// Manages the changes to (or creation of) a variable used to represent // Manages the changes to (or creation of) a variable used to represent
// a value. // a value.
class DeltaGen class DeltaGen {
{
public: public:
DeltaGen(ValPtr _val, std::string _rhs, bool _needs_lhs, bool _is_first_def) DeltaGen(ValPtr _val, std::string _rhs, bool _needs_lhs, bool _is_first_def)
: val(std::move(_val)), rhs(std::move(_rhs)), needs_lhs(_needs_lhs), : val(std::move(_val)), rhs(std::move(_rhs)), needs_lhs(_needs_lhs), is_first_def(_is_first_def) {}
is_first_def(_is_first_def)
{
}
const ValPtr& GetVal() const { return val; } const ValPtr& GetVal() const { return val; }
const std::string& RHS() const { return rhs; } const std::string& RHS() const { return rhs; }
bool NeedsLHS() const { return needs_lhs; } bool NeedsLHS() const { return needs_lhs; }
bool IsFirstDef() const { return is_first_def; } bool IsFirstDef() const { return is_first_def; }
private: private:
ValPtr val; ValPtr val;
// The expression to set the variable to. // The expression to set the variable to.
std::string rhs; std::string rhs;
// Whether that expression needs the variable explicitly provides // Whether that expression needs the variable explicitly provides
// on the lefthand side. // on the lefthand side.
bool needs_lhs; bool needs_lhs;
// Whether this is the first definition of the variable (in which // Whether this is the first definition of the variable (in which
// case we also need to declare the variable). // case we also need to declare the variable).
bool is_first_def; bool is_first_def;
}; };
using DeltaGenVec = std::vector<DeltaGen>; using DeltaGenVec = std::vector<DeltaGen>;
// Tracks a single event. // Tracks a single event.
class EventTrace class EventTrace {
{
public: public:
// Constructed in terms of the associated script function, "network // Constructed in terms of the associated script function, "network
// time" when the event occurred, and the position of this event // time" when the event occurred, and the position of this event
// within all of those being traced. // within all of those being traced.
EventTrace(const ScriptFunc* _ev, double _nt, size_t event_num); EventTrace(const ScriptFunc* _ev, double _nt, size_t event_num);
// Sets a string representation of the arguments (values) being // Sets a string representation of the arguments (values) being
// passed to the event. // passed to the event.
void SetArgs(std::string _args) { args = std::move(_args); } void SetArgs(std::string _args) { args = std::move(_args); }
// Adds to the trace an update for the given value. // Adds to the trace an update for the given value.
void AddDelta(ValPtr val, std::string rhs, bool needs_lhs, bool is_first_def) void AddDelta(ValPtr val, std::string rhs, bool needs_lhs, bool is_first_def) {
{ auto& d = is_post ? post_deltas : deltas;
auto& d = is_post ? post_deltas : deltas; d.emplace_back(DeltaGen(val, rhs, needs_lhs, is_first_def));
d.emplace_back(DeltaGen(val, rhs, needs_lhs, is_first_def)); }
}
// Initially we analyze events pre-execution. When this flag // Initially we analyze events pre-execution. When this flag
// is set, we switch to instead analyzing post-execution. The // is set, we switch to instead analyzing post-execution. The
// difference allows us to annotate the output with "# from script" // difference allows us to annotate the output with "# from script"
// comments that flag changes created by script execution rather // comments that flag changes created by script execution rather
// than event engine activity. // than event engine activity.
void SetDoingPost() { is_post = true; } void SetDoingPost() { is_post = true; }
const char* GetName() const { return name.c_str(); } const char* GetName() const { return name.c_str(); }
// Generates an internal event handler that sets up the values // Generates an internal event handler that sets up the values
// associated with the traced event, followed by queueing the traced // associated with the traced event, followed by queueing the traced
// event, and then queueing the successor internal event handler, // event, and then queueing the successor internal event handler,
// if any. // if any.
// //
// "predecessor", if non-nil, gives the event that came just before // "predecessor", if non-nil, gives the event that came just before
// this one (used for "# from script" annotations"). "successor", // this one (used for "# from script" annotations"). "successor",
// if not empty, gives the name of the successor internal event. // if not empty, gives the name of the successor internal event.
void Generate(FILE* f, ValTraceMgr& vtm, const EventTrace* predecessor, void Generate(FILE* f, ValTraceMgr& vtm, const EventTrace* predecessor, std::string successor) const;
std::string successor) const;
private: private:
// "dvec" is either just our deltas, or the "post_deltas" of our // "dvec" is either just our deltas, or the "post_deltas" of our
// predecessor plus our deltas. // predecessor plus our deltas.
void Generate(FILE* f, ValTraceMgr& vtm, const DeltaGenVec& dvec, std::string successor, void Generate(FILE* f, ValTraceMgr& vtm, const DeltaGenVec& dvec, std::string successor, int num_pre = 0) const;
int num_pre = 0) const;
const ScriptFunc* ev; const ScriptFunc* ev;
double nt; double nt;
bool is_post = false; bool is_post = false;
// The deltas needed to construct the values associated with this // The deltas needed to construct the values associated with this
// event prior to its execution. // event prior to its execution.
DeltaGenVec deltas; DeltaGenVec deltas;
// The deltas capturing any changes to the original values as induced // The deltas capturing any changes to the original values as induced
// by executing its event handlers. // by executing its event handlers.
DeltaGenVec post_deltas; DeltaGenVec post_deltas;
// The event's name and a string representation of its arguments. // The event's name and a string representation of its arguments.
std::string name; std::string name;
std::string args; std::string args;
}; };
// Manages all of the events and associated values seen during the execution. // Manages all of the events and associated values seen during the execution.
class ValTraceMgr class ValTraceMgr {
{
public: public:
// Invoked to trace a new event with the associated arguments. // Invoked to trace a new event with the associated arguments.
void TraceEventValues(std::shared_ptr<EventTrace> et, const zeek::Args* args); void TraceEventValues(std::shared_ptr<EventTrace> et, const zeek::Args* args);
// Invoked when the current event finishes execution. The arguments // Invoked when the current event finishes execution. The arguments
// are again provided, for convenience so we don't have to remember // are again provided, for convenience so we don't have to remember
// them from the previous method. // them from the previous method.
void FinishCurrentEvent(const zeek::Args* args); void FinishCurrentEvent(const zeek::Args* args);
// Returns the name of the script variable associated with the // Returns the name of the script variable associated with the
// given value. // given value.
const std::string& ValName(const ValPtr& v); const std::string& ValName(const ValPtr& v);
const std::string& ValName(const ValTrace* vt) { return ValName(vt->GetVal()); } const std::string& ValName(const ValTrace* vt) { return ValName(vt->GetVal()); }
// Returns true if the script variable associated with the given value // Returns true if the script variable associated with the given value
// needs to be global (because it's used across multiple events). // needs to be global (because it's used across multiple events).
bool IsGlobal(const ValPtr& v) const { return globals.count(v.get()) > 0; } bool IsGlobal(const ValPtr& v) const { return globals.count(v.get()) > 0; }
// Returns or sets the "base time" from which eligible times are // Returns or sets the "base time" from which eligible times are
// transformed into offsets rather than maintained as absolute // transformed into offsets rather than maintained as absolute
// values. // values.
double GetBaseTime() const { return base_time; } double GetBaseTime() const { return base_time; }
void SetBaseTime(double bt) { base_time = bt; } void SetBaseTime(double bt) { base_time = bt; }
// Returns a Zeek script representation of the given "time" value. // Returns a Zeek script representation of the given "time" value.
// This might be relative to base_time or might be absolute. // This might be relative to base_time or might be absolute.
std::string TimeConstant(double t); std::string TimeConstant(double t);
// Returns the array of per-type-tag constants. // Returns the array of per-type-tag constants.
const auto& GetConstants() const { return constants; } const auto& GetConstants() const { return constants; }
private: private:
// Traces the given value, which we may-or-may-not have seen before. // Traces the given value, which we may-or-may-not have seen before.
void AddVal(ValPtr v); void AddVal(ValPtr v);
// Creates a new value, associating a script variable with it. // Creates a new value, associating a script variable with it.
void NewVal(ValPtr v); void NewVal(ValPtr v);
// Called when the given value is used in an expression that sets // Called when the given value is used in an expression that sets
// or updates another value. This lets us track which values are // or updates another value. This lets us track which values are
// used across multiple events, and thus need to be global. // used across multiple events, and thus need to be global.
void ValUsed(const ValPtr& v); void ValUsed(const ValPtr& v);
// Compares the two value traces to build up deltas capturing // Compares the two value traces to build up deltas capturing
// the difference between the previous one and the current one. // the difference between the previous one and the current one.
void AssessChange(const ValTrace* vt, const ValTrace* prev_vt); void AssessChange(const ValTrace* vt, const ValTrace* prev_vt);
// Create and track a script variable associated with the given value. // Create and track a script variable associated with the given value.
void TrackVar(const Val* vt); void TrackVar(const Val* vt);
// Generates a name for a value. // Generates a name for a value.
std::string GenValName(const ValPtr& v); std::string GenValName(const ValPtr& v);
// True if the given value is an unspecified (and empty set, // True if the given value is an unspecified (and empty set,
// table, or vector appearing as a constant rather than an // table, or vector appearing as a constant rather than an
// already-typed value). // already-typed value).
bool IsUnspecifiedAggregate(const ValPtr& v) const; bool IsUnspecifiedAggregate(const ValPtr& v) const;
// True if the given value has an unsupported type. // True if the given value has an unsupported type.
bool IsUnsupported(const Val* v) const; bool IsUnsupported(const Val* v) const;
// Maps values to their associated traces. // Maps values to their associated traces.
std::unordered_map<const Val*, std::shared_ptr<ValTrace>> val_map; std::unordered_map<const Val*, std::shared_ptr<ValTrace>> val_map;
// Maps values to the "names" we associated with them. For simple // Maps values to the "names" we associated with them. For simple
// values, the name is just a Zeek script constant. For aggregates, // values, the name is just a Zeek script constant. For aggregates,
// it's a dedicated script variable. // it's a dedicated script variable.
std::unordered_map<const Val*, std::string> val_names; std::unordered_map<const Val*, std::string> val_names;
int num_vars = 0; // the number of dedicated script variables int num_vars = 0; // the number of dedicated script variables
// Tracks which values we've processed up through the preceding event. // Tracks which values we've processed up through the preceding event.
// Any re-use we then see for the current event (via a ValUsed() call) // Any re-use we then see for the current event (via a ValUsed() call)
// then tells us that the value is used across events, and thus its // then tells us that the value is used across events, and thus its
// associated script variable needs to be global. // associated script variable needs to be global.
std::unordered_set<const Val*> processed_vals; std::unordered_set<const Val*> processed_vals;
// Tracks which values have associated script variables that need // Tracks which values have associated script variables that need
// to be global. // to be global.
std::unordered_set<const Val*> globals; std::unordered_set<const Val*> globals;
// Indexed by type tag, stores an ordered set of all of the distinct // Indexed by type tag, stores an ordered set of all of the distinct
// representations of constants of that type. // representations of constants of that type.
std::array<std::set<std::string>, NUM_TYPES> constants; std::array<std::set<std::string>, NUM_TYPES> constants;
// If non-zero, then we've established a "base time" and will report // If non-zero, then we've established a "base time" and will report
// time constants as offsets from it (when reasonable, i.e., no // time constants as offsets from it (when reasonable, i.e., no
// negative offsets, and base_time can't be too close to 0.0). // negative offsets, and base_time can't be too close to 0.0).
double base_time = 0.0; double base_time = 0.0;
// The event we're currently tracing. // The event we're currently tracing.
std::shared_ptr<EventTrace> curr_ev; std::shared_ptr<EventTrace> curr_ev;
// Hang on to values we're tracking to make sure the pointers don't // Hang on to values we're tracking to make sure the pointers don't
// get reused when the main use of the value ends. // get reused when the main use of the value ends.
std::vector<ValPtr> vals; std::vector<ValPtr> vals;
}; };
// Manages tracing of all of the events seen during execution, including // Manages tracing of all of the events seen during execution, including
// the final generation of the trace script. // the final generation of the trace script.
class EventTraceMgr class EventTraceMgr {
{
public: public:
EventTraceMgr(const std::string& trace_file); EventTraceMgr(const std::string& trace_file);
~EventTraceMgr(); ~EventTraceMgr();
// Called at the beginning of invoking an event's handlers. // Called at the beginning of invoking an event's handlers.
void StartEvent(const ScriptFunc* ev, const zeek::Args* args); void StartEvent(const ScriptFunc* ev, const zeek::Args* args);
// Called after finishing with invoking an event's handlers. // Called after finishing with invoking an event's handlers.
void EndEvent(const ScriptFunc* ev, const zeek::Args* args); void EndEvent(const ScriptFunc* ev, const zeek::Args* args);
// Used to track events generated at script-level. // Used to track events generated at script-level.
void ScriptEventQueued(const EventHandlerPtr& h); void ScriptEventQueued(const EventHandlerPtr& h);
private: private:
FILE* f = nullptr; FILE* f = nullptr;
ValTraceMgr vtm; ValTraceMgr vtm;
// All of the events we've traced so far. // All of the events we've traced so far.
std::vector<std::shared_ptr<EventTrace>> events; std::vector<std::shared_ptr<EventTrace>> events;
// The names of all of the script events that have been generated. // The names of all of the script events that have been generated.
std::unordered_set<std::string> script_events; std::unordered_set<std::string> script_events;
}; };
// If non-nil then we're doing event tracing. // If non-nil then we're doing event tracing.
extern std::unique_ptr<EventTraceMgr> etm; extern std::unique_ptr<EventTraceMgr> etm;
} // namespace zeek::detail } // namespace zeek::detail

File diff suppressed because it is too large Load diff

2321
src/Expr.h

File diff suppressed because it is too large Load diff

View file

@ -31,331 +31,296 @@
#include "zeek/Type.h" #include "zeek/Type.h"
#include "zeek/Var.h" #include "zeek/Var.h"
namespace zeek namespace zeek {
{
std::list<std::pair<std::string, File*>> File::open_files; std::list<std::pair<std::string, File*>> File::open_files;
// Maximizes the number of open file descriptors. // Maximizes the number of open file descriptors.
static void maximize_num_fds() static void maximize_num_fds() {
{ struct rlimit rl;
struct rlimit rl; if ( getrlimit(RLIMIT_NOFILE, &rl) < 0 )
if ( getrlimit(RLIMIT_NOFILE, &rl) < 0 ) reporter->FatalError("maximize_num_fds(): getrlimit failed");
reporter->FatalError("maximize_num_fds(): getrlimit failed");
if ( rl.rlim_max == RLIM_INFINITY ) if ( rl.rlim_max == RLIM_INFINITY ) {
{ // Don't try raising the current limit.
// Don't try raising the current limit. return;
return; }
}
// See if we can raise the current to the maximum. // See if we can raise the current to the maximum.
rl.rlim_cur = rl.rlim_max; rl.rlim_cur = rl.rlim_max;
if ( setrlimit(RLIMIT_NOFILE, &rl) < 0 ) if ( setrlimit(RLIMIT_NOFILE, &rl) < 0 )
reporter->FatalError("maximize_num_fds(): setrlimit failed"); reporter->FatalError("maximize_num_fds(): setrlimit failed");
} }
File::File(FILE* arg_f) File::File(FILE* arg_f) {
{ Init();
Init(); f = arg_f;
f = arg_f; name = access = nullptr;
name = access = nullptr; t = base_type(TYPE_STRING);
t = base_type(TYPE_STRING); is_open = (f != nullptr);
is_open = (f != nullptr); }
}
File::File(FILE* arg_f, const char* arg_name, const char* arg_access) File::File(FILE* arg_f, const char* arg_name, const char* arg_access) {
{ Init();
Init(); f = arg_f;
f = arg_f; name = util::copy_string(arg_name);
name = util::copy_string(arg_name); access = util::copy_string(arg_access);
access = util::copy_string(arg_access); t = base_type(TYPE_STRING);
t = base_type(TYPE_STRING); is_open = (f != nullptr);
is_open = (f != nullptr); }
}
File::File(const char* arg_name, const char* arg_access) File::File(const char* arg_name, const char* arg_access) {
{ Init();
Init(); f = nullptr;
f = nullptr; name = util::copy_string(arg_name);
name = util::copy_string(arg_name); access = util::copy_string(arg_access);
access = util::copy_string(arg_access); t = base_type(TYPE_STRING);
t = base_type(TYPE_STRING);
if ( util::streq(name, "/dev/stdin") ) if ( util::streq(name, "/dev/stdin") )
f = stdin; f = stdin;
else if ( util::streq(name, "/dev/stdout") ) else if ( util::streq(name, "/dev/stdout") )
f = stdout; f = stdout;
else if ( util::streq(name, "/dev/stderr") ) else if ( util::streq(name, "/dev/stderr") )
f = stderr; f = stderr;
if ( f ) if ( f )
is_open = true; is_open = true;
else if ( ! Open() ) else if ( ! Open() ) {
{ reporter->Error("cannot open %s: %s", name, strerror(errno));
reporter->Error("cannot open %s: %s", name, strerror(errno)); is_open = false;
is_open = false; }
} }
}
const char* File::Name() const const char* File::Name() const {
{ if ( name )
if ( name ) return name;
return name;
if ( f == stdin ) if ( f == stdin )
return "/dev/stdin"; return "/dev/stdin";
if ( f == stdout ) if ( f == stdout )
return "/dev/stdout"; return "/dev/stdout";
if ( f == stderr ) if ( f == stderr )
return "/dev/stderr"; return "/dev/stderr";
return nullptr; return nullptr;
} }
bool File::Open(FILE* file, const char* mode) bool File::Open(FILE* file, const char* mode) {
{ static bool fds_maximized = false;
static bool fds_maximized = false; open_time = run_state::network_time ? run_state::network_time : util::current_time();
open_time = run_state::network_time ? run_state::network_time : util::current_time();
if ( ! fds_maximized ) if ( ! fds_maximized ) {
{ // Haven't initialized yet.
// Haven't initialized yet. maximize_num_fds();
maximize_num_fds(); fds_maximized = true;
fds_maximized = true; }
}
f = file; f = file;
if ( ! f ) if ( ! f ) {
{ if ( ! mode )
if ( ! mode ) f = fopen(name, access);
f = fopen(name, access); else
else f = fopen(name, mode);
f = fopen(name, mode); }
}
SetBuf(buffered); SetBuf(buffered);
if ( ! f ) if ( ! f ) {
{ is_open = false;
is_open = false; return false;
return false; }
}
is_open = true; is_open = true;
open_files.emplace_back(name, this); open_files.emplace_back(name, this);
RaiseOpenEvent(); RaiseOpenEvent();
return true; return true;
} }
File::~File() File::~File() {
{ Close();
Close(); Unref(attrs);
Unref(attrs);
delete[] name; delete[] name;
delete[] access; delete[] access;
#ifdef USE_PERFTOOLS_DEBUG #ifdef USE_PERFTOOLS_DEBUG
heap_checker->UnIgnoreObject(this); heap_checker->UnIgnoreObject(this);
#endif #endif
} }
void File::Init() void File::Init() {
{ open_time = 0;
open_time = 0; is_open = false;
is_open = false; attrs = nullptr;
attrs = nullptr; buffered = true;
buffered = true; raw_output = false;
raw_output = false;
#ifdef USE_PERFTOOLS_DEBUG #ifdef USE_PERFTOOLS_DEBUG
heap_checker->IgnoreObject(this); heap_checker->IgnoreObject(this);
#endif #endif
} }
FILE* File::FileHandle() FILE* File::FileHandle() { return f; }
{
return f;
}
FILE* File::Seek(long new_position) FILE* File::Seek(long new_position) {
{ if ( ! FileHandle() )
if ( ! FileHandle() ) return nullptr;
return nullptr;
if ( fseek(f, new_position, SEEK_SET) < 0 ) if ( fseek(f, new_position, SEEK_SET) < 0 )
reporter->Error("seek failed"); reporter->Error("seek failed");
return f; return f;
} }
void File::SetBuf(bool arg_buffered) void File::SetBuf(bool arg_buffered) {
{ if ( ! f )
if ( ! f ) return;
return;
if ( util::detail::setvbuf(f, NULL, arg_buffered ? _IOFBF : _IOLBF, 0) != 0 ) if ( util::detail::setvbuf(f, NULL, arg_buffered ? _IOFBF : _IOLBF, 0) != 0 )
reporter->Error("setvbuf failed"); reporter->Error("setvbuf failed");
buffered = arg_buffered; buffered = arg_buffered;
} }
bool File::Close() bool File::Close() {
{ if ( ! is_open )
if ( ! is_open ) return true;
return true;
// Do not close stdin/stdout/stderr. // Do not close stdin/stdout/stderr.
if ( f == stdin || f == stdout || f == stderr ) if ( f == stdin || f == stdout || f == stderr )
return false; return false;
if ( ! f ) if ( ! f )
return false; return false;
fclose(f); fclose(f);
f = nullptr; f = nullptr;
open_time = 0; open_time = 0;
is_open = false; is_open = false;
Unlink(); Unlink();
return true; return true;
} }
void File::Unlink() void File::Unlink() {
{ for ( auto it = open_files.begin(); it != open_files.end(); ++it ) {
for ( auto it = open_files.begin(); it != open_files.end(); ++it ) if ( (*it).second == this ) {
{ open_files.erase(it);
if ( (*it).second == this ) return;
{ }
open_files.erase(it); }
return; }
}
}
}
void File::Describe(ODesc* d) const void File::Describe(ODesc* d) const {
{ d->AddSP("file");
d->AddSP("file");
if ( name ) if ( name ) {
{ d->Add("\"");
d->Add("\""); d->Add(name);
d->Add(name); d->AddSP("\"");
d->AddSP("\""); }
}
d->AddSP("of"); d->AddSP("of");
if ( t ) if ( t )
t->Describe(d); t->Describe(d);
else else
d->Add("(no type)"); d->Add("(no type)");
} }
void File::SetAttrs(detail::Attributes* arg_attrs) void File::SetAttrs(detail::Attributes* arg_attrs) {
{ if ( ! arg_attrs )
if ( ! arg_attrs ) return;
return;
attrs = arg_attrs; attrs = arg_attrs;
Ref(attrs); Ref(attrs);
if ( attrs->Find(detail::ATTR_RAW_OUTPUT) ) if ( attrs->Find(detail::ATTR_RAW_OUTPUT) )
EnableRawOutput(); EnableRawOutput();
} }
RecordVal* File::Rotate() RecordVal* File::Rotate() {
{ if ( ! is_open )
if ( ! is_open ) return nullptr;
return nullptr;
// Do not rotate stdin/stdout/stderr. // Do not rotate stdin/stdout/stderr.
if ( f == stdin || f == stdout || f == stderr ) if ( f == stdin || f == stdout || f == stderr )
return nullptr; return nullptr;
static auto rotate_info = id::find_type<RecordType>("rotate_info"); static auto rotate_info = id::find_type<RecordType>("rotate_info");
auto* info = new RecordVal(rotate_info); auto* info = new RecordVal(rotate_info);
FILE* newf = util::detail::rotate_file(name, info); FILE* newf = util::detail::rotate_file(name, info);
if ( ! newf ) if ( ! newf ) {
{ Unref(info);
Unref(info); return nullptr;
return nullptr; }
}
info->AssignTime(2, open_time); info->AssignTime(2, open_time);
Unlink(); Unlink();
fclose(f); fclose(f);
f = nullptr; f = nullptr;
Open(newf); Open(newf);
return info; return info;
} }
void File::CloseOpenFiles() void File::CloseOpenFiles() {
{ auto it = open_files.begin();
auto it = open_files.begin(); while ( it != open_files.end() ) {
while ( it != open_files.end() ) auto el = it++;
{ (*el).second->Close();
auto el = it++; }
(*el).second->Close(); }
}
}
bool File::Write(const char* data, int len) bool File::Write(const char* data, int len) {
{ if ( ! is_open )
if ( ! is_open ) return false;
return false;
if ( ! len ) if ( ! len )
len = strlen(data); len = strlen(data);
if ( fwrite(data, len, 1, f) < 1 ) if ( fwrite(data, len, 1, f) < 1 )
return false; return false;
return true; return true;
} }
void File::RaiseOpenEvent() void File::RaiseOpenEvent() {
{ if ( ! ::file_opened )
if ( ! ::file_opened ) return;
return;
FilePtr bf{NewRef{}, this}; FilePtr bf{NewRef{}, this};
auto* event = new Event(::file_opened, {make_intrusive<FileVal>(std::move(bf))}); auto* event = new Event(::file_opened, {make_intrusive<FileVal>(std::move(bf))});
event_mgr.Dispatch(event, true); event_mgr.Dispatch(event, true);
} }
double File::Size() double File::Size() {
{ fflush(f);
fflush(f); struct stat s;
struct stat s; if ( fstat(fileno(f), &s) < 0 ) {
if ( fstat(fileno(f), &s) < 0 ) reporter->Error("can't stat fd for %s: %s", name, strerror(errno));
{ return 0;
reporter->Error("can't stat fd for %s: %s", name, strerror(errno)); }
return 0;
}
return s.st_size; return s.st_size;
} }
FilePtr File::Get(const char* name) FilePtr File::Get(const char* name) {
{ for ( const auto& el : open_files )
for ( const auto& el : open_files ) if ( el.first == name )
if ( el.first == name ) return {NewRef{}, el.second};
return {NewRef{}, el.second};
return make_intrusive<File>(name, "w"); return make_intrusive<File>(name, "w");
} }
} // namespace zeek } // namespace zeek

View file

@ -10,18 +10,16 @@
#include "zeek/Val.h" #include "zeek/Val.h"
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek namespace zeek {
{
namespace detail namespace detail {
{
class PrintStmt; class PrintStmt;
class Attributes; class Attributes;
extern void do_print_stmt(const std::vector<ValPtr>& vals); extern void do_print_stmt(const std::vector<ValPtr>& vals);
} // namespace detail; } // namespace detail
class RecordVal; class RecordVal;
class Type; class Type;
@ -30,94 +28,93 @@ using TypePtr = IntrusivePtr<Type>;
class File; class File;
using FilePtr = IntrusivePtr<File>; using FilePtr = IntrusivePtr<File>;
class File final : public Obj class File final : public Obj {
{
public: public:
explicit File(FILE* arg_f); explicit File(FILE* arg_f);
File(FILE* arg_f, const char* filename, const char* access); File(FILE* arg_f, const char* filename, const char* access);
File(const char* filename, const char* access); File(const char* filename, const char* access);
~File() override; ~File() override;
const char* Name() const; const char* Name() const;
// Returns false if an error occurred. // Returns false if an error occurred.
bool Write(const char* data, int len = 0); bool Write(const char* data, int len = 0);
void Flush() { fflush(f); } void Flush() { fflush(f); }
FILE* Seek(long position); // seek to absolute position FILE* Seek(long position); // seek to absolute position
void SetBuf(bool buffered); // false=line buffered, true=fully buffered void SetBuf(bool buffered); // false=line buffered, true=fully buffered
const TypePtr& GetType() const { return t; } const TypePtr& GetType() const { return t; }
// Whether the file is open in a general sense; it might // Whether the file is open in a general sense; it might
// not be open as a Unix file due to our management of // not be open as a Unix file due to our management of
// a finite number of FDs. // a finite number of FDs.
bool IsOpen() const { return is_open; } bool IsOpen() const { return is_open; }
// Returns true if the close made sense, false if it was already // Returns true if the close made sense, false if it was already
// closed, not active, or whatever. // closed, not active, or whatever.
bool Close(); bool Close();
void Describe(ODesc* d) const override; void Describe(ODesc* d) const override;
// Rotates the logfile. Returns rotate_info. // Rotates the logfile. Returns rotate_info.
RecordVal* Rotate(); RecordVal* Rotate();
// Set &raw_output attribute. // Set &raw_output attribute.
void SetAttrs(detail::Attributes* attrs); void SetAttrs(detail::Attributes* attrs);
// Returns the current size of the file, after fresh stat'ing. // Returns the current size of the file, after fresh stat'ing.
double Size(); double Size();
// Close all files which are currently open. // Close all files which are currently open.
static void CloseOpenFiles(); static void CloseOpenFiles();
// Get the file with the given name, opening it if it doesn't yet exist. // Get the file with the given name, opening it if it doesn't yet exist.
static FilePtr Get(const char* name); static FilePtr Get(const char* name);
void EnableRawOutput() { raw_output = true; } void EnableRawOutput() { raw_output = true; }
bool IsRawOutput() const { return raw_output; } bool IsRawOutput() const { return raw_output; }
protected: protected:
friend void detail::do_print_stmt(const std::vector<ValPtr>& vals); friend void detail::do_print_stmt(const std::vector<ValPtr>& vals);
File() { Init(); } File() { Init(); }
void Init(); void Init();
/** /**
* If file is given, it's an open file to use already. * If file is given, it's an open file to use already.
* If file is not given and mode is, the filename will be opened with that * If file is not given and mode is, the filename will be opened with that
* access mode. * access mode.
*/ */
bool Open(FILE* f = nullptr, const char* mode = nullptr); bool Open(FILE* f = nullptr, const char* mode = nullptr);
void Unlink(); void Unlink();
// Returns nil if the file is not active, was in error, etc. // Returns nil if the file is not active, was in error, etc.
// (Protected because we do not want anyone to write directly // (Protected because we do not want anyone to write directly
// to the file, but the PrintStmt friend uses this to check whether // to the file, but the PrintStmt friend uses this to check whether
// it's really stdout.) // it's really stdout.)
FILE* FileHandle(); FILE* FileHandle();
// Raises a file_opened event. // Raises a file_opened event.
void RaiseOpenEvent(); void RaiseOpenEvent();
FILE* f = nullptr; FILE* f = nullptr;
TypePtr t; TypePtr t;
char* name = nullptr; char* name = nullptr;
char* access = nullptr; char* access = nullptr;
detail::Attributes* attrs = nullptr; detail::Attributes* attrs = nullptr;
double open_time = 0.0; double open_time = 0.0;
bool is_open = false; // whether the file is open in a general sense bool is_open = false; // whether the file is open in a general sense
bool buffered = false; bool buffered = false;
bool raw_output = false; bool raw_output = false;
static constexpr int MIN_BUFFER_SIZE = 1024; static constexpr int MIN_BUFFER_SIZE = 1024;
private: private:
static std::list<std::pair<std::string, File*>> open_files; static std::list<std::pair<std::string, File*>> open_files;
}; };
} // namespace zeek } // namespace zeek

View file

@ -12,148 +12,134 @@
#include <winsock2.h> #include <winsock2.h>
#define fatalError(...) \ #define fatalError(...) \
do \ do { \
{ \ if ( reporter ) \
if ( reporter ) \ reporter->FatalError(__VA_ARGS__); \
reporter->FatalError(__VA_ARGS__); \ else { \
else \ fprintf(stderr, __VA_ARGS__); \
{ \ fprintf(stderr, "\n"); \
fprintf(stderr, __VA_ARGS__); \ _exit(1); \
fprintf(stderr, "\n"); \ } \
_exit(1); \ } while ( 0 )
} \
} while ( 0 )
#endif #endif
namespace zeek::detail namespace zeek::detail {
{
Flare::Flare() Flare::Flare()
#ifndef _MSC_VER #ifndef _MSC_VER
: pipe(FD_CLOEXEC, FD_CLOEXEC, O_NONBLOCK, O_NONBLOCK) : pipe(FD_CLOEXEC, FD_CLOEXEC, O_NONBLOCK, O_NONBLOCK) {
{ }
}
#else #else
{ {
WSADATA wsaData; WSADATA wsaData;
if ( WSAStartup(MAKEWORD(2, 2), &wsaData) != 0 ) if ( WSAStartup(MAKEWORD(2, 2), &wsaData) != 0 )
fatalError("WSAStartup failure: %d", WSAGetLastError()); fatalError("WSAStartup failure: %d", WSAGetLastError());
recvfd = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, nullptr, 0, recvfd = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, nullptr, 0, WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT);
WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT); if ( recvfd == (int)INVALID_SOCKET )
if ( recvfd == (int)INVALID_SOCKET ) fatalError("WSASocket failure: %d", WSAGetLastError());
fatalError("WSASocket failure: %d", WSAGetLastError()); sendfd = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, nullptr, 0, WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT);
sendfd = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, nullptr, 0, if ( sendfd == (int)INVALID_SOCKET )
WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT); fatalError("WSASocket failure: %d", WSAGetLastError());
if ( sendfd == (int)INVALID_SOCKET )
fatalError("WSASocket failure: %d", WSAGetLastError());
sockaddr_in sa; sockaddr_in sa;
memset(&sa, 0, sizeof(sa)); memset(&sa, 0, sizeof(sa));
sa.sin_family = AF_INET; sa.sin_family = AF_INET;
sa.sin_addr.s_addr = inet_addr("127.0.0.1"); sa.sin_addr.s_addr = inet_addr("127.0.0.1");
if ( bind(recvfd, (sockaddr*)&sa, sizeof(sa)) == SOCKET_ERROR ) if ( bind(recvfd, (sockaddr*)&sa, sizeof(sa)) == SOCKET_ERROR )
fatalError("bind failure: %d", WSAGetLastError()); fatalError("bind failure: %d", WSAGetLastError());
int salen = sizeof(sa); int salen = sizeof(sa);
if ( getsockname(recvfd, (sockaddr*)&sa, &salen) == SOCKET_ERROR ) if ( getsockname(recvfd, (sockaddr*)&sa, &salen) == SOCKET_ERROR )
fatalError("getsockname failure: %d", WSAGetLastError()); fatalError("getsockname failure: %d", WSAGetLastError());
if ( connect(sendfd, (sockaddr*)&sa, sizeof(sa)) == SOCKET_ERROR ) if ( connect(sendfd, (sockaddr*)&sa, sizeof(sa)) == SOCKET_ERROR )
fatalError("connect failure: %d", WSAGetLastError()); fatalError("connect failure: %d", WSAGetLastError());
} }
#endif #endif
[[noreturn]] static void bad_pipe_op(const char* which, bool signal_safe) [[noreturn]] static void bad_pipe_op(const char* which, bool signal_safe) {
{ if ( signal_safe )
if ( signal_safe ) abort();
abort();
char buf[256]; char buf[256];
util::zeek_strerror_r(errno, buf, sizeof(buf)); util::zeek_strerror_r(errno, buf, sizeof(buf));
if ( reporter ) if ( reporter )
reporter->FatalErrorWithCore("unexpected pipe %s failure: %s", which, buf); reporter->FatalErrorWithCore("unexpected pipe %s failure: %s", which, buf);
else else {
{ fprintf(stderr, "unexpected pipe %s failure: %s", which, buf);
fprintf(stderr, "unexpected pipe %s failure: %s", which, buf); abort();
abort(); }
} }
}
void Flare::Fire(bool signal_safe) void Flare::Fire(bool signal_safe) {
{ char tmp = 0;
char tmp = 0;
for ( ;; ) for ( ;; ) {
{
#ifndef _MSC_VER #ifndef _MSC_VER
int n = write(pipe.WriteFD(), &tmp, 1); int n = write(pipe.WriteFD(), &tmp, 1);
#else #else
int n = send(sendfd, &tmp, 1, 0); int n = send(sendfd, &tmp, 1, 0);
#endif #endif
if ( n > 0 ) if ( n > 0 )
// Success -- wrote a byte to pipe. // Success -- wrote a byte to pipe.
break; break;
if ( n < 0 ) if ( n < 0 ) {
{
#ifdef _MSC_VER #ifdef _MSC_VER
errno = WSAGetLastError(); errno = WSAGetLastError();
bad_pipe_op("send", signal_safe); bad_pipe_op("send", signal_safe);
#endif #endif
if ( errno == EAGAIN ) if ( errno == EAGAIN )
// Success: pipe is full and just need at least one byte in it. // Success: pipe is full and just need at least one byte in it.
break; break;
if ( errno == EINTR ) if ( errno == EINTR )
// Interrupted: try again. // Interrupted: try again.
continue; continue;
bad_pipe_op("write", signal_safe); bad_pipe_op("write", signal_safe);
} }
// No error, but didn't write a byte: try again. // No error, but didn't write a byte: try again.
} }
} }
int Flare::Extinguish(bool signal_safe) int Flare::Extinguish(bool signal_safe) {
{ int rval = 0;
int rval = 0; char tmp[256];
char tmp[256];
for ( ;; ) for ( ;; ) {
{
#ifndef _MSC_VER #ifndef _MSC_VER
int n = read(pipe.ReadFD(), &tmp, sizeof(tmp)); int n = read(pipe.ReadFD(), &tmp, sizeof(tmp));
#else #else
int n = recv(recvfd, tmp, sizeof(tmp), 0); int n = recv(recvfd, tmp, sizeof(tmp), 0);
#endif #endif
if ( n >= 0 ) if ( n >= 0 ) {
{ rval += n;
rval += n; // Pipe may not be empty yet: try again.
// Pipe may not be empty yet: try again. continue;
continue; }
}
#ifdef _MSC_VER #ifdef _MSC_VER
if ( WSAGetLastError() == WSAEWOULDBLOCK ) if ( WSAGetLastError() == WSAEWOULDBLOCK )
break; break;
errno = WSAGetLastError(); errno = WSAGetLastError();
bad_pipe_op("recv", signal_safe); bad_pipe_op("recv", signal_safe);
#endif #endif
if ( errno == EAGAIN ) if ( errno == EAGAIN )
// Success: pipe is now empty. // Success: pipe is now empty.
break; break;
if ( errno == EINTR ) if ( errno == EINTR )
// Interrupted: try again. // Interrupted: try again.
continue; continue;
bad_pipe_op("read", signal_safe); bad_pipe_op("read", signal_safe);
} }
return rval; return rval;
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -6,57 +6,55 @@
#include "Pipe.h" #include "Pipe.h"
#endif #endif
namespace zeek::detail namespace zeek::detail {
{
class Flare class Flare {
{
public: public:
/** /**
* Create a flare object that can be used to signal a "ready" status via * Create a flare object that can be used to signal a "ready" status via
* a file descriptor that may be integrated with select(), poll(), etc. * a file descriptor that may be integrated with select(), poll(), etc.
* Not thread-safe, but that should only require Fire()/Extinguish() calls * Not thread-safe, but that should only require Fire()/Extinguish() calls
* to be made mutually exclusive (across all copies of a Flare). * to be made mutually exclusive (across all copies of a Flare).
*/ */
Flare(); Flare();
/** /**
* @return a file descriptor that will become ready if the flare has been * @return a file descriptor that will become ready if the flare has been
* Fire()'d and not yet Extinguished()'d. * Fire()'d and not yet Extinguished()'d.
*/ */
int FD() const int FD() const
#ifndef _MSC_VER #ifndef _MSC_VER
{ {
return pipe.ReadFD(); return pipe.ReadFD();
} }
#else #else
{ {
return recvfd; return recvfd;
} }
#endif #endif
/** /**
* Put the object in the "ready" state. * Put the object in the "ready" state.
* @param signal_safe whether to skip error-reporting functionality that * @param signal_safe whether to skip error-reporting functionality that
* is not async-signal-safe (errors still abort the process regardless) * is not async-signal-safe (errors still abort the process regardless)
*/ */
void Fire(bool signal_safe = false); void Fire(bool signal_safe = false);
/** /**
* Take the object out of the "ready" state. * Take the object out of the "ready" state.
* @param signal_safe whether to skip error-reporting functionality that * @param signal_safe whether to skip error-reporting functionality that
* is not async-signal-safe (errors still abort the process regardless) * is not async-signal-safe (errors still abort the process regardless)
* @return number of bytes read from the pipe, corresponds to the number * @return number of bytes read from the pipe, corresponds to the number
* of times Fire() was called. * of times Fire() was called.
*/ */
int Extinguish(bool signal_safe = false); int Extinguish(bool signal_safe = false);
private: private:
#ifndef _MSC_VER #ifndef _MSC_VER
Pipe pipe; Pipe pipe;
#else #else
int sendfd, recvfd; int sendfd, recvfd;
#endif #endif
}; };
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -14,371 +14,328 @@
constexpr uint32_t MIN_ACCEPTABLE_FRAG_SIZE = 64; constexpr uint32_t MIN_ACCEPTABLE_FRAG_SIZE = 64;
constexpr uint32_t MAX_ACCEPTABLE_FRAG_SIZE = 64000; constexpr uint32_t MAX_ACCEPTABLE_FRAG_SIZE = 64000;
namespace zeek::detail namespace zeek::detail {
{
FragTimer::~FragTimer() {
FragTimer::~FragTimer() if ( f )
{ f->ClearTimer();
if ( f ) }
f->ClearTimer();
} void FragTimer::Dispatch(double t, bool /* is_expire */) {
if ( f )
void FragTimer::Dispatch(double t, bool /* is_expire */) f->Expire(t);
{ else
if ( f ) reporter->InternalWarning("fragment timer dispatched w/o reassembler");
f->Expire(t); }
else
reporter->InternalWarning("fragment timer dispatched w/o reassembler"); FragReassembler::FragReassembler(session::Manager* arg_s, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt,
} const FragReassemblerKey& k, double t)
: Reassembler(0, REASSEM_FRAG) {
FragReassembler::FragReassembler(session::Manager* arg_s, const std::shared_ptr<IP_Hdr>& ip, s = arg_s;
const u_char* pkt, const FragReassemblerKey& k, double t) key = k;
: Reassembler(0, REASSEM_FRAG)
{ const struct ip* ip4 = ip->IP4_Hdr();
s = arg_s; if ( ip4 ) {
key = k; proto_hdr_len = ip->HdrLen();
proto_hdr = new u_char[64]; // max IP header + slop
const struct ip* ip4 = ip->IP4_Hdr(); // Don't do a structure copy - need to pick up options, too.
if ( ip4 ) memcpy((void*)proto_hdr, (const void*)ip4, proto_hdr_len);
{ }
proto_hdr_len = ip->HdrLen(); else {
proto_hdr = new u_char[64]; // max IP header + slop proto_hdr_len = ip->HdrLen() - 8; // minus length of fragment header
// Don't do a structure copy - need to pick up options, too. proto_hdr = new u_char[proto_hdr_len];
memcpy((void*)proto_hdr, (const void*)ip4, proto_hdr_len); memcpy(proto_hdr, ip->IP6_Hdr(), proto_hdr_len);
} }
else
{ reassembled_pkt = nullptr;
proto_hdr_len = ip->HdrLen() - 8; // minus length of fragment header frag_size = 0; // flag meaning "not known"
proto_hdr = new u_char[proto_hdr_len]; next_proto = ip->NextProto();
memcpy(proto_hdr, ip->IP6_Hdr(), proto_hdr_len);
} if ( frag_timeout != 0.0 ) {
expire_timer = new FragTimer(this, t + frag_timeout);
reassembled_pkt = nullptr; timer_mgr->Add(expire_timer);
frag_size = 0; // flag meaning "not known" }
next_proto = ip->NextProto(); else
expire_timer = nullptr;
if ( frag_timeout != 0.0 )
{ AddFragment(t, ip, pkt);
expire_timer = new FragTimer(this, t + frag_timeout); }
timer_mgr->Add(expire_timer);
} FragReassembler::~FragReassembler() {
else DeleteTimer();
expire_timer = nullptr; delete[] proto_hdr;
}
AddFragment(t, ip, pkt);
} void FragReassembler::AddFragment(double t, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt) {
const struct ip* ip4 = ip->IP4_Hdr();
FragReassembler::~FragReassembler()
{ if ( ip4 ) {
DeleteTimer(); if ( ip4->ip_p != ((const struct ip*)proto_hdr)->ip_p || ip4->ip_hl != ((const struct ip*)proto_hdr)->ip_hl )
delete[] proto_hdr; // || ip4->ip_tos != proto_hdr->ip_tos
} // don't check TOS, there's at least one stack that actually
// uses different values, and it's hard to see an associated
void FragReassembler::AddFragment(double t, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt) // attack.
{ s->Weird("fragment_protocol_inconsistency", ip.get());
const struct ip* ip4 = ip->IP4_Hdr(); }
else {
if ( ip4 ) if ( ip->NextProto() != next_proto || ip->HdrLen() - 8 != proto_hdr_len )
{ s->Weird("fragment_protocol_inconsistency", ip.get());
if ( ip4->ip_p != ((const struct ip*)proto_hdr)->ip_p || // TODO: more detailed unfrag header consistency checks?
ip4->ip_hl != ((const struct ip*)proto_hdr)->ip_hl ) }
// || ip4->ip_tos != proto_hdr->ip_tos
// don't check TOS, there's at least one stack that actually if ( ip->DF() )
// uses different values, and it's hard to see an associated // Linux MTU discovery for UDP can do this, for example.
// attack. s->Weird("fragment_with_DF", ip.get());
s->Weird("fragment_protocol_inconsistency", ip.get());
} uint16_t offset = ip->FragOffset();
else uint32_t len = ip->TotalLen();
{ uint16_t hdr_len = ip->HdrLen();
if ( ip->NextProto() != next_proto || ip->HdrLen() - 8 != proto_hdr_len )
s->Weird("fragment_protocol_inconsistency", ip.get()); if ( len < hdr_len ) {
// TODO: more detailed unfrag header consistency checks? s->Weird("fragment_protocol_inconsistency", ip.get());
} return;
}
if ( ip->DF() )
// Linux MTU discovery for UDP can do this, for example. uint64_t upper_seq = offset + len - hdr_len;
s->Weird("fragment_with_DF", ip.get());
if ( ! offset )
uint16_t offset = ip->FragOffset(); // Make sure to use the first fragment header's next field.
uint32_t len = ip->TotalLen(); next_proto = ip->NextProto();
uint16_t hdr_len = ip->HdrLen();
if ( ! ip->MF() ) {
if ( len < hdr_len ) // Last fragment.
{ if ( frag_size == 0 )
s->Weird("fragment_protocol_inconsistency", ip.get()); frag_size = upper_seq;
return;
} else if ( upper_seq != frag_size ) {
s->Weird("fragment_size_inconsistency", ip.get());
uint64_t upper_seq = offset + len - hdr_len;
if ( upper_seq > frag_size )
if ( ! offset ) frag_size = upper_seq;
// Make sure to use the first fragment header's next field. }
next_proto = ip->NextProto(); }
if ( ! ip->MF() ) else if ( len < MIN_ACCEPTABLE_FRAG_SIZE )
{ s->Weird("excessively_small_fragment", ip.get());
// Last fragment.
if ( frag_size == 0 ) if ( upper_seq > MAX_ACCEPTABLE_FRAG_SIZE )
frag_size = upper_seq; s->Weird("excessively_large_fragment", ip.get());
else if ( upper_seq != frag_size ) if ( frag_size && upper_seq > frag_size ) {
{ // This can happen if we receive a fragment that's *not*
s->Weird("fragment_size_inconsistency", ip.get()); // the last fragment, but still imputes a size that's
// larger than the size we derived from a previously-seen
if ( upper_seq > frag_size ) // "last fragment".
frag_size = upper_seq;
} s->Weird("fragment_size_inconsistency", ip.get());
} frag_size = upper_seq;
}
else if ( len < MIN_ACCEPTABLE_FRAG_SIZE )
s->Weird("excessively_small_fragment", ip.get()); // Do we need to check for consistent options? That's tricky
// for things like LSRR that get modified in route.
if ( upper_seq > MAX_ACCEPTABLE_FRAG_SIZE )
s->Weird("excessively_large_fragment", ip.get()); // Remove header.
pkt += hdr_len;
if ( frag_size && upper_seq > frag_size ) len -= hdr_len;
{
// This can happen if we receive a fragment that's *not* NewBlock(run_state::network_time, offset, len, pkt);
// the last fragment, but still imputes a size that's }
// larger than the size we derived from a previously-seen
// "last fragment". void FragReassembler::Weird(const char* name) const {
unsigned int version = ((const ip*)proto_hdr)->ip_v;
s->Weird("fragment_size_inconsistency", ip.get());
frag_size = upper_seq; if ( version == 4 ) {
} IP_Hdr hdr((const ip*)proto_hdr, false);
s->Weird(name, &hdr);
// Do we need to check for consistent options? That's tricky }
// for things like LSRR that get modified in route.
else if ( version == 6 ) {
// Remove header. IP_Hdr hdr((const ip6_hdr*)proto_hdr, false, proto_hdr_len);
pkt += hdr_len; s->Weird(name, &hdr);
len -= hdr_len; }
NewBlock(run_state::network_time, offset, len, pkt); else {
} reporter->InternalWarning("Unexpected IP version in FragReassembler");
reporter->Weird(name);
void FragReassembler::Weird(const char* name) const }
{ }
unsigned int version = ((const ip*)proto_hdr)->ip_v;
void FragReassembler::Overlap(const u_char* b1, const u_char* b2, uint64_t n) {
if ( version == 4 ) if ( memcmp((const void*)b1, (const void*)b2, n) )
{ Weird("fragment_inconsistency");
IP_Hdr hdr((const ip*)proto_hdr, false); else
s->Weird(name, &hdr); Weird("fragment_overlap");
} }
else if ( version == 6 ) void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */) {
{ auto it = block_list.Begin();
IP_Hdr hdr((const ip6_hdr*)proto_hdr, false, proto_hdr_len);
s->Weird(name, &hdr); if ( it->second.seq > 0 || ! frag_size )
} // For sure don't have it all yet.
return;
else
{ auto next = std::next(it);
reporter->InternalWarning("Unexpected IP version in FragReassembler");
reporter->Weird(name); // We might have it all - look for contiguous all the way.
} while ( next != block_list.End() ) {
} if ( it->second.upper != next->second.seq )
break;
void FragReassembler::Overlap(const u_char* b1, const u_char* b2, uint64_t n)
{ ++it;
if ( memcmp((const void*)b1, (const void*)b2, n) ) ++next;
Weird("fragment_inconsistency"); }
else
Weird("fragment_overlap"); const auto& last = block_list.LastBlock();
}
if ( next != block_list.End() ) {
void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */) // We have a hole.
{ if ( it->second.upper >= frag_size ) {
auto it = block_list.Begin(); // We're stuck. The point where we stopped is
// contiguous up through the expected end of
if ( it->second.seq > 0 || ! frag_size ) // the fragment, but there's more stuff still
// For sure don't have it all yet. // beyond it, which is not contiguous. This
return; // can happen for benign reasons when we're
// intermingling parts of two fragmented packets.
auto next = std::next(it); Weird("fragment_size_inconsistency");
// We might have it all - look for contiguous all the way. // We decide to analyze the contiguous portion now.
while ( next != block_list.End() ) // Extend the fragment up through the end of what
{ // we have.
if ( it->second.upper != next->second.seq ) frag_size = it->second.upper;
break; }
else
++it; return;
++next; }
}
else if ( last.upper > frag_size ) {
const auto& last = block_list.LastBlock(); Weird("fragment_size_inconsistency");
frag_size = last.upper;
if ( next != block_list.End() ) }
{
// We have a hole. else if ( last.upper < frag_size )
if ( it->second.upper >= frag_size ) // Missing the tail.
{ return;
// We're stuck. The point where we stopped is
// contiguous up through the expected end of // We have it all. Compute the expected size of the fragment.
// the fragment, but there's more stuff still uint64_t n = proto_hdr_len + frag_size;
// beyond it, which is not contiguous. This
// can happen for benign reasons when we're // It's possible that we have blocks associated with this fragment
// intermingling parts of two fragmented packets. // that exceed this size, if we saw MF fragments (which don't lead
Weird("fragment_size_inconsistency"); // to us setting frag_size) that went beyond the size indicated by
// the final, non-MF fragment. This can happen for benign reasons
// We decide to analyze the contiguous portion now. // due to intermingling of fragments from an older datagram with those
// Extend the fragment up through the end of what // for a more recent one.
// we have.
frag_size = it->second.upper; u_char* pkt = new u_char[n];
} memcpy((void*)pkt, (const void*)proto_hdr, proto_hdr_len);
else
return; u_char* pkt_start = pkt;
}
pkt += proto_hdr_len;
else if ( last.upper > frag_size )
{ for ( it = block_list.Begin(); it != block_list.End(); ++it ) {
Weird("fragment_size_inconsistency"); const auto& b = it->second;
frag_size = last.upper;
} if ( it != block_list.Begin() ) {
const auto& prev = std::prev(it)->second;
else if ( last.upper < frag_size )
// Missing the tail. // If we're above a hole, stop. This can happen because
return; // the logic above regarding a hole that's above the
// expected fragment size.
// We have it all. Compute the expected size of the fragment. if ( prev.upper < b.seq )
uint64_t n = proto_hdr_len + frag_size; break;
}
// It's possible that we have blocks associated with this fragment
// that exceed this size, if we saw MF fragments (which don't lead if ( b.upper > n ) {
// to us setting frag_size) that went beyond the size indicated by reporter->InternalWarning("bad fragment reassembly");
// the final, non-MF fragment. This can happen for benign reasons DeleteTimer();
// due to intermingling of fragments from an older datagram with those Expire(run_state::network_time);
// for a more recent one. delete[] pkt_start;
return;
u_char* pkt = new u_char[n]; }
memcpy((void*)pkt, (const void*)proto_hdr, proto_hdr_len);
memcpy(&pkt[b.seq], b.block, b.upper - b.seq);
u_char* pkt_start = pkt; }
pkt += proto_hdr_len; reassembled_pkt.reset();
for ( it = block_list.Begin(); it != block_list.End(); ++it ) unsigned int version = ((const struct ip*)pkt_start)->ip_v;
{
const auto& b = it->second; if ( version == 4 ) {
struct ip* reassem4 = (struct ip*)pkt_start;
if ( it != block_list.Begin() ) reassem4->ip_len = htons(frag_size + proto_hdr_len);
{ reassembled_pkt = std::make_shared<IP_Hdr>(reassem4, true, true);
const auto& prev = std::prev(it)->second; DeleteTimer();
}
// If we're above a hole, stop. This can happen because
// the logic above regarding a hole that's above the else if ( version == 6 ) {
// expected fragment size. struct ip6_hdr* reassem6 = (struct ip6_hdr*)pkt_start;
if ( prev.upper < b.seq ) reassem6->ip6_plen = htons(frag_size + proto_hdr_len - 40);
break; const IPv6_Hdr_Chain* chain = new IPv6_Hdr_Chain(reassem6, next_proto, n);
} reassembled_pkt = std::make_shared<IP_Hdr>(reassem6, true, n, chain, true);
DeleteTimer();
if ( b.upper > n ) }
{
reporter->InternalWarning("bad fragment reassembly"); else {
DeleteTimer(); reporter->InternalWarning("bad IP version in fragment reassembly: %d", version);
Expire(run_state::network_time); delete[] pkt_start;
delete[] pkt_start; }
return; }
}
void FragReassembler::Expire(double t) {
memcpy(&pkt[b.seq], b.block, b.upper - b.seq); block_list.Clear();
} expire_timer->ClearReassembler();
expire_timer = nullptr; // timer manager will delete it
reassembled_pkt.reset();
fragment_mgr->Remove(this);
unsigned int version = ((const struct ip*)pkt_start)->ip_v; }
if ( version == 4 ) void FragReassembler::DeleteTimer() {
{ if ( expire_timer ) {
struct ip* reassem4 = (struct ip*)pkt_start; expire_timer->ClearReassembler();
reassem4->ip_len = htons(frag_size + proto_hdr_len); timer_mgr->Cancel(expire_timer);
reassembled_pkt = std::make_shared<IP_Hdr>(reassem4, true, true); expire_timer = nullptr; // timer manager will delete it
DeleteTimer(); }
} }
else if ( version == 6 ) FragmentManager::~FragmentManager() { Clear(); }
{
struct ip6_hdr* reassem6 = (struct ip6_hdr*)pkt_start; FragReassembler* FragmentManager::NextFragment(double t, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt) {
reassem6->ip6_plen = htons(frag_size + proto_hdr_len - 40); uint32_t frag_id = ip->ID();
const IPv6_Hdr_Chain* chain = new IPv6_Hdr_Chain(reassem6, next_proto, n); FragReassemblerKey key = std::make_tuple(ip->SrcAddr(), ip->DstAddr(), frag_id);
reassembled_pkt = std::make_shared<IP_Hdr>(reassem6, true, n, chain, true);
DeleteTimer(); FragReassembler* f = nullptr;
} auto it = fragments.find(key);
if ( it != fragments.end() )
else f = it->second;
{
reporter->InternalWarning("bad IP version in fragment reassembly: %d", version); if ( ! f ) {
delete[] pkt_start; f = new FragReassembler(session_mgr, ip, pkt, key, t);
} fragments[key] = f;
} if ( fragments.size() > max_fragments )
max_fragments = fragments.size();
void FragReassembler::Expire(double t) return f;
{ }
block_list.Clear();
expire_timer->ClearReassembler(); f->AddFragment(t, ip, pkt);
expire_timer = nullptr; // timer manager will delete it return f;
}
fragment_mgr->Remove(this);
} void FragmentManager::Clear() {
for ( const auto& entry : fragments )
void FragReassembler::DeleteTimer() Unref(entry.second);
{
if ( expire_timer ) fragments.clear();
{ }
expire_timer->ClearReassembler();
timer_mgr->Cancel(expire_timer); void FragmentManager::Remove(detail::FragReassembler* f) {
expire_timer = nullptr; // timer manager will delete it if ( ! f )
} return;
}
if ( fragments.erase(f->Key()) == 0 )
FragmentManager::~FragmentManager() reporter->InternalWarning("fragment reassembler not in dict");
{
Clear(); Unref(f);
} }
FragReassembler* FragmentManager::NextFragment(double t, const std::shared_ptr<IP_Hdr>& ip, } // namespace zeek::detail
const u_char* pkt)
{
uint32_t frag_id = ip->ID();
FragReassemblerKey key = std::make_tuple(ip->SrcAddr(), ip->DstAddr(), frag_id);
FragReassembler* f = nullptr;
auto it = fragments.find(key);
if ( it != fragments.end() )
f = it->second;
if ( ! f )
{
f = new FragReassembler(session_mgr, ip, pkt, key, t);
fragments[key] = f;
if ( fragments.size() > max_fragments )
max_fragments = fragments.size();
return f;
}
f->AddFragment(t, ip, pkt);
return f;
}
void FragmentManager::Clear()
{
for ( const auto& entry : fragments )
Unref(entry.second);
fragments.clear();
}
void FragmentManager::Remove(detail::FragReassembler* f)
{
if ( ! f )
return;
if ( fragments.erase(f->Key()) == 0 )
reporter->InternalWarning("fragment reassembler not in dict");
Unref(f);
}
} // namespace zeek::detail

View file

@ -10,102 +10,95 @@
#include "zeek/Timer.h" #include "zeek/Timer.h"
#include "zeek/util.h" // for zeek_uint_t #include "zeek/util.h" // for zeek_uint_t
namespace zeek namespace zeek {
{
class IP_Hdr; class IP_Hdr;
namespace session namespace session {
{
class Manager; class Manager;
} }
namespace detail namespace detail {
{
class FragReassembler; class FragReassembler;
class FragTimer; class FragTimer;
using FragReassemblerKey = std::tuple<IPAddr, IPAddr, zeek_uint_t>; using FragReassemblerKey = std::tuple<IPAddr, IPAddr, zeek_uint_t>;
class FragReassembler : public Reassembler class FragReassembler : public Reassembler {
{
public: public:
FragReassembler(session::Manager* s, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt, FragReassembler(session::Manager* s, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt,
const FragReassemblerKey& k, double t); const FragReassemblerKey& k, double t);
~FragReassembler() override; ~FragReassembler() override;
void AddFragment(double t, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt); void AddFragment(double t, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt);
void Expire(double t); void Expire(double t);
void DeleteTimer(); void DeleteTimer();
void ClearTimer() { expire_timer = nullptr; } void ClearTimer() { expire_timer = nullptr; }
std::shared_ptr<IP_Hdr> ReassembledPkt() { return std::move(reassembled_pkt); } std::shared_ptr<IP_Hdr> ReassembledPkt() { return std::move(reassembled_pkt); }
const FragReassemblerKey& Key() const { return key; } const FragReassemblerKey& Key() const { return key; }
protected: protected:
void BlockInserted(DataBlockMap::const_iterator it) override; void BlockInserted(DataBlockMap::const_iterator it) override;
void Overlap(const u_char* b1, const u_char* b2, uint64_t n) override; void Overlap(const u_char* b1, const u_char* b2, uint64_t n) override;
void Weird(const char* name) const; void Weird(const char* name) const;
u_char* proto_hdr; u_char* proto_hdr;
std::shared_ptr<IP_Hdr> reassembled_pkt; std::shared_ptr<IP_Hdr> reassembled_pkt;
session::Manager* s; session::Manager* s;
uint64_t frag_size; // size of fully reassembled fragment uint64_t frag_size; // size of fully reassembled fragment
FragReassemblerKey key; FragReassemblerKey key;
uint16_t next_proto; // first IPv6 fragment header's next proto field uint16_t next_proto; // first IPv6 fragment header's next proto field
uint16_t proto_hdr_len; uint16_t proto_hdr_len;
FragTimer* expire_timer; FragTimer* expire_timer;
}; };
class FragTimer final : public Timer class FragTimer final : public Timer {
{
public: public:
FragTimer(FragReassembler* arg_f, double arg_t) : Timer(arg_t, TIMER_FRAG) { f = arg_f; } FragTimer(FragReassembler* arg_f, double arg_t) : Timer(arg_t, TIMER_FRAG) { f = arg_f; }
~FragTimer() override; ~FragTimer() override;
void Dispatch(double t, bool is_expire) override; void Dispatch(double t, bool is_expire) override;
// Break the association between this timer and its creator. // Break the association between this timer and its creator.
void ClearReassembler() { f = nullptr; } void ClearReassembler() { f = nullptr; }
protected: protected:
FragReassembler* f; FragReassembler* f;
}; };
class FragmentManager class FragmentManager {
{
public: public:
FragmentManager() = default; FragmentManager() = default;
~FragmentManager(); ~FragmentManager();
FragReassembler* NextFragment(double t, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt); FragReassembler* NextFragment(double t, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt);
void Clear(); void Clear();
void Remove(detail::FragReassembler* f); void Remove(detail::FragReassembler* f);
size_t Size() const { return fragments.size(); } size_t Size() const { return fragments.size(); }
size_t MaxFragments() const { return max_fragments; } size_t MaxFragments() const { return max_fragments; }
private: private:
using FragmentMap = std::map<detail::FragReassemblerKey, detail::FragReassembler*>; using FragmentMap = std::map<detail::FragReassemblerKey, detail::FragReassembler*>;
FragmentMap fragments; FragmentMap fragments;
size_t max_fragments = 0; size_t max_fragments = 0;
}; };
extern FragmentManager* fragment_mgr; extern FragmentManager* fragment_mgr;
class FragReassemblerTracker class FragReassemblerTracker {
{
public: public:
FragReassemblerTracker(FragReassembler* f) : frag_reassembler(f) { } FragReassemblerTracker(FragReassembler* f) : frag_reassembler(f) {}
~FragReassemblerTracker() { fragment_mgr->Remove(frag_reassembler); } ~FragReassemblerTracker() { fragment_mgr->Remove(frag_reassembler); }
private: private:
FragReassembler* frag_reassembler; FragReassembler* frag_reassembler;
}; };
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -13,211 +13,185 @@
std::vector<zeek::detail::Frame*> g_frame_stack; std::vector<zeek::detail::Frame*> g_frame_stack;
namespace zeek::detail namespace zeek::detail {
{
Frame::Frame(int arg_size, const ScriptFunc* func, const zeek::Args* fn_args) Frame::Frame(int arg_size, const ScriptFunc* func, const zeek::Args* fn_args) {
{ size = arg_size;
size = arg_size; frame = std::make_unique<Element[]>(size);
frame = std::make_unique<Element[]>(size); function = func;
function = func; func_args = fn_args;
func_args = fn_args;
next_stmt = nullptr; next_stmt = nullptr;
break_before_next_stmt = false; break_before_next_stmt = false;
break_on_return = false; break_on_return = false;
call = nullptr; call = nullptr;
delayed = false; delayed = false;
// We could Ref()/Unref() the captures frame, but there's really // We could Ref()/Unref() the captures frame, but there's really
// no need because by definition this current frame exists to // no need because by definition this current frame exists to
// enable execution of the function, and its captures frame won't // enable execution of the function, and its captures frame won't
// go away until the function itself goes away, which can only be // go away until the function itself goes away, which can only be
// after this frame does. // after this frame does.
captures = function ? function->GetCapturesFrame() : nullptr; captures = function ? function->GetCapturesFrame() : nullptr;
captures_offset_map = function ? function->GetCapturesOffsetMap() : nullptr; captures_offset_map = function ? function->GetCapturesOffsetMap() : nullptr;
current_offset = 0; current_offset = 0;
} }
void Frame::SetElement(int n, ValPtr v) void Frame::SetElement(int n, ValPtr v) {
{ n += current_offset;
n += current_offset; ASSERT(n >= 0 && n < size);
ASSERT(n >= 0 && n < size); frame[n] = std::move(v);
frame[n] = std::move(v); }
}
void Frame::SetElement(const ID* id, ValPtr v) void Frame::SetElement(const ID* id, ValPtr v) {
{ if ( captures ) {
if ( captures ) auto cap_off = captures_offset_map->find(id->Name());
{ if ( cap_off != captures_offset_map->end() ) {
auto cap_off = captures_offset_map->find(id->Name()); captures->SetElement(cap_off->second, std::move(v));
if ( cap_off != captures_offset_map->end() ) return;
{ }
captures->SetElement(cap_off->second, std::move(v)); }
return;
}
}
SetElement(id->Offset(), std::move(v)); SetElement(id->Offset(), std::move(v));
} }
const ValPtr& Frame::GetElementByID(const ID* id) const const ValPtr& Frame::GetElementByID(const ID* id) const {
{ if ( captures ) {
if ( captures ) auto cap_off = captures_offset_map->find(id->Name());
{ if ( cap_off != captures_offset_map->end() )
auto cap_off = captures_offset_map->find(id->Name()); return captures->GetElement(cap_off->second);
if ( cap_off != captures_offset_map->end() ) }
return captures->GetElement(cap_off->second);
}
return frame[id->Offset() + current_offset]; return frame[id->Offset() + current_offset];
} }
void Frame::Reset(int startIdx) void Frame::Reset(int startIdx) {
{ for ( int i = startIdx + current_offset; i < size; ++i )
for ( int i = startIdx + current_offset; i < size; ++i ) frame[i] = nullptr;
frame[i] = nullptr; }
}
void Frame::Describe(ODesc* d) const void Frame::Describe(ODesc* d) const {
{ if ( ! d->IsBinary() )
if ( ! d->IsBinary() ) d->AddSP("frame");
d->AddSP("frame");
if ( ! d->IsReadable() ) if ( ! d->IsReadable() ) {
{ d->Add(size);
d->Add(size);
for ( int i = 0; i < size; ++i ) for ( int i = 0; i < size; ++i ) {
{ d->Add(frame[i] != nullptr);
d->Add(frame[i] != nullptr); d->SP();
d->SP(); }
} }
}
for ( int i = 0; i < size; ++i ) for ( int i = 0; i < size; ++i )
if ( frame[i] ) if ( frame[i] )
frame[i]->Describe(d); frame[i]->Describe(d);
else if ( d->IsReadable() ) else if ( d->IsReadable() )
d->Add("<nil>"); d->Add("<nil>");
} }
Frame* Frame::Clone() const Frame* Frame::Clone() const {
{ Frame* other = new Frame(size, function, func_args);
Frame* other = new Frame(size, function, func_args);
other->call = call; other->call = call;
other->assoc = assoc; other->assoc = assoc;
other->trigger = trigger; other->trigger = trigger;
for ( int i = 0; i < size; i++ ) for ( int i = 0; i < size; i++ )
if ( frame[i] ) if ( frame[i] )
other->frame[i] = frame[i]->Clone(); other->frame[i] = frame[i]->Clone();
// Note, there's no need to clone "captures" or "captures_offset_map" // Note, there's no need to clone "captures" or "captures_offset_map"
// since those get created fresh when constructing "other". // since those get created fresh when constructing "other".
return other; return other;
} }
Frame* Frame::CloneForTrigger() const Frame* Frame::CloneForTrigger() const {
{ Frame* other = new Frame(0, function, func_args);
Frame* other = new Frame(0, function, func_args);
other->call = call; other->call = call;
other->assoc = assoc; other->assoc = assoc;
other->trigger = trigger; other->trigger = trigger;
return other; return other;
} }
static bool val_is_func(const ValPtr& v, ScriptFunc* func) static bool val_is_func(const ValPtr& v, ScriptFunc* func) {
{ if ( v->GetType()->Tag() != TYPE_FUNC )
if ( v->GetType()->Tag() != TYPE_FUNC ) return false;
return false;
return v->AsFunc() == func; return v->AsFunc() == func;
} }
broker::expected<broker::data> Frame::Serialize() broker::expected<broker::data> Frame::Serialize() {
{ broker::vector body;
broker::vector body;
for ( int i = 0; i < size; ++i ) for ( int i = 0; i < size; ++i ) {
{ const auto& val = frame[i];
const auto& val = frame[i]; auto expected = Broker::detail::val_to_data(val.get());
auto expected = Broker::detail::val_to_data(val.get()); if ( ! expected )
if ( ! expected ) return broker::ec::invalid_data;
return broker::ec::invalid_data;
TypeTag tag = val->GetType()->Tag(); TypeTag tag = val->GetType()->Tag();
broker::vector val_tuple{std::move(*expected), static_cast<broker::integer>(tag)}; broker::vector val_tuple{std::move(*expected), static_cast<broker::integer>(tag)};
body.emplace_back(std::move(val_tuple)); body.emplace_back(std::move(val_tuple));
} }
broker::vector rval; broker::vector rval;
rval.emplace_back(std::move(body)); rval.emplace_back(std::move(body));
return {std::move(rval)}; return {std::move(rval)};
} }
std::pair<bool, FramePtr> Frame::Unserialize(const broker::vector& data) std::pair<bool, FramePtr> Frame::Unserialize(const broker::vector& data) {
{ if ( data.size() == 0 )
if ( data.size() == 0 ) return std::make_pair(true, nullptr);
return std::make_pair(true, nullptr);
auto where = data.begin(); auto where = data.begin();
auto has_body = broker::get_if<broker::vector>(*where); auto has_body = broker::get_if<broker::vector>(*where);
if ( ! has_body ) if ( ! has_body )
return std::make_pair(false, nullptr); return std::make_pair(false, nullptr);
broker::vector body = *has_body; broker::vector body = *has_body;
int frame_size = body.size(); int frame_size = body.size();
auto rf = make_intrusive<Frame>(frame_size, nullptr, nullptr); auto rf = make_intrusive<Frame>(frame_size, nullptr, nullptr);
for ( int i = 0; i < frame_size; ++i ) for ( int i = 0; i < frame_size; ++i ) {
{ auto has_vec = broker::get_if<broker::vector>(body[i]);
auto has_vec = broker::get_if<broker::vector>(body[i]); if ( ! has_vec )
if ( ! has_vec ) continue;
continue;
broker::vector val_tuple = *has_vec; broker::vector val_tuple = *has_vec;
if ( val_tuple.size() != 2 ) if ( val_tuple.size() != 2 )
return std::make_pair(false, nullptr); return std::make_pair(false, nullptr);
auto has_type = broker::get_if<broker::integer>(val_tuple[1]); auto has_type = broker::get_if<broker::integer>(val_tuple[1]);
if ( ! has_type ) if ( ! has_type )
return std::make_pair(false, nullptr); return std::make_pair(false, nullptr);
broker::integer g = *has_type; broker::integer g = *has_type;
Type t(static_cast<TypeTag>(g)); Type t(static_cast<TypeTag>(g));
auto val = Broker::detail::data_to_val(std::move(val_tuple[0]), &t); auto val = Broker::detail::data_to_val(std::move(val_tuple[0]), &t);
if ( ! val ) if ( ! val )
return std::make_pair(false, nullptr); return std::make_pair(false, nullptr);
rf->frame[i] = std::move(val); rf->frame[i] = std::move(val);
} }
return std::make_pair(true, std::move(rf)); return std::make_pair(true, std::move(rf));
} }
const detail::Location* Frame::GetCallLocation() const const detail::Location* Frame::GetCallLocation() const {
{ // This is currently trivial, but we keep it as an explicit
// This is currently trivial, but we keep it as an explicit // method because it can provide flexibility for compiled code.
// method because it can provide flexibility for compiled code. return call->GetLocationInfo();
return call->GetLocationInfo(); }
}
void Frame::SetTrigger(trigger::TriggerPtr arg_trigger) void Frame::SetTrigger(trigger::TriggerPtr arg_trigger) { trigger = std::move(arg_trigger); }
{
trigger = std::move(arg_trigger);
}
void Frame::ClearTrigger() void Frame::ClearTrigger() { trigger = nullptr; }
{
trigger = nullptr;
}
} } // namespace zeek::detail

View file

@ -17,250 +17,244 @@
#include "zeek/ZeekArgs.h" #include "zeek/ZeekArgs.h"
#include "zeek/ZeekList.h" // for typedef val_list #include "zeek/ZeekList.h" // for typedef val_list
namespace zeek namespace zeek {
{
using ValPtr = IntrusivePtr<Val>; using ValPtr = IntrusivePtr<Val>;
namespace detail namespace detail {
{
class CallExpr; class CallExpr;
class ScriptFunc; class ScriptFunc;
using IDPtr = IntrusivePtr<ID>; using IDPtr = IntrusivePtr<ID>;
namespace trigger namespace trigger {
{
class Trigger; class Trigger;
using TriggerPtr = IntrusivePtr<Trigger>; using TriggerPtr = IntrusivePtr<Trigger>;
} } // namespace trigger
class Frame; class Frame;
using FramePtr = IntrusivePtr<Frame>; using FramePtr = IntrusivePtr<Frame>;
class Frame : public Obj class Frame : public Obj {
{
public: public:
/** /**
* Constructs a new frame belonging to *func* with *fn_args* * Constructs a new frame belonging to *func* with *fn_args*
* arguments. * arguments.
* *
* @param the size of the frame * @param the size of the frame
* @param func the function that is creating this frame * @param func the function that is creating this frame
* @param fn_args the arguments being passed to that function. * @param fn_args the arguments being passed to that function.
*/ */
Frame(int size, const ScriptFunc* func, const zeek::Args* fn_args); Frame(int size, const ScriptFunc* func, const zeek::Args* fn_args);
/** /**
* Returns the size of the frame. * Returns the size of the frame.
* *
* @return the number of elements in the frame. * @return the number of elements in the frame.
*/ */
int FrameSize() const { return size; } int FrameSize() const { return size; }
/** /**
* @param n the index to get. * @param n the index to get.
* @return the value at index *n* of the underlying array. * @return the value at index *n* of the underlying array.
*/ */
const ValPtr& GetElement(int n) const const ValPtr& GetElement(int n) const {
{ // Note: technically this may want to adjust by current_offset, but
// Note: technically this may want to adjust by current_offset, but // in practice, this method is never called from anywhere other than
// in practice, this method is never called from anywhere other than // function call invocation, where current_offset should be zero.
// function call invocation, where current_offset should be zero. return frame[n];
return frame[n]; }
}
/** /**
* Sets the element at index *n* of the underlying array to *v*. * Sets the element at index *n* of the underlying array to *v*.
* @param n the index to set * @param n the index to set
* @param v the value to set it to * @param v the value to set it to
*/ */
void SetElement(int n, ValPtr v); void SetElement(int n, ValPtr v);
/** /**
* Associates *id* and *v* in the frame. Future lookups of * Associates *id* and *v* in the frame. Future lookups of
* *id* will return *v*. * *id* will return *v*.
* *
* @param id the ID to associate * @param id the ID to associate
* @param v the value to associate it with * @param v the value to associate it with
*/ */
void SetElement(const ID* id, ValPtr v); void SetElement(const ID* id, ValPtr v);
void SetElement(const IDPtr& id, ValPtr v) { SetElement(id.get(), std::move(v)); } void SetElement(const IDPtr& id, ValPtr v) { SetElement(id.get(), std::move(v)); }
/** /**
* Gets the value associated with *id* and returns it. Returns * Gets the value associated with *id* and returns it. Returns
* nullptr if no such element exists. * nullptr if no such element exists.
* *
* @param id the id who's value to retrieve * @param id the id who's value to retrieve
* @return the value associated with *id* * @return the value associated with *id*
*/ */
const ValPtr& GetElementByID(const IDPtr& id) const { return GetElementByID(id.get()); } const ValPtr& GetElementByID(const IDPtr& id) const { return GetElementByID(id.get()); }
/** /**
* Adjusts the current offset being used for frame accesses. * Adjusts the current offset being used for frame accesses.
* This is in support of inlined functions. * This is in support of inlined functions.
* *
* @param incr Amount by which to increase the frame offset. * @param incr Amount by which to increase the frame offset.
* Use a negative value to shrink the offset. * Use a negative value to shrink the offset.
*/ */
void AdjustOffset(int incr) { current_offset += incr; } void AdjustOffset(int incr) { current_offset += incr; }
/** /**
* Resets all of the indexes from [*startIdx, frame_size) in * Resets all of the indexes from [*startIdx, frame_size) in
* the Frame. * the Frame.
* @param the first index to unref. * @param the first index to unref.
*/ */
void Reset(int startIdx); void Reset(int startIdx);
/** /**
* Describes the frame and all of its values. * Describes the frame and all of its values.
*/ */
void Describe(ODesc* d) const override; void Describe(ODesc* d) const override;
/** /**
* @return the function that the frame is associated with. * @return the function that the frame is associated with.
*/ */
const ScriptFunc* GetFunction() const { return function; } const ScriptFunc* GetFunction() const { return function; }
/** /**
* @return the arguments passed to the function that this frame * @return the arguments passed to the function that this frame
* is associated with. * is associated with.
*/ */
const Args* GetFuncArgs() const { return func_args; } const Args* GetFuncArgs() const { return func_args; }
/** /**
* Change the function that the frame is associated with. * Change the function that the frame is associated with.
* *
* @param func the function for the frame to be associated with. * @param func the function for the frame to be associated with.
*/ */
void SetFunction(ScriptFunc* func) { function = func; } void SetFunction(ScriptFunc* func) { function = func; }
/** /**
* Sets the next statement to be executed in the context of the frame. * Sets the next statement to be executed in the context of the frame.
* *
* @param stmt the statement to set it to. * @param stmt the statement to set it to.
*/ */
void SetNextStmt(Stmt* stmt) { next_stmt = stmt; } void SetNextStmt(Stmt* stmt) { next_stmt = stmt; }
/** /**
* @return the next statement to be executed in the context of the frame. * @return the next statement to be executed in the context of the frame.
*/ */
Stmt* GetNextStmt() const { return next_stmt; } Stmt* GetNextStmt() const { return next_stmt; }
/** Used to implement "next" command in debugger. */ /** Used to implement "next" command in debugger. */
void BreakBeforeNextStmt(bool should_break) { break_before_next_stmt = should_break; } void BreakBeforeNextStmt(bool should_break) { break_before_next_stmt = should_break; }
bool BreakBeforeNextStmt() const { return break_before_next_stmt; } bool BreakBeforeNextStmt() const { return break_before_next_stmt; }
/** Used to implement "finish" command in debugger. */ /** Used to implement "finish" command in debugger. */
void BreakOnReturn(bool should_break) { break_on_return = should_break; } void BreakOnReturn(bool should_break) { break_on_return = should_break; }
bool BreakOnReturn() const { return break_on_return; } bool BreakOnReturn() const { return break_on_return; }
/** /**
* Performs a deep copy of all the values in the current frame. * Performs a deep copy of all the values in the current frame.
* *
* @return a copy of this frame. * @return a copy of this frame.
*/ */
Frame* Clone() const; Frame* Clone() const;
/** /**
* Creates a copy of the frame that just includes its trigger context. * Creates a copy of the frame that just includes its trigger context.
* *
* @return a partial copy of this frame. * @return a partial copy of this frame.
*/ */
Frame* CloneForTrigger() const; Frame* CloneForTrigger() const;
/** /**
* Serializes the frame (only done for lambda/when captures) as a * Serializes the frame (only done for lambda/when captures) as a
* sequence of two-element vectors, the first element reflecting * sequence of two-element vectors, the first element reflecting
* the frame value, the second its type. * the frame value, the second its type.
*/ */
broker::expected<broker::data> Serialize(); broker::expected<broker::data> Serialize();
/** /**
* Instantiates a Frame from a serialized one. * Instantiates a Frame from a serialized one.
* *
* @return a pair in which the first item is the status of the serialization; * @return a pair in which the first item is the status of the serialization;
* and the second is the unserialized frame with reference count +1, or * and the second is the unserialized frame with reference count +1, or
* null if the serialization wasn't successful. * null if the serialization wasn't successful.
*/ */
static std::pair<bool, FramePtr> Unserialize(const broker::vector& data); static std::pair<bool, FramePtr> Unserialize(const broker::vector& data);
// If the frame is run in the context of a trigger condition evaluation, // If the frame is run in the context of a trigger condition evaluation,
// the trigger needs to be registered. // the trigger needs to be registered.
void SetTrigger(trigger::TriggerPtr arg_trigger); void SetTrigger(trigger::TriggerPtr arg_trigger);
void ClearTrigger(); void ClearTrigger();
trigger::Trigger* GetTrigger() const { return trigger.get(); } trigger::Trigger* GetTrigger() const { return trigger.get(); }
void SetCall(const CallExpr* arg_call) void SetCall(const CallExpr* arg_call) {
{ call = arg_call;
call = arg_call; SetTriggerAssoc((void*)call);
SetTriggerAssoc((void*)call); }
} void SetOnlyCall(const CallExpr* arg_call) { call = arg_call; }
void SetOnlyCall(const CallExpr* arg_call) { call = arg_call; } const CallExpr* GetCall() const { return call; }
const CallExpr* GetCall() const { return call; }
void SetTriggerAssoc(const void* arg_assoc) { assoc = arg_assoc; } void SetTriggerAssoc(const void* arg_assoc) { assoc = arg_assoc; }
const void* GetTriggerAssoc() const { return assoc; } const void* GetTriggerAssoc() const { return assoc; }
const detail::Location* GetCallLocation() const; const detail::Location* GetCallLocation() const;
void SetDelayed() { delayed = true; } void SetDelayed() { delayed = true; }
bool HasDelayed() const { return delayed; } bool HasDelayed() const { return delayed; }
private: private:
using OffsetMap = std::unordered_map<std::string, int>; using OffsetMap = std::unordered_map<std::string, int>;
// This has a trivial form now, but used to hold additional // This has a trivial form now, but used to hold additional
// information, which is why we abstract it away from just being // information, which is why we abstract it away from just being
// a ValPtr. // a ValPtr.
using Element = ValPtr; using Element = ValPtr;
const ValPtr& GetElementByID(const ID* id) const; const ValPtr& GetElementByID(const ID* id) const;
/** The number of vals that can be stored in this frame. */ /** The number of vals that can be stored in this frame. */
int size; int size;
bool break_before_next_stmt; bool break_before_next_stmt;
bool break_on_return; bool break_on_return;
bool delayed; bool delayed;
/** Associates ID's offsets with values. */ /** Associates ID's offsets with values. */
std::unique_ptr<Element[]> frame; std::unique_ptr<Element[]> frame;
/** /**
* The offset we're currently using for references into the frame. * The offset we're currently using for references into the frame.
* This is how we support inlined functions without having to * This is how we support inlined functions without having to
* alter the offsets associated with their local variables. * alter the offsets associated with their local variables.
*/ */
int current_offset; int current_offset;
/** Frame used for lambda/when captures. */ /** Frame used for lambda/when captures. */
Frame* captures; Frame* captures;
/** Maps IDs to offsets into the "captures" frame. If the ID /** Maps IDs to offsets into the "captures" frame. If the ID
* isn't present, then it's not a capture. * isn't present, then it's not a capture.
*/ */
const OffsetMap* captures_offset_map; const OffsetMap* captures_offset_map;
/** The function this frame is associated with. */ /** The function this frame is associated with. */
const ScriptFunc* function; const ScriptFunc* function;
// The following is only needed for the debugger. // The following is only needed for the debugger.
/** The arguments to the function that this Frame is associated with. */ /** The arguments to the function that this Frame is associated with. */
const zeek::Args* func_args; const zeek::Args* func_args;
/** The next statement to be evaluated in the context of this frame. */ /** The next statement to be evaluated in the context of this frame. */
Stmt* next_stmt; Stmt* next_stmt;
trigger::TriggerPtr trigger; trigger::TriggerPtr trigger;
const CallExpr* call = nullptr; const CallExpr* call = nullptr;
const void* assoc = nullptr; const void* assoc = nullptr;
}; };
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek
/** /**
* If we stopped using this and instead just made a struct of the information * If we stopped using this and instead just made a struct of the information

File diff suppressed because it is too large Load diff

View file

@ -18,21 +18,19 @@
#include "zeek/ZeekArgs.h" #include "zeek/ZeekArgs.h"
#include "zeek/ZeekList.h" #include "zeek/ZeekList.h"
namespace broker namespace broker {
{
class data; class data;
using vector = std::vector<data>; using vector = std::vector<data>;
template <class> class expected; template<class>
} class expected;
} // namespace broker
namespace zeek namespace zeek {
{
class Val; class Val;
class FuncType; class FuncType;
namespace detail namespace detail {
{
class Scope; class Scope;
class Stmt; class Stmt;
@ -46,7 +44,7 @@ using StmtPtr = IntrusivePtr<Stmt>;
class ScriptFunc; class ScriptFunc;
class FunctionIngredients; class FunctionIngredients;
} // namespace detail } // namespace detail
class EventGroup; class EventGroup;
using EventGroupPtr = std::shared_ptr<EventGroup>; using EventGroupPtr = std::shared_ptr<EventGroup>;
@ -54,350 +52,329 @@ using EventGroupPtr = std::shared_ptr<EventGroup>;
class Func; class Func;
using FuncPtr = IntrusivePtr<Func>; using FuncPtr = IntrusivePtr<Func>;
class Func : public Obj class Func : public Obj {
{
public: public:
static inline const FuncPtr nil; static inline const FuncPtr nil;
enum Kind enum Kind { SCRIPT_FUNC, BUILTIN_FUNC };
{
SCRIPT_FUNC,
BUILTIN_FUNC
};
explicit Func(Kind arg_kind) : kind(arg_kind) { } explicit Func(Kind arg_kind) : kind(arg_kind) {}
virtual bool IsPure() const = 0; virtual bool IsPure() const = 0;
FunctionFlavor Flavor() const { return GetType()->Flavor(); } FunctionFlavor Flavor() const { return GetType()->Flavor(); }
struct Body struct Body {
{ detail::StmtPtr stmts;
detail::StmtPtr stmts; int priority;
int priority; std::set<EventGroupPtr> groups;
std::set<EventGroupPtr> groups; // If any of the groups are disabled, this body is disabled.
// If any of the groups are disabled, this body is disabled. // The disabled field is updated from EventGroup instances.
// The disabled field is updated from EventGroup instances. bool disabled = false;
bool disabled = false;
bool operator<(const Body& other) const bool operator<(const Body& other) const { return priority > other.priority; } // reverse sort
{ };
return priority > other.priority;
} // reverse sort
};
const std::vector<Body>& GetBodies() const { return bodies; } const std::vector<Body>& GetBodies() const { return bodies; }
bool HasBodies() const { return ! bodies.empty(); } bool HasBodies() const { return ! bodies.empty(); }
/** /**
* Are there bodies and is any one of them enabled? * Are there bodies and is any one of them enabled?
* *
* @return true if bodies exist and at least one is enabled. * @return true if bodies exist and at least one is enabled.
*/ */
bool HasEnabledBodies() const { return ! bodies.empty() && has_enabled_bodies; }; bool HasEnabledBodies() const { return ! bodies.empty() && has_enabled_bodies; };
/** /**
* Calls a Zeek function. * Calls a Zeek function.
* @param args the list of arguments to the function call. * @param args the list of arguments to the function call.
* @param parent the frame from which the function is being called. * @param parent the frame from which the function is being called.
* @return the return value of the function call. * @return the return value of the function call.
*/ */
virtual ValPtr Invoke(zeek::Args* args, detail::Frame* parent = nullptr) const = 0; virtual ValPtr Invoke(zeek::Args* args, detail::Frame* parent = nullptr) const = 0;
/** /**
* A version of Invoke() taking a variable number of individual arguments. * A version of Invoke() taking a variable number of individual arguments.
*/ */
template <class... Args> template<class... Args>
std::enable_if_t<std::is_convertible_v<std::tuple_element_t<0, std::tuple<Args...>>, ValPtr>, std::enable_if_t<std::is_convertible_v<std::tuple_element_t<0, std::tuple<Args...>>, ValPtr>, ValPtr> Invoke(
ValPtr> Args&&... args) const {
Invoke(Args&&... args) const auto zargs = zeek::Args{std::forward<Args>(args)...};
{ return Invoke(&zargs);
auto zargs = zeek::Args{std::forward<Args>(args)...}; }
return Invoke(&zargs);
}
// Various ways to add a new event handler to an existing function // Various ways to add a new event handler to an existing function
// (event). The usual version to use is the first with its default // (event). The usual version to use is the first with its default
// parameter. All of the others are for use by script optimization, // parameter. All of the others are for use by script optimization,
// as is a non-default second parameter to the first method, which // as is a non-default second parameter to the first method, which
// overrides the function body in "ingr". // overrides the function body in "ingr".
void AddBody(const detail::FunctionIngredients& ingr, detail::StmtPtr new_body = nullptr); void AddBody(const detail::FunctionIngredients& ingr, detail::StmtPtr new_body = nullptr);
virtual void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, virtual void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, size_t new_frame_size,
size_t new_frame_size, int priority, int priority, const std::set<EventGroupPtr>& groups);
const std::set<EventGroupPtr>& groups); void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, size_t new_frame_size,
void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, int priority = 0);
size_t new_frame_size, int priority = 0); void AddBody(detail::StmtPtr new_body, size_t new_frame_size);
void AddBody(detail::StmtPtr new_body, size_t new_frame_size);
virtual void SetScope(detail::ScopePtr newscope); virtual void SetScope(detail::ScopePtr newscope);
virtual detail::ScopePtr GetScope() const { return scope; } virtual detail::ScopePtr GetScope() const { return scope; }
const FuncTypePtr& GetType() const { return type; } const FuncTypePtr& GetType() const { return type; }
Kind GetKind() const { return kind; } Kind GetKind() const { return kind; }
const char* Name() const { return name.c_str(); } const char* Name() const { return name.c_str(); }
void SetName(const char* arg_name) { name = arg_name; } void SetName(const char* arg_name) { name = arg_name; }
void Describe(ODesc* d) const override = 0; void Describe(ODesc* d) const override = 0;
virtual void DescribeDebug(ODesc* d, const zeek::Args* args) const; virtual void DescribeDebug(ODesc* d, const zeek::Args* args) const;
virtual FuncPtr DoClone(); virtual FuncPtr DoClone();
virtual detail::TraversalCode Traverse(detail::TraversalCallback* cb) const; virtual detail::TraversalCode Traverse(detail::TraversalCallback* cb) const;
protected: protected:
Func() = default; Func() = default;
// Copies this function's state into other. // Copies this function's state into other.
void CopyStateInto(Func* other) const; void CopyStateInto(Func* other) const;
// Helper function for checking result of plugin hook. // Helper function for checking result of plugin hook.
void CheckPluginResult(bool handled, const ValPtr& hook_result, FunctionFlavor flavor) const; void CheckPluginResult(bool handled, const ValPtr& hook_result, FunctionFlavor flavor) const;
std::vector<Body> bodies; std::vector<Body> bodies;
detail::ScopePtr scope; detail::ScopePtr scope;
Kind kind = SCRIPT_FUNC; Kind kind = SCRIPT_FUNC;
FuncTypePtr type; FuncTypePtr type;
std::string name; std::string name;
private: private:
// EventGroup updates Func::Body.disabled and has_enabled_bodies. // EventGroup updates Func::Body.disabled and has_enabled_bodies.
// This is friend/private with EventGroup here so that we do not // This is friend/private with EventGroup here so that we do not
// expose accessors in the zeek:: public interface. // expose accessors in the zeek:: public interface.
friend class EventGroup; friend class EventGroup;
bool has_enabled_bodies = true; bool has_enabled_bodies = true;
}; };
namespace detail namespace detail {
{
class ScriptFunc : public Func class ScriptFunc : public Func {
{
public: public:
ScriptFunc(const IDPtr& id); ScriptFunc(const IDPtr& id);
// For compiled scripts. // For compiled scripts.
ScriptFunc(std::string name, FuncTypePtr ft, std::vector<StmtPtr> bodies, ScriptFunc(std::string name, FuncTypePtr ft, std::vector<StmtPtr> bodies, std::vector<int> priorities);
std::vector<int> priorities);
~ScriptFunc() override; ~ScriptFunc() override;
bool IsPure() const override; bool IsPure() const override;
ValPtr Invoke(zeek::Args* args, Frame* parent) const override; ValPtr Invoke(zeek::Args* args, Frame* parent) const override;
/** /**
* Creates a separate frame for captures and initializes its * Creates a separate frame for captures and initializes its
* elements. The list of captures comes from the ScriptFunc's * elements. The list of captures comes from the ScriptFunc's
* type, so doesn't need to be passed in, just the frame to * type, so doesn't need to be passed in, just the frame to
* use in evaluating the identifiers. * use in evaluating the identifiers.
* *
* @param f the frame used for evaluating the captured identifiers * @param f the frame used for evaluating the captured identifiers
*/ */
void CreateCaptures(Frame* f); void CreateCaptures(Frame* f);
/** /**
* Uses the given set of ZVal's for captures. Note that this is * Uses the given set of ZVal's for captures. Note that this is
* different from the method above, which uses its argument to * different from the method above, which uses its argument to
* compute the captures, rather than here where they are pre-computed. * compute the captures, rather than here where they are pre-computed.
* *
* Makes deep copies if required. * Makes deep copies if required.
* *
* @param cvec a vector of ZVal's corresponding to the captures. * @param cvec a vector of ZVal's corresponding to the captures.
*/ */
void CreateCaptures(std::unique_ptr<std::vector<ZVal>> cvec); void CreateCaptures(std::unique_ptr<std::vector<ZVal>> cvec);
/** /**
* Returns the frame associated with this function for tracking * Returns the frame associated with this function for tracking
* captures, or nil if there isn't one. * captures, or nil if there isn't one.
* *
* @return internal frame kept by the function for persisting captures * @return internal frame kept by the function for persisting captures
*/ */
Frame* GetCapturesFrame() const { return captures_frame; } Frame* GetCapturesFrame() const { return captures_frame; }
/** /**
* Returns the set of ZVal's used for captures. It's okay to modify * Returns the set of ZVal's used for captures. It's okay to modify
* these as long as memory-management is done for managed entries. * these as long as memory-management is done for managed entries.
* *
* @return internal vector of ZVal's kept for persisting captures * @return internal vector of ZVal's kept for persisting captures
*/ */
auto& GetCapturesVec() const auto& GetCapturesVec() const {
{ ASSERT(captures_vec);
ASSERT(captures_vec); return *captures_vec;
return *captures_vec; }
}
// Same definition as in Frame.h. // Same definition as in Frame.h.
using OffsetMap = std::unordered_map<std::string, int>; using OffsetMap = std::unordered_map<std::string, int>;
/** /**
* Returns the mapping of captures to slots in the captures frame. * Returns the mapping of captures to slots in the captures frame.
* *
* @return pointer to mapping of captures to slots * @return pointer to mapping of captures to slots
*/ */
const OffsetMap* GetCapturesOffsetMap() const { return captures_offset_mapping; } const OffsetMap* GetCapturesOffsetMap() const { return captures_offset_mapping; }
/** /**
* Serializes this function's capture frame. * Serializes this function's capture frame.
* *
* @return a serialized version of the function's capture frame. * @return a serialized version of the function's capture frame.
*/ */
virtual broker::expected<broker::data> SerializeCaptures() const; virtual broker::expected<broker::data> SerializeCaptures() const;
/** /**
* Sets the captures frame to one built from *data*. * Sets the captures frame to one built from *data*.
* *
* @param data a serialized frame * @param data a serialized frame
*/ */
bool DeserializeCaptures(const broker::vector& data); bool DeserializeCaptures(const broker::vector& data);
using Func::AddBody; using Func::AddBody;
void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, size_t new_frame_size,
size_t new_frame_size, int priority, int priority, const std::set<EventGroupPtr>& groups) override;
const std::set<EventGroupPtr>& groups) override;
/** /**
* Replaces the given current instance of a function body with * Replaces the given current instance of a function body with
* a new one. If new_body is nil then the current instance is * a new one. If new_body is nil then the current instance is
* deleted with no replacement. * deleted with no replacement.
* *
* @param old_body Body to replace. * @param old_body Body to replace.
* @param new_body New body to use; can be nil. * @param new_body New body to use; can be nil.
*/ */
void ReplaceBody(const detail::StmtPtr& old_body, detail::StmtPtr new_body); void ReplaceBody(const detail::StmtPtr& old_body, detail::StmtPtr new_body);
StmtPtr CurrentBody() const { return current_body; } StmtPtr CurrentBody() const { return current_body; }
int CurrentPriority() const { return current_priority; } int CurrentPriority() const { return current_priority; }
/** /**
* Returns the function's frame size. * Returns the function's frame size.
* @return The number of ValPtr slots in the function's frame. * @return The number of ValPtr slots in the function's frame.
*/ */
int FrameSize() const { return frame_size; } int FrameSize() const { return frame_size; }
/** /**
* Changes the function's frame size to a new size - used for * Changes the function's frame size to a new size - used for
* script optimization/compilation. * script optimization/compilation.
* *
* @param new_size The frame size the function should use. * @param new_size The frame size the function should use.
*/ */
void SetFrameSize(int new_size) { frame_size = new_size; } void SetFrameSize(int new_size) { frame_size = new_size; }
/** Sets this function's outer_id list. */ /** Sets this function's outer_id list. */
void SetOuterIDs(IDPList ids) { outer_ids = std::move(ids); } void SetOuterIDs(IDPList ids) { outer_ids = std::move(ids); }
void Describe(ODesc* d) const override; void Describe(ODesc* d) const override;
protected: protected:
ScriptFunc() : Func(SCRIPT_FUNC) { } ScriptFunc() : Func(SCRIPT_FUNC) {}
StmtPtr AddInits(StmtPtr body, const std::vector<IDPtr>& inits); StmtPtr AddInits(StmtPtr body, const std::vector<IDPtr>& inits);
/** /**
* Clones this function along with its captures. * Clones this function along with its captures.
*/ */
FuncPtr DoClone() override; FuncPtr DoClone() override;
/** /**
* Uses the given frame for captures, and generates the * Uses the given frame for captures, and generates the
* mapping from captured variables to offsets in the frame. * mapping from captured variables to offsets in the frame.
* Virtual so it can be modified for script optimization uses. * Virtual so it can be modified for script optimization uses.
* *
* @param f the frame holding the values of capture variables * @param f the frame holding the values of capture variables
*/ */
virtual void SetCaptures(Frame* f); virtual void SetCaptures(Frame* f);
private: private:
size_t frame_size = 0; size_t frame_size = 0;
// List of the outer IDs used in the function. // List of the outer IDs used in the function.
IDPList outer_ids; IDPList outer_ids;
// Frame for (capture-by-copy) closures. These persist over the // Frame for (capture-by-copy) closures. These persist over the
// function's lifetime, providing quasi-globals that maintain // function's lifetime, providing quasi-globals that maintain
// state across individual calls to the function. // state across individual calls to the function.
Frame* captures_frame = nullptr; Frame* captures_frame = nullptr;
OffsetMap* captures_offset_mapping = nullptr; OffsetMap* captures_offset_mapping = nullptr;
// Captures when using ZVal block instead of a Frame. // Captures when using ZVal block instead of a Frame.
std::unique_ptr<std::vector<ZVal>> captures_vec; std::unique_ptr<std::vector<ZVal>> captures_vec;
// The most recently added/updated body ... // The most recently added/updated body ...
StmtPtr current_body; StmtPtr current_body;
// ... and its priority. // ... and its priority.
int current_priority = 0; int current_priority = 0;
}; };
using built_in_func = BifReturnVal (*)(Frame* frame, const Args* args); using built_in_func = BifReturnVal (*)(Frame* frame, const Args* args);
class BuiltinFunc final : public Func class BuiltinFunc final : public Func {
{
public: public:
BuiltinFunc(built_in_func func, const char* name, bool is_pure); BuiltinFunc(built_in_func func, const char* name, bool is_pure);
~BuiltinFunc() override = default; ~BuiltinFunc() override = default;
bool IsPure() const override; bool IsPure() const override;
ValPtr Invoke(zeek::Args* args, Frame* parent) const override; ValPtr Invoke(zeek::Args* args, Frame* parent) const override;
built_in_func TheFunc() const { return func; } built_in_func TheFunc() const { return func; }
void Describe(ODesc* d) const override; void Describe(ODesc* d) const override;
protected: protected:
BuiltinFunc() BuiltinFunc() {
{ func = nullptr;
func = nullptr; is_pure = 0;
is_pure = 0; }
}
built_in_func func; built_in_func func;
bool is_pure; bool is_pure;
}; };
extern bool check_built_in_call(BuiltinFunc* f, CallExpr* call); extern bool check_built_in_call(BuiltinFunc* f, CallExpr* call);
struct CallInfo struct CallInfo {
{ const CallExpr* call;
const CallExpr* call; const Func* func;
const Func* func; const zeek::Args& args;
const zeek::Args& args; };
};
// Class that collects all the specifics defining a Func. // Class that collects all the specifics defining a Func.
class FunctionIngredients class FunctionIngredients {
{
public: public:
// Gathers all of the information from a scope and a function body needed // Gathers all of the information from a scope and a function body needed
// to build a function. // to build a function.
FunctionIngredients(ScopePtr scope, StmtPtr body, const std::string& module_name); FunctionIngredients(ScopePtr scope, StmtPtr body, const std::string& module_name);
const IDPtr& GetID() const { return id; } const IDPtr& GetID() const { return id; }
const StmtPtr& Body() const { return body; } const StmtPtr& Body() const { return body; }
void ReplaceBody(StmtPtr new_body) { body = std::move(new_body); } void ReplaceBody(StmtPtr new_body) { body = std::move(new_body); }
const auto& Inits() const { return inits; } const auto& Inits() const { return inits; }
void ClearInits() { inits.clear(); } void ClearInits() { inits.clear(); }
size_t FrameSize() const { return frame_size; } size_t FrameSize() const { return frame_size; }
int Priority() const { return priority; } int Priority() const { return priority; }
const ScopePtr& Scope() const { return scope; } const ScopePtr& Scope() const { return scope; }
const auto& Groups() const { return groups; } const auto& Groups() const { return groups; }
// Used by script optimization to update lambda ingredients // Used by script optimization to update lambda ingredients
// after compilation. // after compilation.
void SetFrameSize(size_t _frame_size) { frame_size = _frame_size; } void SetFrameSize(size_t _frame_size) { frame_size = _frame_size; }
private: private:
IDPtr id; IDPtr id;
StmtPtr body; StmtPtr body;
std::vector<IDPtr> inits; std::vector<IDPtr> inits;
size_t frame_size = 0; size_t frame_size = 0;
int priority = 0; int priority = 0;
ScopePtr scope; ScopePtr scope;
std::set<EventGroupPtr> groups; std::set<EventGroupPtr> groups;
}; };
using FunctionIngredientsPtr = std::shared_ptr<FunctionIngredients>; using FunctionIngredientsPtr = std::shared_ptr<FunctionIngredients>;
@ -427,19 +404,18 @@ extern bool did_builtin_init;
extern std::vector<void (*)()> bif_initializers; extern std::vector<void (*)()> bif_initializers;
extern void init_primary_bifs(); extern void init_primary_bifs();
inline void run_bif_initializers() inline void run_bif_initializers() {
{ for ( const auto& bi : bif_initializers )
for ( const auto& bi : bif_initializers ) bi();
bi();
bif_initializers = {}; bif_initializers = {};
} }
extern void emit_builtin_exception(const char* msg); extern void emit_builtin_exception(const char* msg);
extern void emit_builtin_exception(const char* msg, const ValPtr& arg); extern void emit_builtin_exception(const char* msg, const ValPtr& arg);
extern void emit_builtin_exception(const char* msg, Obj* arg); extern void emit_builtin_exception(const char* msg, Obj* arg);
} // namespace detail } // namespace detail
extern std::string render_call_stack(); extern std::string render_call_stack();
@ -448,4 +424,4 @@ extern void emit_builtin_error(const char* msg);
extern void emit_builtin_error(const char* msg, const ValPtr&); extern void emit_builtin_error(const char* msg, const ValPtr&);
extern void emit_builtin_error(const char* msg, Obj* arg); extern void emit_builtin_error(const char* msg, Obj* arg);
} // namespace zeek } // namespace zeek

File diff suppressed because it is too large Load diff

View file

@ -27,368 +27,356 @@
// to allow md5_hmac_bif access to the hmac seed // to allow md5_hmac_bif access to the hmac seed
#include "zeek/ZeekArgs.h" #include "zeek/ZeekArgs.h"
namespace zeek namespace zeek {
{
class String; class String;
class ODesc; class ODesc;
} } // namespace zeek
namespace zeek::detail namespace zeek::detail {
{
class Frame; class Frame;
class BifReturnVal; class BifReturnVal;
} } // namespace zeek::detail
namespace zeek::BifFunc namespace zeek::BifFunc {
{
extern zeek::detail::BifReturnVal md5_hmac_bif(zeek::detail::Frame* frame, const zeek::Args*); extern zeek::detail::BifReturnVal md5_hmac_bif(zeek::detail::Frame* frame, const zeek::Args*);
} }
namespace zeek::detail namespace zeek::detail {
{
using hash_t = uint64_t; using hash_t = uint64_t;
using hash64_t = uint64_t; using hash64_t = uint64_t;
using hash128_t = uint64_t[2]; using hash128_t = uint64_t[2];
using hash256_t = uint64_t[4]; using hash256_t = uint64_t[4];
class KeyedHash class KeyedHash {
{
public: public:
/** /**
* Generate a 64 bit digest hash. * Generate a 64 bit digest hash.
* *
* This hash is seeded with random data, unless the ZEEK_SEED_FILE environment * This hash is seeded with random data, unless the ZEEK_SEED_FILE environment
* variable is set. Thus, typically every node will return a different hash * variable is set. Thus, typically every node will return a different hash
* after every restart. * after every restart.
* *
* This should be used for internal hashes that do not have to be stable over * This should be used for internal hashes that do not have to be stable over
* the cluster/runs - like, e.g. connection ID generation. * the cluster/runs - like, e.g. connection ID generation.
* *
* @param bytes Bytes to hash * @param bytes Bytes to hash
* *
* @param size Size of bytes * @param size Size of bytes
* *
* @returns 64 bit digest hash * @returns 64 bit digest hash
*/ */
static hash64_t Hash64(const void* bytes, uint64_t size); static hash64_t Hash64(const void* bytes, uint64_t size);
/** /**
* Generate a 128 bit digest hash. * Generate a 128 bit digest hash.
* *
* This hash is seeded with random data, unless the ZEEK_SEED_FILE environment * This hash is seeded with random data, unless the ZEEK_SEED_FILE environment
* variable is set. Thus, typically every node will return a different hash * variable is set. Thus, typically every node will return a different hash
* after every restart. * after every restart.
* *
* This should be used for internal hashes that do not have to be stable over * This should be used for internal hashes that do not have to be stable over
* the cluster/runs - like, e.g. connection ID generation. * the cluster/runs - like, e.g. connection ID generation.
* *
* @param bytes Bytes to hash * @param bytes Bytes to hash
* *
* @param size Size of bytes * @param size Size of bytes
* *
* @param result Result of the hashing operation. * @param result Result of the hashing operation.
*/ */
static void Hash128(const void* bytes, uint64_t size, hash128_t* result); static void Hash128(const void* bytes, uint64_t size, hash128_t* result);
/** /**
* Generate a 256 bit digest hash. * Generate a 256 bit digest hash.
* *
* This hash is seeded with random data, unless the ZEEK_SEED_FILE environment * This hash is seeded with random data, unless the ZEEK_SEED_FILE environment
* variable is set. Thus, typically every node will return a different hash * variable is set. Thus, typically every node will return a different hash
* after every restart. * after every restart.
* *
* This should be used for internal hashes that do not have to be stable over * This should be used for internal hashes that do not have to be stable over
* the cluster/runs - like, e.g. connection ID generation. * the cluster/runs - like, e.g. connection ID generation.
* *
* @param bytes Bytes to hash * @param bytes Bytes to hash
* *
* @param size Size of bytes * @param size Size of bytes
* *
* @param result Result of the hashing operation. * @param result Result of the hashing operation.
*/ */
static void Hash256(const void* bytes, uint64_t size, hash256_t* result); static void Hash256(const void* bytes, uint64_t size, hash256_t* result);
/** /**
* Generates a installation-specific 64 bit hash. * Generates a installation-specific 64 bit hash.
* *
* This function generates a 64 bit digest hash, which is stable over a cluster * This function generates a 64 bit digest hash, which is stable over a cluster
* or a restart. * or a restart.
* *
* To be more exact - the seed value for this hash is generated from the script-level * To be more exact - the seed value for this hash is generated from the script-level
* :zeek:see:`digest_salt` constant. The seeds are stable as long as this value * :zeek:see:`digest_salt` constant. The seeds are stable as long as this value
* is not changed. * is not changed.
* *
* This should be used for hashes that have to remain stable over the entire * This should be used for hashes that have to remain stable over the entire
* cluster. An example are file IDs, which have to be stable over several workers. * cluster. An example are file IDs, which have to be stable over several workers.
* *
* @param bytes Bytes to hash * @param bytes Bytes to hash
* *
* @param size Size of bytes * @param size Size of bytes
* *
* @returns 64 bit digest hash * @returns 64 bit digest hash
*/ */
static hash64_t StaticHash64(const void* bytes, uint64_t size); static hash64_t StaticHash64(const void* bytes, uint64_t size);
/** /**
* Generates a installation-specific 128 bit hash. * Generates a installation-specific 128 bit hash.
* *
* This function generates a 128 bit digest hash, which is stable over a cluster * This function generates a 128 bit digest hash, which is stable over a cluster
* or a restart. * or a restart.
* *
* To be more exact - the seed value for this hash is generated from the script-level * To be more exact - the seed value for this hash is generated from the script-level
* :zeek:see:`digest_salt` constant. The seeds are stable as long as this value * :zeek:see:`digest_salt` constant. The seeds are stable as long as this value
* is not changed. * is not changed.
* *
* This should be used for hashes that have to remain stable over the entire * This should be used for hashes that have to remain stable over the entire
* cluster. An example are file IDs, which have to be stable over several workers. * cluster. An example are file IDs, which have to be stable over several workers.
* *
* @param bytes Bytes to hash * @param bytes Bytes to hash
* *
* @param size Size of bytes * @param size Size of bytes
* *
* @param result Result of the hashing operation. * @param result Result of the hashing operation.
*/ */
static void StaticHash128(const void* bytes, uint64_t size, hash128_t* result); static void StaticHash128(const void* bytes, uint64_t size, hash128_t* result);
/** /**
* Generates a installation-specific 256 bit hash. * Generates a installation-specific 256 bit hash.
* *
* This function generates a 128 bit digest hash, which is stable over a cluster * This function generates a 128 bit digest hash, which is stable over a cluster
* or a restart. * or a restart.
* *
* To be more exact - the seed value for this hash is generated from the script-level * To be more exact - the seed value for this hash is generated from the script-level
* :zeek:see:`digest_salt` constant. The seeds are stable as long as this value * :zeek:see:`digest_salt` constant. The seeds are stable as long as this value
* is not changed. * is not changed.
* *
* This should be used for hashes that have to remain stable over the entire * This should be used for hashes that have to remain stable over the entire
* cluster. An example are file IDs, which have to be stable over several workers. * cluster. An example are file IDs, which have to be stable over several workers.
* *
* @param bytes Bytes to hash * @param bytes Bytes to hash
* *
* @param size Size of bytes * @param size Size of bytes
* *
* @param result Result of the hashing operation. * @param result Result of the hashing operation.
*/ */
static void StaticHash256(const void* bytes, uint64_t size, hash256_t* result); static void StaticHash256(const void* bytes, uint64_t size, hash256_t* result);
/** /**
* Size of the initial seed * Size of the initial seed
*/ */
constexpr static int SEED_INIT_SIZE = 20; constexpr static int SEED_INIT_SIZE = 20;
/** /**
* Initialize the (typically process-specific) seeds. This function is indirectly * Initialize the (typically process-specific) seeds. This function is indirectly
* called from main, during early initialization. * called from main, during early initialization.
* *
* @param seed_data random data used as an initial seed * @param seed_data random data used as an initial seed
*/ */
static void InitializeSeeds(const std::array<uint32_t, SEED_INIT_SIZE>& seed_data); static void InitializeSeeds(const std::array<uint32_t, SEED_INIT_SIZE>& seed_data);
/** /**
* Returns true if the process-specific seeds have been initialized * Returns true if the process-specific seeds have been initialized
* *
* @return True if the seeds are initialized * @return True if the seeds are initialized
*/ */
static bool IsInitialized() { return seeds_initialized; } static bool IsInitialized() { return seeds_initialized; }
/** /**
* Initializes the static hash seeds using the script-level * Initializes the static hash seeds using the script-level
* :zeek:see:`digest_salt` constant. * :zeek:see:`digest_salt` constant.
*/ */
static void InitOptions(); static void InitOptions();
private: private:
// actually HHKey. This key changes each start (unless a seed is specified) // actually HHKey. This key changes each start (unless a seed is specified)
alignas(32) static uint64_t shared_highwayhash_key[4]; alignas(32) static uint64_t shared_highwayhash_key[4];
// actually HHKey. This key is installation specific and sourced from the digest_salt // actually HHKey. This key is installation specific and sourced from the digest_salt
// script-level const. // script-level const.
alignas(32) static uint64_t cluster_highwayhash_key[4]; alignas(32) static uint64_t cluster_highwayhash_key[4];
// actually HH_U64, which has the same type. This key changes each start (unless a seed is // actually HH_U64, which has the same type. This key changes each start (unless a seed is
// specified) // specified)
alignas(16) static unsigned long long shared_siphash_key[2]; alignas(16) static unsigned long long shared_siphash_key[2];
// This key changes each start (unless a seed is specified) // This key changes each start (unless a seed is specified)
inline static uint8_t shared_hmac_md5_key[16]; inline static uint8_t shared_hmac_md5_key[16];
inline static bool seeds_initialized = false; inline static bool seeds_initialized = false;
friend void util::detail::hmac_md5(size_t size, const unsigned char* bytes, friend void util::detail::hmac_md5(size_t size, const unsigned char* bytes, unsigned char digest[16]);
unsigned char digest[16]); friend BifReturnVal BifFunc::md5_hmac_bif(zeek::detail::Frame* frame, const Args*);
friend BifReturnVal BifFunc::md5_hmac_bif(zeek::detail::Frame* frame, const Args*); };
};
enum HashKeyTag enum HashKeyTag { HASH_KEY_INT, HASH_KEY_DOUBLE, HASH_KEY_STRING };
{
HASH_KEY_INT,
HASH_KEY_DOUBLE,
HASH_KEY_STRING
};
constexpr int NUM_HASH_KEYS = HASH_KEY_STRING + 1; constexpr int NUM_HASH_KEYS = HASH_KEY_STRING + 1;
class HashKey class HashKey {
{
public: public:
explicit HashKey() { key_u.u32 = 0; } explicit HashKey() { key_u.u32 = 0; }
explicit HashKey(bool b); explicit HashKey(bool b);
explicit HashKey(int i); explicit HashKey(int i);
explicit HashKey(zeek_int_t bi); explicit HashKey(zeek_int_t bi);
explicit HashKey(zeek_uint_t bu); explicit HashKey(zeek_uint_t bu);
explicit HashKey(uint32_t u); explicit HashKey(uint32_t u);
HashKey(const uint32_t u[], size_t n); HashKey(const uint32_t u[], size_t n);
explicit HashKey(double d); explicit HashKey(double d);
explicit HashKey(const void* p); explicit HashKey(const void* p);
explicit HashKey(const char* s); // No copying, no ownership explicit HashKey(const char* s); // No copying, no ownership
explicit HashKey(const String* s); // No copying, no ownership explicit HashKey(const String* s); // No copying, no ownership
// Builds a key from the given chunk of bytes. Copies the data. // Builds a key from the given chunk of bytes. Copies the data.
HashKey(const void* bytes, size_t size); HashKey(const void* bytes, size_t size);
// Create a HashKey given all of its components. Copies the key. // Create a HashKey given all of its components. Copies the key.
HashKey(const void* key, size_t size, hash_t hash); HashKey(const void* key, size_t size, hash_t hash);
// Create a Hashkey given all of its components *without* // Create a Hashkey given all of its components *without*
// copying the key and *without* taking ownership. Note that // copying the key and *without* taking ownership. Note that
// "dont_copy" is a type placeholder to differentiate this member // "dont_copy" is a type placeholder to differentiate this member
// function from the one above; its value is not used. // function from the one above; its value is not used.
HashKey(const void* key, size_t size, hash_t hash, bool dont_copy); HashKey(const void* key, size_t size, hash_t hash, bool dont_copy);
// Copy constructor. Always copies the key. // Copy constructor. Always copies the key.
HashKey(const HashKey& other); HashKey(const HashKey& other);
// Move constructor. Takes ownership of the key. // Move constructor. Takes ownership of the key.
HashKey(HashKey&& other) noexcept; HashKey(HashKey&& other) noexcept;
// Destructor // Destructor
~HashKey(); ~HashKey();
// Hands over the key to the caller. This means that if the // Hands over the key to the caller. This means that if the
// key is our dynamic, we give it to the caller and mark it // key is our dynamic, we give it to the caller and mark it
// as not our dynamic. If initially it's not our dynamic, // as not our dynamic. If initially it's not our dynamic,
// we give them a copy of it. // we give them a copy of it.
void* TakeKey(); void* TakeKey();
const void* Key() const { return key; } const void* Key() const { return key; }
size_t Size() const { return size; } size_t Size() const { return size; }
hash_t Hash() const; hash_t Hash() const;
static hash_t HashBytes(const void* bytes, size_t size); static hash_t HashBytes(const void* bytes, size_t size);
// A HashKey is "allocated" when the underlying key points somewhere // A HashKey is "allocated" when the underlying key points somewhere
// other than our internal key_u union. This is almost like // other than our internal key_u union. This is almost like
// is_our_dynamic, but remains true also after TakeKey(). // is_our_dynamic, but remains true also after TakeKey().
bool IsAllocated() const bool IsAllocated() const { return (key != nullptr && key != reinterpret_cast<const char*>(&key_u)); }
{
return (key != nullptr && key != reinterpret_cast<const char*>(&key_u));
}
// Buffer size reservation. Repeated calls to these methods // Buffer size reservation. Repeated calls to these methods
// incrementally build up the eventual buffer size to be allocated via // incrementally build up the eventual buffer size to be allocated via
// Allocate(). // Allocate().
template <typename T> void ReserveType(const char* tag) { Reserve(tag, sizeof(T), sizeof(T)); } template<typename T>
void Reserve(const char* tag, size_t addl_size, size_t alignment = 0); void ReserveType(const char* tag) {
Reserve(tag, sizeof(T), sizeof(T));
}
void Reserve(const char* tag, size_t addl_size, size_t alignment = 0);
// Allocates the reserved amount of memory // Allocates the reserved amount of memory
void Allocate(); void Allocate();
// Incremental writes into an allocated HashKey. The tags give context // Incremental writes into an allocated HashKey. The tags give context
// to what's being written and are only used in debug-build log streams. // to what's being written and are only used in debug-build log streams.
// When true, the alignment boolean will cause write-marker alignment to // When true, the alignment boolean will cause write-marker alignment to
// the size of the item being written, otherwise writes happen directly // the size of the item being written, otherwise writes happen directly
// at the current marker. // at the current marker.
void Write(const char* tag, bool b); void Write(const char* tag, bool b);
void Write(const char* tag, int i, bool align = true); void Write(const char* tag, int i, bool align = true);
void Write(const char* tag, zeek_int_t bi, bool align = true); void Write(const char* tag, zeek_int_t bi, bool align = true);
void Write(const char* tag, zeek_uint_t bu, bool align = true); void Write(const char* tag, zeek_uint_t bu, bool align = true);
void Write(const char* tag, uint32_t u, bool align = true); void Write(const char* tag, uint32_t u, bool align = true);
void Write(const char* tag, double d, bool align = true); void Write(const char* tag, double d, bool align = true);
void Write(const char* tag, const void* bytes, size_t n, size_t alignment = 0); void Write(const char* tag, const void* bytes, size_t n, size_t alignment = 0);
// For writes that copy directly into the allocated buffer, this method // For writes that copy directly into the allocated buffer, this method
// advances the write marker without modifying content. // advances the write marker without modifying content.
void SkipWrite(const char* tag, size_t n); void SkipWrite(const char* tag, size_t n);
// Aligns the write marker to the next multiple of the given alignment size. // Aligns the write marker to the next multiple of the given alignment size.
void AlignWrite(size_t alignment); void AlignWrite(size_t alignment);
// Bounds check: if the buffer does not have at least n bytes available // Bounds check: if the buffer does not have at least n bytes available
// to write into, triggers an InternalError. // to write into, triggers an InternalError.
void EnsureWriteSpace(size_t n) const; void EnsureWriteSpace(size_t n) const;
// Reads don't modify our internal state except for the read offset // Reads don't modify our internal state except for the read offset
// pointer. To blend in more seamlessly with the rest of Zeek we keep // pointer. To blend in more seamlessly with the rest of Zeek we keep
// reads a const operation. // reads a const operation.
void ResetRead() const { read_size = 0; } void ResetRead() const { read_size = 0; }
// Incremental reads from an allocated HashKey. As with writes, the // Incremental reads from an allocated HashKey. As with writes, the
// tags are only used for debug-build logging, and alignment prior // tags are only used for debug-build logging, and alignment prior
// to the read of the item is controlled by the align boolean. // to the read of the item is controlled by the align boolean.
void Read(const char* tag, bool& b) const; void Read(const char* tag, bool& b) const;
void Read(const char* tag, int& i, bool align = true) const; void Read(const char* tag, int& i, bool align = true) const;
void Read(const char* tag, zeek_int_t& bi, bool align = true) const; void Read(const char* tag, zeek_int_t& bi, bool align = true) const;
void Read(const char* tag, zeek_uint_t& bu, bool align = true) const; void Read(const char* tag, zeek_uint_t& bu, bool align = true) const;
void Read(const char* tag, uint32_t& u, bool align = true) const; void Read(const char* tag, uint32_t& u, bool align = true) const;
void Read(const char* tag, double& d, bool align = true) const; void Read(const char* tag, double& d, bool align = true) const;
void Read(const char* tag, void* out, size_t n, size_t alignment = 0) const; void Read(const char* tag, void* out, size_t n, size_t alignment = 0) const;
// These mirror the corresponding write methods above. // These mirror the corresponding write methods above.
void SkipRead(const char* tag, size_t n) const; void SkipRead(const char* tag, size_t n) const;
void AlignRead(size_t alignment) const; void AlignRead(size_t alignment) const;
void EnsureReadSpace(size_t n) const; void EnsureReadSpace(size_t n) const;
void* KeyAtWrite() { return static_cast<void*>(key + write_size); } void* KeyAtWrite() { return static_cast<void*>(key + write_size); }
const void* KeyAtRead() const { return static_cast<void*>(key + read_size); } const void* KeyAtRead() const { return static_cast<void*>(key + read_size); }
const void* KeyEnd() const { return static_cast<void*>(key + size); } const void* KeyEnd() const { return static_cast<void*>(key + size); }
void Describe(ODesc* d) const; void Describe(ODesc* d) const;
bool operator==(const HashKey& other) const; bool operator==(const HashKey& other) const;
bool operator!=(const HashKey& other) const; bool operator!=(const HashKey& other) const;
bool Equal(const void* other_key, size_t other_size, hash_t other_hash) const; bool Equal(const void* other_key, size_t other_size, hash_t other_hash) const;
// Copy operator. Always copies the key. // Copy operator. Always copies the key.
HashKey& operator=(const HashKey& other); HashKey& operator=(const HashKey& other);
// Move operator. Takes ownership of the key. // Move operator. Takes ownership of the key.
HashKey& operator=(HashKey&& other) noexcept; HashKey& operator=(HashKey&& other) noexcept;
protected: protected:
char* CopyKey(const char* key, size_t size) const; char* CopyKey(const char* key, size_t size) const;
// Payload setters for types stored directly in the key_u union. These // Payload setters for types stored directly in the key_u union. These
// adjust the size and write_size markers to indicate a full buffer, and // adjust the size and write_size markers to indicate a full buffer, and
// use the key_u union for storage. // use the key_u union for storage.
void Set(bool b); void Set(bool b);
void Set(int i); void Set(int i);
void Set(zeek_int_t bi); void Set(zeek_int_t bi);
void Set(zeek_uint_t bu); void Set(zeek_uint_t bu);
void Set(uint32_t u); void Set(uint32_t u);
void Set(double d); void Set(double d);
void Set(const void* p); void Set(const void* p);
union { union {
bool b; bool b;
int i; int i;
zeek_int_t bi; zeek_int_t bi;
uint32_t u32; uint32_t u32;
double d; double d;
const void* p; const void* p;
} key_u; } key_u;
char* key = nullptr; char* key = nullptr;
mutable hash_t hash = 0; mutable hash_t hash = 0;
size_t size = 0; size_t size = 0;
bool is_our_dynamic = false; bool is_our_dynamic = false;
size_t write_size = 0; size_t write_size = 0;
mutable size_t read_size = 0; mutable size_t read_size = 0;
}; };
extern void init_hash_function(); extern void init_hash_function();
} // namespace zeek::detail } // namespace zeek::detail

1137
src/ID.cc

File diff suppressed because it is too large Load diff

220
src/ID.h
View file

@ -13,8 +13,7 @@
#include "zeek/Obj.h" #include "zeek/Obj.h"
#include "zeek/TraverseTypes.h" #include "zeek/TraverseTypes.h"
namespace zeek namespace zeek {
{
class Func; class Func;
class Val; class Val;
@ -31,29 +30,22 @@ using EnumTypePtr = IntrusivePtr<EnumType>;
using ValPtr = IntrusivePtr<Val>; using ValPtr = IntrusivePtr<Val>;
using FuncPtr = IntrusivePtr<Func>; using FuncPtr = IntrusivePtr<Func>;
} } // namespace zeek
namespace zeek::detail namespace zeek::detail {
{
class Attributes; class Attributes;
class Expr; class Expr;
using ExprPtr = IntrusivePtr<Expr>; using ExprPtr = IntrusivePtr<Expr>;
enum InitClass enum InitClass {
{ INIT_NONE,
INIT_NONE, INIT_FULL,
INIT_FULL, INIT_EXTRA,
INIT_EXTRA, INIT_REMOVE,
INIT_REMOVE, INIT_SKIP,
INIT_SKIP, };
}; enum IDScope { SCOPE_FUNCTION, SCOPE_MODULE, SCOPE_GLOBAL };
enum IDScope
{
SCOPE_FUNCTION,
SCOPE_MODULE,
SCOPE_GLOBAL
};
class ID; class ID;
using IDPtr = IntrusivePtr<ID>; using IDPtr = IntrusivePtr<ID>;
@ -61,130 +53,131 @@ using IDSet = std::unordered_set<const ID*>;
class IDOptInfo; class IDOptInfo;
class ID final : public Obj, public notifier::detail::Modifiable class ID final : public Obj, public notifier::detail::Modifiable {
{
public: public:
static inline const IDPtr nil; static inline const IDPtr nil;
ID(const char* name, IDScope arg_scope, bool arg_is_export); ID(const char* name, IDScope arg_scope, bool arg_is_export);
~ID() override; ~ID() override;
const char* Name() const { return name; } const char* Name() const { return name; }
int Scope() const { return scope; } int Scope() const { return scope; }
bool IsGlobal() const { return scope != SCOPE_FUNCTION; } bool IsGlobal() const { return scope != SCOPE_FUNCTION; }
bool IsExport() const { return is_export; } bool IsExport() const { return is_export; }
void SetExport() { is_export = true; } void SetExport() { is_export = true; }
std::string ModuleName() const; std::string ModuleName() const;
void SetType(TypePtr t); void SetType(TypePtr t);
const TypePtr& GetType() const { return type; } const TypePtr& GetType() const { return type; }
template <class T> IntrusivePtr<T> GetType() const { return cast_intrusive<T>(type); } template<class T>
IntrusivePtr<T> GetType() const {
return cast_intrusive<T>(type);
}
bool IsType() const { return is_type; } bool IsType() const { return is_type; }
void MakeType() { is_type = true; } void MakeType() { is_type = true; }
void SetVal(ValPtr v); void SetVal(ValPtr v);
void SetVal(ValPtr v, InitClass c); void SetVal(ValPtr v, InitClass c);
void SetVal(ExprPtr ev, InitClass c); void SetVal(ExprPtr ev, InitClass c);
bool HasVal() const { return val != nullptr; } bool HasVal() const { return val != nullptr; }
const ValPtr& GetVal() const { return val; } const ValPtr& GetVal() const { return val; }
void ClearVal(); void ClearVal();
void SetConst() { is_const = true; } void SetConst() { is_const = true; }
bool IsConst() const { return is_const; } bool IsConst() const { return is_const; }
void SetOption(); void SetOption();
bool IsOption() const { return is_option; } bool IsOption() const { return is_option; }
bool IsBlank() const { return is_blank; }; bool IsBlank() const { return is_blank; };
void SetEnumConst() { is_enum_const = true; } void SetEnumConst() { is_enum_const = true; }
bool IsEnumConst() const { return is_enum_const; } bool IsEnumConst() const { return is_enum_const; }
void SetOffset(int arg_offset) { offset = arg_offset; } void SetOffset(int arg_offset) { offset = arg_offset; }
int Offset() const { return offset; } int Offset() const { return offset; }
bool IsRedefinable() const; bool IsRedefinable() const;
void SetAttrs(AttributesPtr attr); void SetAttrs(AttributesPtr attr);
void AddAttr(AttrPtr a, bool is_redef = false); void AddAttr(AttrPtr a, bool is_redef = false);
void AddAttrs(AttributesPtr attr, bool is_redef = false); void AddAttrs(AttributesPtr attr, bool is_redef = false);
void RemoveAttr(AttrTag a); void RemoveAttr(AttrTag a);
void UpdateValAttrs(); void UpdateValAttrs();
const AttributesPtr& GetAttrs() const { return attrs; } const AttributesPtr& GetAttrs() const { return attrs; }
const AttrPtr& GetAttr(AttrTag t) const; const AttrPtr& GetAttr(AttrTag t) const;
bool IsDeprecated() const; bool IsDeprecated() const;
void MakeDeprecated(ExprPtr deprecation); void MakeDeprecated(ExprPtr deprecation);
std::string GetDeprecationWarning() const; std::string GetDeprecationWarning() const;
void Error(const char* msg, const Obj* o2 = nullptr); void Error(const char* msg, const Obj* o2 = nullptr);
void Describe(ODesc* d) const override; void Describe(ODesc* d) const override;
// Adds type and value to description. // Adds type and value to description.
void DescribeExtended(ODesc* d) const; void DescribeExtended(ODesc* d) const;
// Produces a description that's reST-ready. // Produces a description that's reST-ready.
void DescribeReST(ODesc* d, bool roles_only = false) const; void DescribeReST(ODesc* d, bool roles_only = false) const;
void DescribeReSTShort(ODesc* d) const; void DescribeReSTShort(ODesc* d) const;
bool DoInferReturnType() const { return infer_return_type; } bool DoInferReturnType() const { return infer_return_type; }
void SetInferReturnType(bool infer) { infer_return_type = infer; } void SetInferReturnType(bool infer) { infer_return_type = infer; }
TraversalCode Traverse(TraversalCallback* cb) const; TraversalCode Traverse(TraversalCallback* cb) const;
bool HasOptionHandlers() const { return ! option_handlers.empty(); } bool HasOptionHandlers() const { return ! option_handlers.empty(); }
void AddOptionHandler(FuncPtr callback, int priority); void AddOptionHandler(FuncPtr callback, int priority);
std::vector<Func*> GetOptionHandlers() const; std::vector<Func*> GetOptionHandlers() const;
IDOptInfo* GetOptInfo() const { return opt_info; } IDOptInfo* GetOptInfo() const { return opt_info; }
void ClearOptInfo(); void ClearOptInfo();
protected: protected:
void EvalFunc(ExprPtr ef, ExprPtr ev); void EvalFunc(ExprPtr ef, ExprPtr ev);
#ifdef DEBUG #ifdef DEBUG
void UpdateValID(); void UpdateValID();
#endif #endif
const char* name; const char* name;
IDScope scope; IDScope scope;
bool is_export; bool is_export;
bool infer_return_type; bool infer_return_type;
TypePtr type; TypePtr type;
bool is_const, is_enum_const, is_type, is_option, is_blank; bool is_const, is_enum_const, is_type, is_option, is_blank;
int offset; int offset;
ValPtr val; ValPtr val;
AttributesPtr attrs; AttributesPtr attrs;
// contains list of functions that are called when an option changes // contains list of functions that are called when an option changes
std::multimap<int, FuncPtr> option_handlers; std::multimap<int, FuncPtr> option_handlers;
// Information managed by script optimization. We package this // Information managed by script optimization. We package this
// up into a separate object for purposes of modularity, and, // up into a separate object for purposes of modularity, and,
// via the associated pointer, to allow it to be modified in // via the associated pointer, to allow it to be modified in
// contexts where the ID is itself "const". // contexts where the ID is itself "const".
IDOptInfo* opt_info; IDOptInfo* opt_info;
}; };
} // namespace zeek::detail } // namespace zeek::detail
namespace zeek::id namespace zeek::id {
{
/** /**
* Lookup an ID in the global module and return it, if one exists; * Lookup an ID in the global module and return it, if one exists;
@ -208,10 +201,10 @@ const TypePtr& find_type(std::string_view name);
* @param name The identifier name to lookup * @param name The identifier name to lookup
* @return The type of the identifier. * @return The type of the identifier.
*/ */
template <class T> IntrusivePtr<T> find_type(std::string_view name) template<class T>
{ IntrusivePtr<T> find_type(std::string_view name) {
return cast_intrusive<T>(find_type(name)); return cast_intrusive<T>(find_type(name));
} }
/** /**
* Lookup an ID by its name and return its value. A fatal occurs if the ID * Lookup an ID by its name and return its value. A fatal occurs if the ID
@ -227,10 +220,10 @@ const ValPtr& find_val(std::string_view name);
* @param name The identifier name to lookup * @param name The identifier name to lookup
* @return The current value of the identifier. * @return The current value of the identifier.
*/ */
template <class T> IntrusivePtr<T> find_val(std::string_view name) template<class T>
{ IntrusivePtr<T> find_val(std::string_view name) {
return cast_intrusive<T>(find_val(name)); return cast_intrusive<T>(find_val(name));
} }
/** /**
* Lookup an ID by its name and return its value. A fatal occurs if the ID * Lookup an ID by its name and return its value. A fatal occurs if the ID
@ -246,10 +239,10 @@ const ValPtr& find_const(std::string_view name);
* @param name The identifier name to lookup * @param name The identifier name to lookup
* @return The current value of the identifier. * @return The current value of the identifier.
*/ */
template <class T> IntrusivePtr<T> find_const(std::string_view name) template<class T>
{ IntrusivePtr<T> find_const(std::string_view name) {
return cast_intrusive<T>(find_const(name)); return cast_intrusive<T>(find_const(name));
} }
/** /**
* Lookup an ID by its name and return the function it references. * Lookup an ID by its name and return the function it references.
@ -271,10 +264,9 @@ extern TableTypePtr count_set;
extern VectorTypePtr string_vec; extern VectorTypePtr string_vec;
extern VectorTypePtr index_vec; extern VectorTypePtr index_vec;
namespace detail namespace detail {
{
void init_types(); void init_types();
} // namespace detail } // namespace detail
} // namespace zeek::id } // namespace zeek::id

1529
src/IP.cc

File diff suppressed because it is too large Load diff

819
src/IP.h
View file

@ -20,8 +20,7 @@
#include "zeek/IntrusivePtr.h" #include "zeek/IntrusivePtr.h"
namespace zeek namespace zeek {
{
class IPAddr; class IPAddr;
class RecordVal; class RecordVal;
@ -29,530 +28,478 @@ class VectorVal;
using RecordValPtr = IntrusivePtr<RecordVal>; using RecordValPtr = IntrusivePtr<RecordVal>;
using VectorValPtr = IntrusivePtr<VectorVal>; using VectorValPtr = IntrusivePtr<VectorVal>;
namespace detail namespace detail {
{
class FragReassembler; class FragReassembler;
} }
#ifndef IPPROTO_MOBILITY #ifndef IPPROTO_MOBILITY
#define IPPROTO_MOBILITY 135 #define IPPROTO_MOBILITY 135
#endif #endif
struct ip6_mobility struct ip6_mobility {
{ uint8_t ip6mob_payload;
uint8_t ip6mob_payload; uint8_t ip6mob_len;
uint8_t ip6mob_len; uint8_t ip6mob_type;
uint8_t ip6mob_type; uint8_t ip6mob_rsv;
uint8_t ip6mob_rsv; uint16_t ip6mob_chksum;
uint16_t ip6mob_chksum; };
};
/** /**
* Base class for IPv6 header/extensions. * Base class for IPv6 header/extensions.
*/ */
class IPv6_Hdr class IPv6_Hdr {
{
public: public:
/** /**
* Construct an IPv6 header or extension header from assigned type number. * Construct an IPv6 header or extension header from assigned type number.
*/ */
IPv6_Hdr(uint8_t t, const u_char* d) : type(t), data(d) { } IPv6_Hdr(uint8_t t, const u_char* d) : type(t), data(d) {}
/** /**
* Replace the value of the next protocol field. * Replace the value of the next protocol field.
*/ */
void ChangeNext(uint8_t next_type) void ChangeNext(uint8_t next_type) {
{ switch ( type ) {
switch ( type ) case IPPROTO_IPV6: ((ip6_hdr*)data)->ip6_nxt = next_type; break;
{ case IPPROTO_HOPOPTS:
case IPPROTO_IPV6: case IPPROTO_DSTOPTS:
((ip6_hdr*)data)->ip6_nxt = next_type; case IPPROTO_ROUTING:
break; case IPPROTO_FRAGMENT:
case IPPROTO_HOPOPTS: case IPPROTO_AH:
case IPPROTO_DSTOPTS: case IPPROTO_MOBILITY: ((ip6_ext*)data)->ip6e_nxt = next_type; break;
case IPPROTO_ROUTING: case IPPROTO_ESP:
case IPPROTO_FRAGMENT: default: break;
case IPPROTO_AH: }
case IPPROTO_MOBILITY: }
((ip6_ext*)data)->ip6e_nxt = next_type;
break;
case IPPROTO_ESP:
default:
break;
}
}
~IPv6_Hdr() { } ~IPv6_Hdr() {}
/** /**
* Returns the assigned IPv6 extension header type number of the header * Returns the assigned IPv6 extension header type number of the header
* that immediately follows this one. * that immediately follows this one.
*/ */
uint8_t NextHdr() const uint8_t NextHdr() const {
{ switch ( type ) {
switch ( type ) case IPPROTO_IPV6: return ((ip6_hdr*)data)->ip6_nxt;
{ case IPPROTO_HOPOPTS:
case IPPROTO_IPV6: case IPPROTO_DSTOPTS:
return ((ip6_hdr*)data)->ip6_nxt; case IPPROTO_ROUTING:
case IPPROTO_HOPOPTS: case IPPROTO_FRAGMENT:
case IPPROTO_DSTOPTS: case IPPROTO_AH:
case IPPROTO_ROUTING: case IPPROTO_MOBILITY: return ((ip6_ext*)data)->ip6e_nxt;
case IPPROTO_FRAGMENT: case IPPROTO_ESP:
case IPPROTO_AH: default: return IPPROTO_NONE;
case IPPROTO_MOBILITY: }
return ((ip6_ext*)data)->ip6e_nxt; }
case IPPROTO_ESP:
default:
return IPPROTO_NONE;
}
}
/** /**
* Returns the length of the header in bytes. * Returns the length of the header in bytes.
*/ */
uint16_t Length() const uint16_t Length() const {
{ switch ( type ) {
switch ( type ) case IPPROTO_IPV6: return 40;
{ case IPPROTO_HOPOPTS:
case IPPROTO_IPV6: case IPPROTO_DSTOPTS:
return 40; case IPPROTO_ROUTING:
case IPPROTO_HOPOPTS: case IPPROTO_MOBILITY: return 8 + 8 * ((ip6_ext*)data)->ip6e_len;
case IPPROTO_DSTOPTS: case IPPROTO_FRAGMENT: return 8;
case IPPROTO_ROUTING: case IPPROTO_AH: return 8 + 4 * ((ip6_ext*)data)->ip6e_len;
case IPPROTO_MOBILITY: case IPPROTO_ESP: return 8; // encrypted payload begins after 8 bytes
return 8 + 8 * ((ip6_ext*)data)->ip6e_len; default: return 0;
case IPPROTO_FRAGMENT: }
return 8; }
case IPPROTO_AH:
return 8 + 4 * ((ip6_ext*)data)->ip6e_len;
case IPPROTO_ESP:
return 8; // encrypted payload begins after 8 bytes
default:
return 0;
}
}
/** /**
* Returns the RFC 1700 et seq. IANA assigned number for the header. * Returns the RFC 1700 et seq. IANA assigned number for the header.
*/ */
uint8_t Type() const { return type; } uint8_t Type() const { return type; }
/** /**
* Returns pointer to the start of where header structure resides in memory. * Returns pointer to the start of where header structure resides in memory.
*/ */
const u_char* Data() const { return data; } const u_char* Data() const { return data; }
/** /**
* Returns the script-layer record representation of the header. * Returns the script-layer record representation of the header.
*/ */
RecordValPtr ToVal(VectorValPtr chain) const; RecordValPtr ToVal(VectorValPtr chain) const;
RecordValPtr ToVal() const; RecordValPtr ToVal() const;
protected: protected:
uint8_t type; uint8_t type;
const u_char* data; const u_char* data;
private: private:
bool IsOptionTruncated(uint16_t off) const; bool IsOptionTruncated(uint16_t off) const;
}; };
class IPv6_Hdr_Chain class IPv6_Hdr_Chain {
{
public: public:
/** /**
* Initializes the header chain from an IPv6 header structure. * Initializes the header chain from an IPv6 header structure.
*/ */
IPv6_Hdr_Chain(const struct ip6_hdr* ip6, uint64_t len) { Init(ip6, len, false); } IPv6_Hdr_Chain(const struct ip6_hdr* ip6, uint64_t len) { Init(ip6, len, false); }
~IPv6_Hdr_Chain(); ~IPv6_Hdr_Chain();
/** /**
* @return a copy of the header chain, but with pointers to individual * @return a copy of the header chain, but with pointers to individual
* IPv6 headers now pointing within \a new_hdr. * IPv6 headers now pointing within \a new_hdr.
*/ */
IPv6_Hdr_Chain* Copy(const struct ip6_hdr* new_hdr) const; IPv6_Hdr_Chain* Copy(const struct ip6_hdr* new_hdr) const;
/** /**
* Returns the number of headers in the chain. * Returns the number of headers in the chain.
*/ */
size_t Size() const { return chain.size(); } size_t Size() const { return chain.size(); }
/** /**
* Returns the sum of the length of all headers in the chain in bytes. * Returns the sum of the length of all headers in the chain in bytes.
*/ */
uint16_t TotalLength() const { return length; } uint16_t TotalLength() const { return length; }
/** /**
* Accesses the header at the given location in the chain. * Accesses the header at the given location in the chain.
*/ */
const IPv6_Hdr* operator[](const size_t i) const { return chain[i]; } const IPv6_Hdr* operator[](const size_t i) const { return chain[i]; }
/** /**
* Returns whether the header chain indicates a fragmented packet. * Returns whether the header chain indicates a fragmented packet.
*/ */
bool IsFragment() const; bool IsFragment() const;
/** /**
* Returns pointer to fragment header structure if the chain contains one. * Returns pointer to fragment header structure if the chain contains one.
*/ */
const struct ip6_frag* GetFragHdr() const const struct ip6_frag* GetFragHdr() const {
{ return IsFragment() ? (const struct ip6_frag*)chain[chain.size() - 1]->Data() : nullptr;
return IsFragment() ? (const struct ip6_frag*)chain[chain.size() - 1]->Data() : nullptr; }
}
/** /**
* If the header chain is a fragment, returns the offset in number of bytes * If the header chain is a fragment, returns the offset in number of bytes
* relative to the start of the Fragmentable Part of the original packet. * relative to the start of the Fragmentable Part of the original packet.
*/ */
uint16_t FragOffset() const uint16_t FragOffset() const { return IsFragment() ? (ntohs(GetFragHdr()->ip6f_offlg) & 0xfff8) : 0; }
{
return IsFragment() ? (ntohs(GetFragHdr()->ip6f_offlg) & 0xfff8) : 0;
}
/** /**
* If the header chain is a fragment, returns the identification field. * If the header chain is a fragment, returns the identification field.
*/ */
uint32_t ID() const { return IsFragment() ? ntohl(GetFragHdr()->ip6f_ident) : 0; } uint32_t ID() const { return IsFragment() ? ntohl(GetFragHdr()->ip6f_ident) : 0; }
/** /**
* If the header chain is a fragment, returns the M (more fragments) flag. * If the header chain is a fragment, returns the M (more fragments) flag.
*/ */
int MF() const { return IsFragment() ? (ntohs(GetFragHdr()->ip6f_offlg) & 0x0001) != 0 : 0; } int MF() const { return IsFragment() ? (ntohs(GetFragHdr()->ip6f_offlg) & 0x0001) != 0 : 0; }
/** /**
* If the chain contains a Destination Options header with a Home Address * If the chain contains a Destination Options header with a Home Address
* option as defined by Mobile IPv6 (RFC 6275), then return it, else * option as defined by Mobile IPv6 (RFC 6275), then return it, else
* return the source address in the main IPv6 header. * return the source address in the main IPv6 header.
*/ */
IPAddr SrcAddr() const; IPAddr SrcAddr() const;
/** /**
* If the chain contains a Routing header with non-zero segments left, * If the chain contains a Routing header with non-zero segments left,
* then return the last address of the first such header, else return * then return the last address of the first such header, else return
* the destination address of the main IPv6 header. * the destination address of the main IPv6 header.
*/ */
IPAddr DstAddr() const; IPAddr DstAddr() const;
/** /**
* Returns a vector of ip6_ext_hdr RecordVals that includes script-layer * Returns a vector of ip6_ext_hdr RecordVals that includes script-layer
* representation of all extension headers in the chain. * representation of all extension headers in the chain.
*/ */
VectorValPtr ToVal() const; VectorValPtr ToVal() const;
protected: protected:
// for access to protected ctor that changes next header values that // for access to protected ctor that changes next header values that
// point to a fragment // point to a fragment
friend class detail::FragReassembler; friend class detail::FragReassembler;
IPv6_Hdr_Chain() = default; IPv6_Hdr_Chain() = default;
/** /**
* Initializes the header chain from an IPv6 header structure, and replaces * Initializes the header chain from an IPv6 header structure, and replaces
* the first next protocol pointer field that points to a fragment header. * the first next protocol pointer field that points to a fragment header.
*/ */
IPv6_Hdr_Chain(const struct ip6_hdr* ip6, uint16_t next, uint64_t len) IPv6_Hdr_Chain(const struct ip6_hdr* ip6, uint16_t next, uint64_t len) { Init(ip6, len, true, next); }
{
Init(ip6, len, true, next);
}
/** /**
* Initializes the header chain from an IPv6 header structure of a given * Initializes the header chain from an IPv6 header structure of a given
* length, possibly setting the first next protocol pointer field that * length, possibly setting the first next protocol pointer field that
* points to a fragment header. * points to a fragment header.
*/ */
void Init(const struct ip6_hdr* ip6, uint64_t total_len, bool set_next, uint16_t next = 0); void Init(const struct ip6_hdr* ip6, uint64_t total_len, bool set_next, uint16_t next = 0);
/** /**
* Process a routing header and allocate/remember the final destination * Process a routing header and allocate/remember the final destination
* address if it has segments left and is a valid routing header. * address if it has segments left and is a valid routing header.
*/ */
void ProcessRoutingHeader(const struct ip6_rthdr* r, uint16_t len); void ProcessRoutingHeader(const struct ip6_rthdr* r, uint16_t len);
/** /**
* Inspect a Destination Option header's options for things we need to * Inspect a Destination Option header's options for things we need to
* remember, such as the Home Address option from Mobile IPv6. * remember, such as the Home Address option from Mobile IPv6.
*/ */
void ProcessDstOpts(const struct ip6_dest* d, uint16_t len); void ProcessDstOpts(const struct ip6_dest* d, uint16_t len);
std::vector<IPv6_Hdr*> chain; std::vector<IPv6_Hdr*> chain;
/** /**
* The summation of all header lengths in the chain in bytes. * The summation of all header lengths in the chain in bytes.
*/ */
uint16_t length = 0; uint16_t length = 0;
/** /**
* Home Address of the packet's source as defined by Mobile IPv6 (RFC 6275). * Home Address of the packet's source as defined by Mobile IPv6 (RFC 6275).
*/ */
IPAddr* homeAddr = nullptr; IPAddr* homeAddr = nullptr;
/** /**
* The final destination address in chain's first Routing header that has * The final destination address in chain's first Routing header that has
* non-zero segments left. * non-zero segments left.
*/ */
IPAddr* finalDst = nullptr; IPAddr* finalDst = nullptr;
}; };
/** /**
* A class that wraps either an IPv4 or IPv6 packet and abstracts methods * A class that wraps either an IPv4 or IPv6 packet and abstracts methods
* for inquiring about common features between the two. * for inquiring about common features between the two.
*/ */
class IP_Hdr class IP_Hdr {
{
public: public:
/** /**
* Construct the header wrapper from an IPv4 packet. Caller must have * Construct the header wrapper from an IPv4 packet. Caller must have
* already checked that the header is not truncated. * already checked that the header is not truncated.
* @param arg_ip4 pointer to memory containing an IPv4 packet. * @param arg_ip4 pointer to memory containing an IPv4 packet.
* @param arg_del whether to take ownership of \a arg_ip4 pointer's memory. * @param arg_del whether to take ownership of \a arg_ip4 pointer's memory.
* @param reassembled whether this header is for a reassembled packet. * @param reassembled whether this header is for a reassembled packet.
*/ */
IP_Hdr(const struct ip* arg_ip4, bool arg_del, bool reassembled = false) IP_Hdr(const struct ip* arg_ip4, bool arg_del, bool reassembled = false)
: ip4(arg_ip4), del(arg_del), reassembled(reassembled) : ip4(arg_ip4), del(arg_del), reassembled(reassembled) {}
{
}
/** /**
* Construct the header wrapper from an IPv6 packet. Caller must have * Construct the header wrapper from an IPv6 packet. Caller must have
* already checked that the static IPv6 header is not truncated. If * already checked that the static IPv6 header is not truncated. If
* the packet contains extension headers and they are truncated, that can * the packet contains extension headers and they are truncated, that can
* be checked afterwards by comparing \a len with \a TotalLen. E.g. * be checked afterwards by comparing \a len with \a TotalLen. E.g.
* The IP packet analyzer does this to skip truncated packets. * The IP packet analyzer does this to skip truncated packets.
* @param arg_ip6 pointer to memory containing an IPv6 packet. * @param arg_ip6 pointer to memory containing an IPv6 packet.
* @param arg_del whether to take ownership of \a arg_ip6 pointer's memory. * @param arg_del whether to take ownership of \a arg_ip6 pointer's memory.
* @param len the packet's length in bytes. * @param len the packet's length in bytes.
* @param c an already-constructed header chain to take ownership of. * @param c an already-constructed header chain to take ownership of.
* @param reassembled whether this header is for a reassembled packet. * @param reassembled whether this header is for a reassembled packet.
*/ */
IP_Hdr(const struct ip6_hdr* arg_ip6, bool arg_del, uint64_t len, IP_Hdr(const struct ip6_hdr* arg_ip6, bool arg_del, uint64_t len, const IPv6_Hdr_Chain* c = nullptr,
const IPv6_Hdr_Chain* c = nullptr, bool reassembled = false) bool reassembled = false)
: ip6(arg_ip6), ip6_hdrs(c ? c : new IPv6_Hdr_Chain(ip6, len)), del(arg_del), : ip6(arg_ip6), ip6_hdrs(c ? c : new IPv6_Hdr_Chain(ip6, len)), del(arg_del), reassembled(reassembled) {}
reassembled(reassembled)
{
}
/** /**
* Copy a header. The internal buffer which contains the header data * Copy a header. The internal buffer which contains the header data
* must not be truncated. Also note that if that buffer points to a full * must not be truncated. Also note that if that buffer points to a full
* packet payload, only the IP header portion is copied. * packet payload, only the IP header portion is copied.
*/ */
IP_Hdr* Copy() const; IP_Hdr* Copy() const;
/** /**
* Destructor. * Destructor.
*/ */
~IP_Hdr() ~IP_Hdr() {
{ delete ip6_hdrs;
delete ip6_hdrs;
if ( del ) if ( del ) {
{ delete[](struct ip*) ip4;
delete[](struct ip*) ip4; delete[](struct ip6_hdr*) ip6;
delete[](struct ip6_hdr*) ip6; }
} }
}
/** /**
* If an IPv4 packet is wrapped, return a pointer to it, else null. * If an IPv4 packet is wrapped, return a pointer to it, else null.
*/ */
const struct ip* IP4_Hdr() const { return ip4; } const struct ip* IP4_Hdr() const { return ip4; }
/** /**
* If an IPv6 packet is wrapped, return a pointer to it, else null. * If an IPv6 packet is wrapped, return a pointer to it, else null.
*/ */
const struct ip6_hdr* IP6_Hdr() const { return ip6; } const struct ip6_hdr* IP6_Hdr() const { return ip6; }
/** /**
* Returns the source address held in the IP header. * Returns the source address held in the IP header.
*/ */
IPAddr IPHeaderSrcAddr() const; IPAddr IPHeaderSrcAddr() const;
/** /**
* Returns the destination address held in the IP header. * Returns the destination address held in the IP header.
*/ */
IPAddr IPHeaderDstAddr() const; IPAddr IPHeaderDstAddr() const;
/** /**
* For IPv4 or IPv6 headers that don't contain a Home Address option * For IPv4 or IPv6 headers that don't contain a Home Address option
* (Mobile IPv6, RFC 6275), return source address held in the IP header. * (Mobile IPv6, RFC 6275), return source address held in the IP header.
* For IPv6 headers that contain a Home Address option, return that address. * For IPv6 headers that contain a Home Address option, return that address.
*/ */
IPAddr SrcAddr() const; IPAddr SrcAddr() const;
/** /**
* For IPv4 or IPv6 headers that don't contain a Routing header with * For IPv4 or IPv6 headers that don't contain a Routing header with
* non-zero segments left, return destination address held in the IP header. * non-zero segments left, return destination address held in the IP header.
* For IPv6 headers with a Routing header that has non-zero segments left, * For IPv6 headers with a Routing header that has non-zero segments left,
* return the last address in the first such Routing header. * return the last address in the first such Routing header.
*/ */
IPAddr DstAddr() const; IPAddr DstAddr() const;
/** /**
* Returns a pointer to the payload of the IP packet, usually an * Returns a pointer to the payload of the IP packet, usually an
* upper-layer protocol. * upper-layer protocol.
*/ */
const u_char* Payload() const const u_char* Payload() const {
{ if ( ip4 )
if ( ip4 ) return ((const u_char*)ip4) + ip4->ip_hl * 4;
return ((const u_char*)ip4) + ip4->ip_hl * 4;
return ((const u_char*)ip6) + ip6_hdrs->TotalLength(); return ((const u_char*)ip6) + ip6_hdrs->TotalLength();
} }
/** /**
* Returns a pointer to the mobility header of the IP packet, if present, * Returns a pointer to the mobility header of the IP packet, if present,
* else a null pointer. * else a null pointer.
*/ */
const ip6_mobility* MobilityHeader() const const ip6_mobility* MobilityHeader() const {
{ if ( ip4 )
if ( ip4 ) return nullptr;
return nullptr; else if ( (*ip6_hdrs)[ip6_hdrs->Size() - 1]->Type() != IPPROTO_MOBILITY )
else if ( (*ip6_hdrs)[ip6_hdrs->Size() - 1]->Type() != IPPROTO_MOBILITY ) return nullptr;
return nullptr; else
else return (const ip6_mobility*)(*ip6_hdrs)[ip6_hdrs->Size() - 1]->Data();
return (const ip6_mobility*)(*ip6_hdrs)[ip6_hdrs->Size() - 1]->Data(); }
}
/** /**
* Returns the length of the IP packet's payload (length of packet minus * Returns the length of the IP packet's payload (length of packet minus
* header length or, for IPv6, also minus length of all extension headers). * header length or, for IPv6, also minus length of all extension headers).
* *
* Also returns 0 if the IPv4 length field is set to zero - which is, e.g., * Also returns 0 if the IPv4 length field is set to zero - which is, e.g.,
* the case when TCP segment offloading is enabled. * the case when TCP segment offloading is enabled.
*/ */
uint16_t PayloadLen() const uint16_t PayloadLen() const {
{ if ( ip4 ) {
if ( ip4 ) // prevent overflow in case of segment offloading/zeroed header length.
{ auto total_len = ntohs(ip4->ip_len);
// prevent overflow in case of segment offloading/zeroed header length. return total_len ? total_len - ip4->ip_hl * 4 : 0;
auto total_len = ntohs(ip4->ip_len); }
return total_len ? total_len - ip4->ip_hl * 4 : 0;
}
return ntohs(ip6->ip6_plen) + 40 - ip6_hdrs->TotalLength(); return ntohs(ip6->ip6_plen) + 40 - ip6_hdrs->TotalLength();
} }
/** /**
* Returns the length of the IP packet (length of headers and payload). * Returns the length of the IP packet (length of headers and payload).
*/ */
uint32_t TotalLen() const uint32_t TotalLen() const {
{ if ( ip4 )
if ( ip4 ) return ntohs(ip4->ip_len);
return ntohs(ip4->ip_len);
return ntohs(ip6->ip6_plen) + 40; return ntohs(ip6->ip6_plen) + 40;
} }
/** /**
* Returns length of IP packet header (includes extension headers for IPv6). * Returns length of IP packet header (includes extension headers for IPv6).
*/ */
uint16_t HdrLen() const { return ip4 ? ip4->ip_hl * 4 : ip6_hdrs->TotalLength(); } uint16_t HdrLen() const { return ip4 ? ip4->ip_hl * 4 : ip6_hdrs->TotalLength(); }
/** /**
* For IPv6 header chains, returns the type of the last header in the chain. * For IPv6 header chains, returns the type of the last header in the chain.
*/ */
uint8_t LastHeader() const uint8_t LastHeader() const {
{ if ( ip4 )
if ( ip4 ) return IPPROTO_RAW;
return IPPROTO_RAW;
size_t i = ip6_hdrs->Size(); size_t i = ip6_hdrs->Size();
if ( i > 0 ) if ( i > 0 )
return (*ip6_hdrs)[i - 1]->Type(); return (*ip6_hdrs)[i - 1]->Type();
return IPPROTO_NONE; return IPPROTO_NONE;
} }
/** /**
* Returns the protocol type of the IP packet's payload, usually an * Returns the protocol type of the IP packet's payload, usually an
* upper-layer protocol. For IPv6, this returns the last (extension) * upper-layer protocol. For IPv6, this returns the last (extension)
* header's Next Header value. * header's Next Header value.
*/ */
unsigned char NextProto() const unsigned char NextProto() const {
{ if ( ip4 )
if ( ip4 ) return ip4->ip_p;
return ip4->ip_p;
size_t i = ip6_hdrs->Size(); size_t i = ip6_hdrs->Size();
if ( i > 0 ) if ( i > 0 )
return (*ip6_hdrs)[i - 1]->NextHdr(); return (*ip6_hdrs)[i - 1]->NextHdr();
return IPPROTO_NONE; return IPPROTO_NONE;
} }
/** /**
* Returns the IPv4 Time to Live or IPv6 Hop Limit field. * Returns the IPv4 Time to Live or IPv6 Hop Limit field.
*/ */
unsigned char TTL() const { return ip4 ? ip4->ip_ttl : ip6->ip6_hlim; } unsigned char TTL() const { return ip4 ? ip4->ip_ttl : ip6->ip6_hlim; }
/** /**
* Returns whether the IP header indicates this packet is a fragment. * Returns whether the IP header indicates this packet is a fragment.
*/ */
bool IsFragment() const bool IsFragment() const { return ip4 ? (ntohs(ip4->ip_off) & 0x3fff) != 0 : ip6_hdrs->IsFragment(); }
{
return ip4 ? (ntohs(ip4->ip_off) & 0x3fff) != 0 : ip6_hdrs->IsFragment();
}
/** /**
* Returns the fragment packet's offset in relation to the original * Returns the fragment packet's offset in relation to the original
* packet in bytes. * packet in bytes.
*/ */
uint16_t FragOffset() const uint16_t FragOffset() const { return ip4 ? (ntohs(ip4->ip_off) & 0x1fff) * 8 : ip6_hdrs->FragOffset(); }
{
return ip4 ? (ntohs(ip4->ip_off) & 0x1fff) * 8 : ip6_hdrs->FragOffset();
}
/** /**
* Returns the fragment packet's identification field. * Returns the fragment packet's identification field.
*/ */
uint32_t ID() const { return ip4 ? ntohs(ip4->ip_id) : ip6_hdrs->ID(); } uint32_t ID() const { return ip4 ? ntohs(ip4->ip_id) : ip6_hdrs->ID(); }
/** /**
* Returns whether a fragment packet's "More Fragments" field is set. * Returns whether a fragment packet's "More Fragments" field is set.
*/ */
int MF() const { return ip4 ? (ntohs(ip4->ip_off) & 0x2000) != 0 : ip6_hdrs->MF(); } int MF() const { return ip4 ? (ntohs(ip4->ip_off) & 0x2000) != 0 : ip6_hdrs->MF(); }
/** /**
* Returns whether a fragment packet's "Don't Fragment" field is set. * Returns whether a fragment packet's "Don't Fragment" field is set.
* Note that IPv6 has no such field. * Note that IPv6 has no such field.
*/ */
int DF() const { return ip4 ? ((ntohs(ip4->ip_off) & 0x4000) != 0) : 0; } int DF() const { return ip4 ? ((ntohs(ip4->ip_off) & 0x4000) != 0) : 0; }
/** /**
* Returns value of an IPv6 header's flow label field or 0 if it's IPv4. * Returns value of an IPv6 header's flow label field or 0 if it's IPv4.
*/ */
uint32_t FlowLabel() const { return ip4 ? 0 : (ntohl(ip6->ip6_flow) & 0x000fffff); } uint32_t FlowLabel() const { return ip4 ? 0 : (ntohl(ip6->ip6_flow) & 0x000fffff); }
/** /**
* Returns number of IP headers in packet (includes IPv6 extension headers). * Returns number of IP headers in packet (includes IPv6 extension headers).
*/ */
size_t NumHeaders() const { return ip4 ? 1 : ip6_hdrs->Size(); } size_t NumHeaders() const { return ip4 ? 1 : ip6_hdrs->Size(); }
/** /**
* Returns an ip_hdr or ip6_hdr_chain RecordVal. * Returns an ip_hdr or ip6_hdr_chain RecordVal.
*/ */
RecordValPtr ToIPHdrVal() const; RecordValPtr ToIPHdrVal() const;
/** /**
* Returns a pkt_hdr RecordVal, which includes not only the IP header, but * Returns a pkt_hdr RecordVal, which includes not only the IP header, but
* also upper-layer (tcp/udp/icmp) headers. * also upper-layer (tcp/udp/icmp) headers.
*/ */
RecordValPtr ToPktHdrVal() const; RecordValPtr ToPktHdrVal() const;
/** /**
* Same as above, but simply add our values into the record at the * Same as above, but simply add our values into the record at the
* specified starting index. * specified starting index.
*/ */
RecordValPtr ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const; RecordValPtr ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const;
bool Reassembled() const { return reassembled; } bool Reassembled() const { return reassembled; }
private: private:
const struct ip* ip4 = nullptr; const struct ip* ip4 = nullptr;
const struct ip6_hdr* ip6 = nullptr; const struct ip6_hdr* ip6 = nullptr;
const IPv6_Hdr_Chain* ip6_hdrs = nullptr; const IPv6_Hdr_Chain* ip6_hdrs = nullptr;
bool del = false; bool del = false;
bool reassembled = false; bool reassembled = false;
}; };
} // namespace zeek } // namespace zeek

View file

@ -13,415 +13,365 @@
#include "zeek/ZeekString.h" #include "zeek/ZeekString.h"
#include "zeek/analyzer/Manager.h" #include "zeek/analyzer/Manager.h"
namespace zeek namespace zeek {
{
const IPAddr IPAddr::v4_unspecified = IPAddr(in4_addr{}); const IPAddr IPAddr::v4_unspecified = IPAddr(in4_addr{});
const IPAddr IPAddr::v6_unspecified = IPAddr(); const IPAddr IPAddr::v6_unspecified = IPAddr();
namespace detail namespace detail {
{
ConnKey::ConnKey(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, TransportProto t,
ConnKey::ConnKey(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, bool one_way) {
TransportProto t, bool one_way) Init(src, dst, src_port, dst_port, t, one_way);
{ }
Init(src, dst, src_port, dst_port, t, one_way);
} ConnKey::ConnKey(const ConnTuple& id) {
Init(id.src_addr, id.dst_addr, id.src_port, id.dst_port, id.proto, id.is_one_way);
ConnKey::ConnKey(const ConnTuple& id) }
{
Init(id.src_addr, id.dst_addr, id.src_port, id.dst_port, id.proto, id.is_one_way); ConnKey& ConnKey::operator=(const ConnKey& rhs) {
} if ( this == &rhs )
return *this;
ConnKey& ConnKey::operator=(const ConnKey& rhs)
{ // Because of padding in the object, this needs to memset to clear out
if ( this == &rhs ) // the extra memory used by padding. Otherwise, the session key stuff
return *this; // doesn't work quite right.
memset(this, 0, sizeof(ConnKey));
// Because of padding in the object, this needs to memset to clear out
// the extra memory used by padding. Otherwise, the session key stuff memcpy(&ip1, &rhs.ip1, sizeof(in6_addr));
// doesn't work quite right. memcpy(&ip2, &rhs.ip2, sizeof(in6_addr));
memset(this, 0, sizeof(ConnKey)); port1 = rhs.port1;
port2 = rhs.port2;
memcpy(&ip1, &rhs.ip1, sizeof(in6_addr)); transport = rhs.transport;
memcpy(&ip2, &rhs.ip2, sizeof(in6_addr)); valid = rhs.valid;
port1 = rhs.port1;
port2 = rhs.port2; return *this;
transport = rhs.transport; }
valid = rhs.valid;
ConnKey::ConnKey(Val* v) {
return *this; const auto& vt = v->GetType();
} if ( ! IsRecord(vt->Tag()) ) {
valid = false;
ConnKey::ConnKey(Val* v) return;
{ }
const auto& vt = v->GetType();
if ( ! IsRecord(vt->Tag()) ) RecordType* vr = vt->AsRecordType();
{ auto vl = v->As<RecordVal*>();
valid = false;
return; int orig_h, orig_p; // indices into record's value list
} int resp_h, resp_p;
RecordType* vr = vt->AsRecordType(); if ( vr == id::conn_id ) {
auto vl = v->As<RecordVal*>(); orig_h = 0;
orig_p = 1;
int orig_h, orig_p; // indices into record's value list resp_h = 2;
int resp_h, resp_p; resp_p = 3;
}
if ( vr == id::conn_id ) else {
{ // While it's not a conn_id, it may have equivalent fields.
orig_h = 0; orig_h = vr->FieldOffset("orig_h");
orig_p = 1; resp_h = vr->FieldOffset("resp_h");
resp_h = 2; orig_p = vr->FieldOffset("orig_p");
resp_p = 3; resp_p = vr->FieldOffset("resp_p");
}
else if ( orig_h < 0 || resp_h < 0 || orig_p < 0 || resp_p < 0 ) {
{ valid = false;
// While it's not a conn_id, it may have equivalent fields. return;
orig_h = vr->FieldOffset("orig_h"); }
resp_h = vr->FieldOffset("resp_h");
orig_p = vr->FieldOffset("orig_p"); // ### we ought to check that the fields have the right
resp_p = vr->FieldOffset("resp_p"); // types, too.
}
if ( orig_h < 0 || resp_h < 0 || orig_p < 0 || resp_p < 0 )
{ const IPAddr& orig_addr = vl->GetFieldAs<AddrVal>(orig_h);
valid = false; const IPAddr& resp_addr = vl->GetFieldAs<AddrVal>(resp_h);
return;
} auto orig_portv = vl->GetFieldAs<PortVal>(orig_p);
auto resp_portv = vl->GetFieldAs<PortVal>(resp_p);
// ### we ought to check that the fields have the right
// types, too. Init(orig_addr, resp_addr, htons((unsigned short)orig_portv->Port()), htons((unsigned short)resp_portv->Port()),
} orig_portv->PortType(), false);
}
const IPAddr& orig_addr = vl->GetFieldAs<AddrVal>(orig_h);
const IPAddr& resp_addr = vl->GetFieldAs<AddrVal>(resp_h); void ConnKey::Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, TransportProto t,
bool one_way) {
auto orig_portv = vl->GetFieldAs<PortVal>(orig_p); // Because of padding in the object, this needs to memset to clear out
auto resp_portv = vl->GetFieldAs<PortVal>(resp_p); // the extra memory used by padding. Otherwise, the session key stuff
// doesn't work quite right.
Init(orig_addr, resp_addr, htons((unsigned short)orig_portv->Port()), memset(this, 0, sizeof(ConnKey));
htons((unsigned short)resp_portv->Port()), orig_portv->PortType(), false);
} // Lookup up connection based on canonical ordering, which is
// the smaller of <src addr, src port> and <dst addr, dst port>
void ConnKey::Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, // followed by the other.
TransportProto t, bool one_way) if ( one_way || addr_port_canon_lt(src, src_port, dst, dst_port) ) {
{ ip1 = src.in6;
// Because of padding in the object, this needs to memset to clear out ip2 = dst.in6;
// the extra memory used by padding. Otherwise, the session key stuff port1 = src_port;
// doesn't work quite right. port2 = dst_port;
memset(this, 0, sizeof(ConnKey)); }
else {
// Lookup up connection based on canonical ordering, which is ip1 = dst.in6;
// the smaller of <src addr, src port> and <dst addr, dst port> ip2 = src.in6;
// followed by the other. port1 = dst_port;
if ( one_way || addr_port_canon_lt(src, src_port, dst, dst_port) ) port2 = src_port;
{ }
ip1 = src.in6;
ip2 = dst.in6; transport = t;
port1 = src_port; valid = true;
port2 = dst_port; }
}
else } // namespace detail
{
ip1 = dst.in6; IPAddr::IPAddr(const String& s) { Init(s.CheckString()); }
ip2 = src.in6;
port1 = dst_port; std::unique_ptr<detail::HashKey> IPAddr::MakeHashKey() const {
port2 = src_port; return std::make_unique<detail::HashKey>((void*)in6.s6_addr, sizeof(in6.s6_addr));
} }
transport = t; static inline uint32_t bit_mask32(int bottom_bits) {
valid = true; if ( bottom_bits >= 32 )
} return 0xffffffff;
} // namespace detail return (((uint32_t)1) << bottom_bits) - 1;
}
IPAddr::IPAddr(const String& s)
{ void IPAddr::Mask(int top_bits_to_keep) {
Init(s.CheckString()); if ( top_bits_to_keep < 0 || top_bits_to_keep > 128 ) {
} reporter->Error("Bad IPAddr::Mask value %d", top_bits_to_keep);
return;
std::unique_ptr<detail::HashKey> IPAddr::MakeHashKey() const }
{
return std::make_unique<detail::HashKey>((void*)in6.s6_addr, sizeof(in6.s6_addr)); uint32_t mask_bits[4] = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff};
} std::ldiv_t res = std::ldiv(top_bits_to_keep, 32);
static inline uint32_t bit_mask32(int bottom_bits) if ( res.quot < 4 )
{ mask_bits[res.quot] = htonl(mask_bits[res.quot] & ~bit_mask32(32 - res.rem));
if ( bottom_bits >= 32 )
return 0xffffffff; for ( unsigned int i = res.quot + 1; i < 4; ++i )
mask_bits[i] = 0;
return (((uint32_t)1) << bottom_bits) - 1;
} uint32_t* p = reinterpret_cast<uint32_t*>(in6.s6_addr);
void IPAddr::Mask(int top_bits_to_keep) for ( unsigned int i = 0; i < 4; ++i )
{ p[i] &= mask_bits[i];
if ( top_bits_to_keep < 0 || top_bits_to_keep > 128 ) }
{
reporter->Error("Bad IPAddr::Mask value %d", top_bits_to_keep); void IPAddr::ReverseMask(int top_bits_to_chop) {
return; if ( top_bits_to_chop < 0 || top_bits_to_chop > 128 ) {
} reporter->Error("Bad IPAddr::ReverseMask value %d", top_bits_to_chop);
return;
uint32_t mask_bits[4] = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff}; }
std::ldiv_t res = std::ldiv(top_bits_to_keep, 32);
uint32_t mask_bits[4] = {0, 0, 0, 0};
if ( res.quot < 4 ) std::ldiv_t res = std::ldiv(top_bits_to_chop, 32);
mask_bits[res.quot] = htonl(mask_bits[res.quot] & ~bit_mask32(32 - res.rem));
if ( res.quot < 4 )
for ( unsigned int i = res.quot + 1; i < 4; ++i ) mask_bits[res.quot] = htonl(bit_mask32(32 - res.rem));
mask_bits[i] = 0;
for ( unsigned int i = res.quot + 1; i < 4; ++i )
uint32_t* p = reinterpret_cast<uint32_t*>(in6.s6_addr); mask_bits[i] = 0xffffffff;
for ( unsigned int i = 0; i < 4; ++i ) uint32_t* p = reinterpret_cast<uint32_t*>(in6.s6_addr);
p[i] &= mask_bits[i];
} for ( unsigned int i = 0; i < 4; ++i )
p[i] &= mask_bits[i];
void IPAddr::ReverseMask(int top_bits_to_chop) }
{
if ( top_bits_to_chop < 0 || top_bits_to_chop > 128 ) bool IPAddr::ConvertString(const char* s, in6_addr* result) {
{ for ( auto p = s; *p; ++p )
reporter->Error("Bad IPAddr::ReverseMask value %d", top_bits_to_chop); if ( *p == ':' )
return; // IPv6
} return (inet_pton(AF_INET6, s, result->s6_addr) == 1);
uint32_t mask_bits[4] = {0, 0, 0, 0}; // IPv4
std::ldiv_t res = std::ldiv(top_bits_to_chop, 32); // Parse the address directly instead of using inet_pton since
// some platforms have more sensitive implementations than others
if ( res.quot < 4 ) // that can't e.g. handle leading zeroes.
mask_bits[res.quot] = htonl(bit_mask32(32 - res.rem)); int a[4];
int n = 0;
for ( unsigned int i = res.quot + 1; i < 4; ++i ) int match_count = sscanf(s, "%d.%d.%d.%d%n", a + 0, a + 1, a + 2, a + 3, &n);
mask_bits[i] = 0xffffffff;
if ( match_count != 4 )
uint32_t* p = reinterpret_cast<uint32_t*>(in6.s6_addr); return false;
for ( unsigned int i = 0; i < 4; ++i ) if ( s[n] != '\0' )
p[i] &= mask_bits[i]; return false;
}
for ( auto i = 0; i < 4; ++i )
bool IPAddr::ConvertString(const char* s, in6_addr* result) if ( a[i] < 0 || a[i] > 255 )
{ return false;
for ( auto p = s; *p; ++p )
if ( *p == ':' ) uint32_t addr = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3];
// IPv6 addr = htonl(addr);
return (inet_pton(AF_INET6, s, result->s6_addr) == 1); memcpy(result->s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix));
memcpy(&result->s6_addr[12], &addr, sizeof(uint32_t));
// IPv4 return true;
// Parse the address directly instead of using inet_pton since }
// some platforms have more sensitive implementations than others
// that can't e.g. handle leading zeroes. void IPAddr::Init(const char* s) {
int a[4]; if ( ! ConvertString(s, &in6) ) {
int n = 0; reporter->Error("Bad IP address: %s", s);
int match_count = sscanf(s, "%d.%d.%d.%d%n", a + 0, a + 1, a + 2, a + 3, &n); memset(in6.s6_addr, 0, sizeof(in6.s6_addr));
}
if ( match_count != 4 ) }
return false;
std::string IPAddr::AsString() const {
if ( s[n] != '\0' ) if ( GetFamily() == IPv4 ) {
return false; char s[INET_ADDRSTRLEN];
for ( auto i = 0; i < 4; ++i ) if ( ! zeek_inet_ntop(AF_INET, &in6.s6_addr[12], s, INET_ADDRSTRLEN) )
if ( a[i] < 0 || a[i] > 255 ) return "<bad IPv4 address conversion";
return false; else
return s;
uint32_t addr = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]; }
addr = htonl(addr); else {
memcpy(result->s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix)); char s[INET6_ADDRSTRLEN];
memcpy(&result->s6_addr[12], &addr, sizeof(uint32_t));
return true; if ( ! zeek_inet_ntop(AF_INET6, in6.s6_addr, s, INET6_ADDRSTRLEN) )
} return "<bad IPv6 address conversion";
else
void IPAddr::Init(const char* s) return s;
{ }
if ( ! ConvertString(s, &in6) ) }
{
reporter->Error("Bad IP address: %s", s); std::string IPAddr::AsHexString() const {
memset(in6.s6_addr, 0, sizeof(in6.s6_addr)); char buf[33];
}
} if ( GetFamily() == IPv4 ) {
uint32_t* p = (uint32_t*)&in6.s6_addr[12];
std::string IPAddr::AsString() const snprintf(buf, sizeof(buf), "%08x", (uint32_t)ntohl(*p));
{ }
if ( GetFamily() == IPv4 ) else {
{ uint32_t* p = (uint32_t*)in6.s6_addr;
char s[INET_ADDRSTRLEN]; snprintf(buf, sizeof(buf), "%08x%08x%08x%08x", (uint32_t)ntohl(p[0]), (uint32_t)ntohl(p[1]),
(uint32_t)ntohl(p[2]), (uint32_t)ntohl(p[3]));
if ( ! zeek_inet_ntop(AF_INET, &in6.s6_addr[12], s, INET_ADDRSTRLEN) ) }
return "<bad IPv4 address conversion";
else return buf;
return s; }
}
else std::string IPAddr::PtrName() const {
{ if ( GetFamily() == IPv4 ) {
char s[INET6_ADDRSTRLEN]; char buf[256];
uint32_t* p = (uint32_t*)&in6.s6_addr[12];
if ( ! zeek_inet_ntop(AF_INET6, in6.s6_addr, s, INET6_ADDRSTRLEN) ) uint32_t a = ntohl(*p);
return "<bad IPv6 address conversion"; uint32_t a3 = (a >> 24) & 0xff;
else uint32_t a2 = (a >> 16) & 0xff;
return s; uint32_t a1 = (a >> 8) & 0xff;
} uint32_t a0 = a & 0xff;
} snprintf(buf, sizeof(buf), "%u.%u.%u.%u.in-addr.arpa", a0, a1, a2, a3);
return buf;
std::string IPAddr::AsHexString() const }
{ else {
char buf[33]; static const char hex_digit[] = "0123456789abcdef";
std::string ptr_name("ip6.arpa");
if ( GetFamily() == IPv4 ) uint32_t* p = (uint32_t*)in6.s6_addr;
{
uint32_t* p = (uint32_t*)&in6.s6_addr[12]; for ( unsigned int i = 0; i < 4; ++i ) {
snprintf(buf, sizeof(buf), "%08x", (uint32_t)ntohl(*p)); uint32_t a = ntohl(p[i]);
} for ( unsigned int j = 1; j <= 8; ++j ) {
else ptr_name.insert(0, 1, '.');
{ ptr_name.insert(0, 1, hex_digit[(a >> (32 - j * 4)) & 0x0f]);
uint32_t* p = (uint32_t*)in6.s6_addr; }
snprintf(buf, sizeof(buf), "%08x%08x%08x%08x", (uint32_t)ntohl(p[0]), (uint32_t)ntohl(p[1]), }
(uint32_t)ntohl(p[2]), (uint32_t)ntohl(p[3]));
} return ptr_name;
}
return buf; }
}
IPPrefix::IPPrefix(const in4_addr& in4, uint8_t length) : prefix(in4), length(96 + length) {
std::string IPAddr::PtrName() const if ( length > 32 ) {
{ reporter->Error("Bad in4_addr IPPrefix length : %d", length);
if ( GetFamily() == IPv4 ) this->length = 0;
{ }
char buf[256];
uint32_t* p = (uint32_t*)&in6.s6_addr[12]; prefix.Mask(this->length);
uint32_t a = ntohl(*p); }
uint32_t a3 = (a >> 24) & 0xff;
uint32_t a2 = (a >> 16) & 0xff; IPPrefix::IPPrefix(const in6_addr& in6, uint8_t length) : prefix(in6), length(length) {
uint32_t a1 = (a >> 8) & 0xff; if ( length > 128 ) {
uint32_t a0 = a & 0xff; reporter->Error("Bad in6_addr IPPrefix length : %d", length);
snprintf(buf, sizeof(buf), "%u.%u.%u.%u.in-addr.arpa", a0, a1, a2, a3); this->length = 0;
return buf; }
}
else prefix.Mask(this->length);
{ }
static const char hex_digit[] = "0123456789abcdef";
std::string ptr_name("ip6.arpa"); bool IPAddr::CheckPrefixLength(uint8_t length, bool len_is_v6_relative) const {
uint32_t* p = (uint32_t*)in6.s6_addr; if ( GetFamily() == IPv4 && ! len_is_v6_relative ) {
if ( length > 32 )
for ( unsigned int i = 0; i < 4; ++i ) return false;
{ }
uint32_t a = ntohl(p[i]);
for ( unsigned int j = 1; j <= 8; ++j ) else {
{ if ( length > 128 )
ptr_name.insert(0, 1, '.'); return false;
ptr_name.insert(0, 1, hex_digit[(a >> (32 - j * 4)) & 0x0f]); }
}
} return true;
}
return ptr_name;
} IPPrefix::IPPrefix(const IPAddr& addr, uint8_t length, bool len_is_v6_relative) : prefix(addr) {
} if ( prefix.CheckPrefixLength(length, len_is_v6_relative) ) {
if ( prefix.GetFamily() == IPv4 && ! len_is_v6_relative )
IPPrefix::IPPrefix(const in4_addr& in4, uint8_t length) : prefix(in4), length(96 + length) this->length = length + 96;
{ else
if ( length > 32 ) this->length = length;
{ }
reporter->Error("Bad in4_addr IPPrefix length : %d", length); else {
this->length = 0; auto vstr = prefix.GetFamily() == IPv4 ? "v4" : "v6";
} reporter->Error("Bad IPAddr(%s) IPPrefix length : %d", vstr, length);
this->length = 0;
prefix.Mask(this->length); }
}
prefix.Mask(this->length);
IPPrefix::IPPrefix(const in6_addr& in6, uint8_t length) : prefix(in6), length(length) }
{
if ( length > 128 ) std::string IPPrefix::AsString() const {
{ char l[16];
reporter->Error("Bad in6_addr IPPrefix length : %d", length);
this->length = 0; if ( prefix.GetFamily() == IPv4 )
} modp_uitoa10(length - 96, l);
else
prefix.Mask(this->length); modp_uitoa10(length, l);
}
return prefix.AsString() + "/" + l;
bool IPAddr::CheckPrefixLength(uint8_t length, bool len_is_v6_relative) const }
{
if ( GetFamily() == IPv4 && ! len_is_v6_relative ) std::unique_ptr<detail::HashKey> IPPrefix::MakeHashKey() const {
{ struct {
if ( length > 32 ) in6_addr ip;
return false; uint32_t len;
} } key;
else key.ip = prefix.in6;
{ key.len = Length();
if ( length > 128 )
return false; return std::make_unique<detail::HashKey>(&key, sizeof(key));
} }
return true; bool IPPrefix::ConvertString(const char* text, IPPrefix* result) {
} std::string s(text);
size_t slash_loc = s.find('/');
IPPrefix::IPPrefix(const IPAddr& addr, uint8_t length, bool len_is_v6_relative) : prefix(addr)
{ if ( slash_loc == std::string::npos )
if ( prefix.CheckPrefixLength(length, len_is_v6_relative) ) return false;
{
if ( prefix.GetFamily() == IPv4 && ! len_is_v6_relative ) auto ip_str = s.substr(0, slash_loc);
this->length = length + 96; auto len = atoi(s.substr(slash_loc + 1).data());
else
this->length = length; in6_addr tmp;
}
else if ( ! IPAddr::ConvertString(ip_str.data(), &tmp) )
{ return false;
auto vstr = prefix.GetFamily() == IPv4 ? "v4" : "v6";
reporter->Error("Bad IPAddr(%s) IPPrefix length : %d", vstr, length); auto ip = IPAddr(tmp);
this->length = 0;
} if ( ! ip.CheckPrefixLength(len) )
return false;
prefix.Mask(this->length);
} *result = IPPrefix(ip, len);
return true;
std::string IPPrefix::AsString() const }
{
char l[16]; } // namespace zeek
if ( prefix.GetFamily() == IPv4 )
modp_uitoa10(length - 96, l);
else
modp_uitoa10(length, l);
return prefix.AsString() + "/" + l;
}
std::unique_ptr<detail::HashKey> IPPrefix::MakeHashKey() const
{
struct
{
in6_addr ip;
uint32_t len;
} key;
key.ip = prefix.in6;
key.len = Length();
return std::make_unique<detail::HashKey>(&key, sizeof(key));
}
bool IPPrefix::ConvertString(const char* text, IPPrefix* result)
{
std::string s(text);
size_t slash_loc = s.find('/');
if ( slash_loc == std::string::npos )
return false;
auto ip_str = s.substr(0, slash_loc);
auto len = atoi(s.substr(slash_loc + 1).data());
in6_addr tmp;
if ( ! IPAddr::ConvertString(ip_str.data(), &tmp) )
return false;
auto ip = IPAddr(tmp);
if ( ! ip.CheckPrefixLength(len) )
return false;
*result = IPPrefix(ip, len);
return true;
}
} // namespace zeek

File diff suppressed because it is too large Load diff

View file

@ -4,20 +4,18 @@
#include <cstring> #include <cstring>
namespace zeek::detail namespace zeek::detail {
{
void IntSet::Expand(unsigned int i) void IntSet::Expand(unsigned int i) {
{ unsigned int newsize = i / 8 + 1;
unsigned int newsize = i / 8 + 1; unsigned char* newset = new unsigned char[newsize];
unsigned char* newset = new unsigned char[newsize];
memset(newset, 0, newsize); memset(newset, 0, newsize);
memcpy(newset, set, size); memcpy(newset, set, size);
delete[] set; delete[] set;
size = newsize; size = newsize;
set = newset; set = newset;
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -8,65 +8,51 @@
#include <cstring> #include <cstring>
namespace zeek::detail namespace zeek::detail {
{
class IntSet class IntSet {
{
public: public:
// n is a hint for the value of the largest integer. // n is a hint for the value of the largest integer.
explicit IntSet(unsigned int n = 1); explicit IntSet(unsigned int n = 1);
~IntSet(); ~IntSet();
void Insert(unsigned int i); void Insert(unsigned int i);
void Remove(unsigned int i); void Remove(unsigned int i);
bool Contains(unsigned int i) const; bool Contains(unsigned int i) const;
void Clear(); void Clear();
private: private:
void Expand(unsigned int i); void Expand(unsigned int i);
unsigned int size; unsigned int size;
unsigned char* set; unsigned char* set;
}; };
inline IntSet::IntSet(unsigned int n) inline IntSet::IntSet(unsigned int n) {
{ size = n / 8 + 1;
size = n / 8 + 1; set = new unsigned char[size];
set = new unsigned char[size]; memset(set, 0, size);
memset(set, 0, size); }
}
inline IntSet::~IntSet() inline IntSet::~IntSet() { delete[] set; }
{
delete[] set;
}
inline void IntSet::Insert(unsigned int i) inline void IntSet::Insert(unsigned int i) {
{ if ( i / 8 >= size )
if ( i / 8 >= size ) Expand(i);
Expand(i);
set[i / 8] |= (1 << (i % 8)); set[i / 8] |= (1 << (i % 8));
} }
inline void IntSet::Remove(unsigned int i) inline void IntSet::Remove(unsigned int i) {
{ if ( i / 8 >= size )
if ( i / 8 >= size ) Expand(i);
Expand(i); else
else set[i / 8] &= ~(1 << (i % 8));
set[i / 8] &= ~(1 << (i % 8)); }
}
inline bool IntSet::Contains(unsigned int i) const inline bool IntSet::Contains(unsigned int i) const { return i / 8 < size ? set[i / 8] & (1 << (i % 8)) : false; }
{
return i / 8 < size ? set[i / 8] & (1 << (i % 8)) : false;
}
inline void IntSet::Clear() inline void IntSet::Clear() { memset(set, 0, size); }
{
memset(set, 0, size);
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -8,24 +8,19 @@
#include "Obj.h" #include "Obj.h"
namespace zeek namespace zeek {
{
/** /**
* A tag class for the #IntrusivePtr constructor which means: adopt * A tag class for the #IntrusivePtr constructor which means: adopt
* the reference from the caller. * the reference from the caller.
*/ */
struct AdoptRef struct AdoptRef {};
{
};
/** /**
* A tag class for the #IntrusivePtr constructor which means: create a * A tag class for the #IntrusivePtr constructor which means: create a
* new reference to the object. * new reference to the object.
*/ */
struct NewRef struct NewRef {};
{
};
/** /**
* This has to be forward declared and known here in order for us to be able * This has to be forward declared and known here in order for us to be able
@ -55,131 +50,120 @@ class OpaqueVal;
* should use a smart pointer whenever possible to reduce boilerplate code and * should use a smart pointer whenever possible to reduce boilerplate code and
* increase robustness of the code (in particular w.r.t. exceptions). * increase robustness of the code (in particular w.r.t. exceptions).
*/ */
template <class T> class IntrusivePtr template<class T>
{ class IntrusivePtr {
public: public:
// -- member types // -- member types
using pointer = T*; using pointer = T*;
using const_pointer = const T*; using const_pointer = const T*;
using element_type = T; using element_type = T;
using reference = T&; using reference = T&;
using const_reference = const T&; using const_reference = const T&;
// -- constructors, destructors, and assignment operators // -- constructors, destructors, and assignment operators
constexpr IntrusivePtr() noexcept = default; constexpr IntrusivePtr() noexcept = default;
constexpr IntrusivePtr(std::nullptr_t) noexcept : IntrusivePtr() constexpr IntrusivePtr(std::nullptr_t) noexcept : IntrusivePtr() {
{ // nop
// nop }
}
/** /**
* Constructs a new intrusive pointer for managing the lifetime of the object * Constructs a new intrusive pointer for managing the lifetime of the object
* pointed to by @c raw_ptr. * pointed to by @c raw_ptr.
* *
* This overload adopts the existing reference from the caller. * This overload adopts the existing reference from the caller.
* *
* @param raw_ptr Pointer to the shared object. * @param raw_ptr Pointer to the shared object.
*/ */
constexpr IntrusivePtr(AdoptRef, pointer raw_ptr) noexcept : ptr_(raw_ptr) { } constexpr IntrusivePtr(AdoptRef, pointer raw_ptr) noexcept : ptr_(raw_ptr) {}
/** /**
* Constructs a new intrusive pointer for managing the lifetime of the object * Constructs a new intrusive pointer for managing the lifetime of the object
* pointed to by @c raw_ptr. * pointed to by @c raw_ptr.
* *
* This overload adds a new reference. * This overload adds a new reference.
* *
* @param raw_ptr Pointer to the shared object. * @param raw_ptr Pointer to the shared object.
*/ */
IntrusivePtr(NewRef, pointer raw_ptr) noexcept : ptr_(raw_ptr) IntrusivePtr(NewRef, pointer raw_ptr) noexcept : ptr_(raw_ptr) {
{ if ( ptr_ )
if ( ptr_ ) Ref(ptr_);
Ref(ptr_); }
}
IntrusivePtr(IntrusivePtr&& other) noexcept : ptr_(other.release()) IntrusivePtr(IntrusivePtr&& other) noexcept : ptr_(other.release()) {
{ // nop
// nop }
}
IntrusivePtr(const IntrusivePtr& other) noexcept : IntrusivePtr(NewRef{}, other.get()) { } IntrusivePtr(const IntrusivePtr& other) noexcept : IntrusivePtr(NewRef{}, other.get()) {}
template <class U, class = std::enable_if_t<std::is_convertible_v<U*, T*>>> template<class U, class = std::enable_if_t<std::is_convertible_v<U*, T*>>>
IntrusivePtr(IntrusivePtr<U> other) noexcept : ptr_(other.release()) IntrusivePtr(IntrusivePtr<U> other) noexcept : ptr_(other.release()) {
{ // nop
// nop }
}
~IntrusivePtr() ~IntrusivePtr() {
{ if ( ptr_ ) {
if ( ptr_ ) // Specializing `OpaqueVal` as MSVC compiler does not detect it
{ // inheriting from `zeek::Obj` so we have to do that manually.
// Specializing `OpaqueVal` as MSVC compiler does not detect it if constexpr ( std::is_same_v<T, OpaqueVal> )
// inheriting from `zeek::Obj` so we have to do that manually. Unref(reinterpret_cast<zeek::Obj*>(ptr_));
if constexpr ( std::is_same_v<T, OpaqueVal> ) else
Unref(reinterpret_cast<zeek::Obj*>(ptr_)); Unref(ptr_);
else }
Unref(ptr_); }
}
}
void swap(IntrusivePtr& other) noexcept { std::swap(ptr_, other.ptr_); } void swap(IntrusivePtr& other) noexcept { std::swap(ptr_, other.ptr_); }
friend void swap(IntrusivePtr& a, IntrusivePtr& b) noexcept friend void swap(IntrusivePtr& a, IntrusivePtr& b) noexcept {
{ using std::swap;
using std::swap; swap(a.ptr_, b.ptr_);
swap(a.ptr_, b.ptr_); }
}
/** /**
* Detaches an object from the automated lifetime management and sets this * Detaches an object from the automated lifetime management and sets this
* intrusive pointer to @c nullptr. * intrusive pointer to @c nullptr.
* @returns the raw pointer without modifying the reference count. * @returns the raw pointer without modifying the reference count.
*/ */
pointer release() noexcept { return std::exchange(ptr_, nullptr); } pointer release() noexcept { return std::exchange(ptr_, nullptr); }
IntrusivePtr& operator=(const IntrusivePtr& other) noexcept IntrusivePtr& operator=(const IntrusivePtr& other) noexcept {
{ IntrusivePtr tmp{other};
IntrusivePtr tmp{other}; swap(tmp);
swap(tmp); return *this;
return *this; }
}
IntrusivePtr& operator=(IntrusivePtr&& other) noexcept IntrusivePtr& operator=(IntrusivePtr&& other) noexcept {
{ swap(other);
swap(other); return *this;
return *this; }
}
IntrusivePtr& operator=(std::nullptr_t) noexcept IntrusivePtr& operator=(std::nullptr_t) noexcept {
{ if ( ptr_ ) {
if ( ptr_ ) Unref(ptr_);
{ ptr_ = nullptr;
Unref(ptr_); }
ptr_ = nullptr; return *this;
} }
return *this;
}
pointer get() const noexcept { return ptr_; } pointer get() const noexcept { return ptr_; }
pointer operator->() const noexcept { return ptr_; } pointer operator->() const noexcept { return ptr_; }
reference operator*() const noexcept { return *ptr_; } reference operator*() const noexcept { return *ptr_; }
bool operator!() const noexcept { return ! ptr_; } bool operator!() const noexcept { return ! ptr_; }
explicit operator bool() const noexcept { return ptr_ != nullptr; } explicit operator bool() const noexcept { return ptr_ != nullptr; }
private: private:
pointer ptr_ = nullptr; pointer ptr_ = nullptr;
}; };
/** /**
* Convenience function for creating a reference counted object and wrapping it * Convenience function for creating a reference counted object and wrapping it
@ -189,11 +173,11 @@ private:
* @note This function assumes that any @c T starts with a reference count of 1. * @note This function assumes that any @c T starts with a reference count of 1.
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T, class... Ts> IntrusivePtr<T> make_intrusive(Ts&&... args) template<class T, class... Ts>
{ IntrusivePtr<T> make_intrusive(Ts&&... args) {
// Assumes that objects start with a reference count of 1! // Assumes that objects start with a reference count of 1!
return {AdoptRef{}, new T(std::forward<Ts>(args)...)}; return {AdoptRef{}, new T(std::forward<Ts>(args)...)};
} }
/** /**
* Casts an @c IntrusivePtr object to another by way of static_cast on * Casts an @c IntrusivePtr object to another by way of static_cast on
@ -201,78 +185,78 @@ template <class T, class... Ts> IntrusivePtr<T> make_intrusive(Ts&&... args)
* @param p The pointer of type @c U to cast to another type, @c T. * @param p The pointer of type @c U to cast to another type, @c T.
* @return The pointer, as cast to type @c T. * @return The pointer, as cast to type @c T.
*/ */
template <class T, class U> IntrusivePtr<T> cast_intrusive(IntrusivePtr<U> p) noexcept template<class T, class U>
{ IntrusivePtr<T> cast_intrusive(IntrusivePtr<U> p) noexcept {
return {AdoptRef{}, static_cast<T*>(p.release())}; return {AdoptRef{}, static_cast<T*>(p.release())};
} }
// -- comparison to nullptr ---------------------------------------------------- // -- comparison to nullptr ----------------------------------------------------
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator==(const zeek::IntrusivePtr<T>& x, std::nullptr_t) template<class T>
{ bool operator==(const zeek::IntrusivePtr<T>& x, std::nullptr_t) {
return ! x; return ! x;
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator==(std::nullptr_t, const zeek::IntrusivePtr<T>& x) template<class T>
{ bool operator==(std::nullptr_t, const zeek::IntrusivePtr<T>& x) {
return ! x; return ! x;
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator!=(const zeek::IntrusivePtr<T>& x, std::nullptr_t) template<class T>
{ bool operator!=(const zeek::IntrusivePtr<T>& x, std::nullptr_t) {
return static_cast<bool>(x); return static_cast<bool>(x);
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator!=(std::nullptr_t, const zeek::IntrusivePtr<T>& x) template<class T>
{ bool operator!=(std::nullptr_t, const zeek::IntrusivePtr<T>& x) {
return static_cast<bool>(x); return static_cast<bool>(x);
} }
// -- comparison to raw pointer ------------------------------------------------ // -- comparison to raw pointer ------------------------------------------------
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator==(const zeek::IntrusivePtr<T>& x, const T* y) template<class T>
{ bool operator==(const zeek::IntrusivePtr<T>& x, const T* y) {
return x.get() == y; return x.get() == y;
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator==(const T* x, const zeek::IntrusivePtr<T>& y) template<class T>
{ bool operator==(const T* x, const zeek::IntrusivePtr<T>& y) {
return x == y.get(); return x == y.get();
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator!=(const zeek::IntrusivePtr<T>& x, const T* y) template<class T>
{ bool operator!=(const zeek::IntrusivePtr<T>& x, const T* y) {
return x.get() != y; return x.get() != y;
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator!=(const T* x, const zeek::IntrusivePtr<T>& y) template<class T>
{ bool operator!=(const T* x, const zeek::IntrusivePtr<T>& y) {
return x != y.get(); return x != y.get();
} }
// -- comparison to intrusive pointer ------------------------------------------ // -- comparison to intrusive pointer ------------------------------------------
@ -282,35 +266,27 @@ template <class T> bool operator!=(const T* x, const zeek::IntrusivePtr<T>& y)
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T, class U> template<class T, class U>
auto operator==(const zeek::IntrusivePtr<T>& x, const zeek::IntrusivePtr<U>& y) auto operator==(const zeek::IntrusivePtr<T>& x, const zeek::IntrusivePtr<U>& y) -> decltype(x.get() == y.get()) {
-> decltype(x.get() == y.get()) return x.get() == y.get();
{ }
return x.get() == y.get();
}
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T, class U> template<class T, class U>
auto operator!=(const zeek::IntrusivePtr<T>& x, const zeek::IntrusivePtr<U>& y) auto operator!=(const zeek::IntrusivePtr<T>& x, const zeek::IntrusivePtr<U>& y) -> decltype(x.get() != y.get()) {
-> decltype(x.get() != y.get()) return x.get() != y.get();
{ }
return x.get() != y.get();
}
} // namespace zeek } // namespace zeek
// -- hashing ------------------------------------------------ // -- hashing ------------------------------------------------
namespace std namespace std {
{ template<class T>
template <class T> struct hash<zeek::IntrusivePtr<T>> struct hash<zeek::IntrusivePtr<T>> {
{ // Hash of intrusive pointer is the same as hash of the raw pointer it holds.
// Hash of intrusive pointer is the same as hash of the raw pointer it holds. size_t operator()(const zeek::IntrusivePtr<T>& v) const noexcept { return std::hash<T*>{}(v.get()); }
size_t operator()(const zeek::IntrusivePtr<T>& v) const noexcept };
{ } // namespace std
return std::hash<T*>{}(v.get());
}
};
}

View file

@ -2,130 +2,125 @@
#include "zeek/3rdparty/doctest.h" #include "zeek/3rdparty/doctest.h"
TEST_CASE("list construction") TEST_CASE("list construction") {
{ zeek::List<int> list;
zeek::List<int> list; CHECK(list.empty());
CHECK(list.empty());
zeek::List<int> list2(10); zeek::List<int> list2(10);
CHECK(list2.empty()); CHECK(list2.empty());
CHECK(list2.max() == 10); CHECK(list2.max() == 10);
} }
TEST_CASE("list operation") TEST_CASE("list operation") {
{ zeek::List<int> list({1, 2, 3});
zeek::List<int> list({1, 2, 3}); CHECK(list.size() == 3);
CHECK(list.size() == 3); CHECK(list.max() == 3);
CHECK(list.max() == 3); CHECK(list[0] == 1);
CHECK(list[0] == 1); CHECK(list[1] == 2);
CHECK(list[1] == 2); CHECK(list[2] == 3);
CHECK(list[2] == 3);
// push_back forces a resize of the list here, which grows the list // push_back forces a resize of the list here, which grows the list
// by a growth factor. That makes the max elements equal to 6. // by a growth factor. That makes the max elements equal to 6.
list.push_back(4); list.push_back(4);
CHECK(list.size() == 4); CHECK(list.size() == 4);
CHECK(list.max() == 6); CHECK(list.max() == 6);
CHECK(list[3] == 4); CHECK(list[3] == 4);
CHECK(list.front() == 1); CHECK(list.front() == 1);
CHECK(list.back() == 4); CHECK(list.back() == 4);
list.pop_front(); list.pop_front();
CHECK(list.size() == 3); CHECK(list.size() == 3);
CHECK(list.front() == 2); CHECK(list.front() == 2);
list.pop_back(); list.pop_back();
CHECK(list.size() == 2); CHECK(list.size() == 2);
CHECK(list.back() == 3); CHECK(list.back() == 3);
list.push_back(4); list.push_back(4);
CHECK(list.is_member(2)); CHECK(list.is_member(2));
CHECK(list.member_pos(2) == 0); CHECK(list.member_pos(2) == 0);
list.remove(2); list.remove(2);
CHECK(list.size() == 2); CHECK(list.size() == 2);
CHECK(list[0] == 3); CHECK(list[0] == 3);
CHECK(list[1] == 4); CHECK(list[1] == 4);
// Squash the list down to the existing elements. // Squash the list down to the existing elements.
list.resize(); list.resize();
CHECK(list.size() == 2); CHECK(list.size() == 2);
CHECK(list.max() == 2); CHECK(list.max() == 2);
// Attempt replacing a known position. // Attempt replacing a known position.
int old = list.replace(0, 10); int old = list.replace(0, 10);
CHECK(list.size() == 2); CHECK(list.size() == 2);
CHECK(list.max() == 2); CHECK(list.max() == 2);
CHECK(old == 3); CHECK(old == 3);
CHECK(list[0] == 10); CHECK(list[0] == 10);
CHECK(list[1] == 4); CHECK(list[1] == 4);
// Attempt replacing an element off the end of the list, which // Attempt replacing an element off the end of the list, which
// causes a resize. // causes a resize.
old = list.replace(3, 5); old = list.replace(3, 5);
CHECK(list.size() == 4); CHECK(list.size() == 4);
CHECK(list.max() == 4); CHECK(list.max() == 4);
CHECK(old == 0); CHECK(old == 0);
CHECK(list[0] == 10); CHECK(list[0] == 10);
CHECK(list[1] == 4); CHECK(list[1] == 4);
CHECK(list[2] == 0); CHECK(list[2] == 0);
CHECK(list[3] == 5); CHECK(list[3] == 5);
// Attempt replacing an element with a negative index, which returns the // Attempt replacing an element with a negative index, which returns the
// default value for the list type. // default value for the list type.
old = list.replace(-1, 50); old = list.replace(-1, 50);
CHECK(list.size() == 4); CHECK(list.size() == 4);
CHECK(list.max() == 4); CHECK(list.max() == 4);
CHECK(old == 0); CHECK(old == 0);
list.clear(); list.clear();
CHECK(list.size() == 0); CHECK(list.size() == 0);
CHECK(list.max() == 0); CHECK(list.max() == 0);
} }
TEST_CASE("list iteration") TEST_CASE("list iteration") {
{ zeek::List<int> list({1, 2, 3, 4});
zeek::List<int> list({1, 2, 3, 4});
int index = 1; int index = 1;
for ( int v : list ) for ( int v : list )
CHECK(v == index++); CHECK(v == index++);
index = 1; index = 1;
for ( auto it = list.begin(); it != list.end(); index++, ++it ) for ( auto it = list.begin(); it != list.end(); index++, ++it )
CHECK(*it == index); CHECK(*it == index);
} }
TEST_CASE("plists") TEST_CASE("plists") {
{ zeek::PList<int> list;
zeek::PList<int> list; list.push_back(new int{1});
list.push_back(new int{1}); list.push_back(new int{2});
list.push_back(new int{2}); list.push_back(new int{3});
list.push_back(new int{3});
CHECK(*list[0] == 1); CHECK(*list[0] == 1);
int* new_val = new int(5); int* new_val = new int(5);
auto old = list.replace(-1, new_val); auto old = list.replace(-1, new_val);
delete new_val; delete new_val;
CHECK(old == nullptr); CHECK(old == nullptr);
for ( auto v : list ) for ( auto v : list )
delete v; delete v;
list.clear(); list.clear();
} }
TEST_CASE("unordered list operation") TEST_CASE("unordered list operation") {
{ zeek::List<int, zeek::ListOrder::UNORDERED> list({1, 2, 3, 4});
zeek::List<int, zeek::ListOrder::UNORDERED> list({1, 2, 3, 4}); CHECK(list.size() == 4);
CHECK(list.size() == 4);
// An unordered list doesn't maintain the ordering of the elements when // An unordered list doesn't maintain the ordering of the elements when
// one is removed. It just swaps the last element into the hole. // one is removed. It just swaps the last element into the hole.
list.remove(2); list.remove(2);
CHECK(list.size() == 3); CHECK(list.size() == 3);
CHECK(list[0] == 1); CHECK(list[0] == 1);
CHECK(list[1] == 4); CHECK(list[1] == 4);
CHECK(list[2] == 3); CHECK(list[2] == 3);
} }

View file

@ -27,313 +27,293 @@
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek namespace zeek {
{
enum class ListOrder : int enum class ListOrder : int { ORDERED, UNORDERED };
{
ORDERED,
UNORDERED
};
template <typename T, ListOrder Order = ListOrder::ORDERED> class List template<typename T, ListOrder Order = ListOrder::ORDERED>
{ class List {
public: public:
constexpr static int DEFAULT_LIST_SIZE = 10; constexpr static int DEFAULT_LIST_SIZE = 10;
constexpr static int LIST_GROWTH_FACTOR = 2; constexpr static int LIST_GROWTH_FACTOR = 2;
~List() { free(entries); } ~List() { free(entries); }
explicit List(int size = 0) explicit List(int size = 0) {
{ num_entries = 0;
num_entries = 0;
if ( size <= 0 ) if ( size <= 0 ) {
{ max_entries = 0;
max_entries = 0; entries = nullptr;
entries = nullptr; return;
return; }
}
max_entries = size; max_entries = size;
entries = (T*)util::safe_malloc(max_entries * sizeof(T)); entries = (T*)util::safe_malloc(max_entries * sizeof(T));
} }
List(const List& b) List(const List& b) {
{ max_entries = b.max_entries;
max_entries = b.max_entries; num_entries = b.num_entries;
num_entries = b.num_entries;
if ( max_entries ) if ( max_entries )
entries = (T*)util::safe_malloc(max_entries * sizeof(T)); entries = (T*)util::safe_malloc(max_entries * sizeof(T));
else else
entries = nullptr; entries = nullptr;
for ( int i = 0; i < num_entries; ++i ) for ( int i = 0; i < num_entries; ++i )
entries[i] = b.entries[i]; entries[i] = b.entries[i];
} }
List(List&& b) List(List&& b) {
{ entries = b.entries;
entries = b.entries; num_entries = b.num_entries;
num_entries = b.num_entries; max_entries = b.max_entries;
max_entries = b.max_entries;
b.entries = nullptr; b.entries = nullptr;
b.num_entries = b.max_entries = 0; b.num_entries = b.max_entries = 0;
} }
List(const T* arr, int n) List(const T* arr, int n) {
{ num_entries = max_entries = n;
num_entries = max_entries = n; entries = (T*)util::safe_malloc(max_entries * sizeof(T));
entries = (T*)util::safe_malloc(max_entries * sizeof(T)); memcpy(entries, arr, n * sizeof(T));
memcpy(entries, arr, n * sizeof(T)); }
}
List(std::initializer_list<T> il) : List(il.begin(), il.size()) { } List(std::initializer_list<T> il) : List(il.begin(), il.size()) {}
List& operator=(const List& b) List& operator=(const List& b) {
{ if ( this == &b )
if ( this == &b ) return *this;
return *this;
free(entries); free(entries);
max_entries = b.max_entries; max_entries = b.max_entries;
num_entries = b.num_entries; num_entries = b.num_entries;
if ( max_entries ) if ( max_entries )
entries = (T*)util::safe_malloc(max_entries * sizeof(T)); entries = (T*)util::safe_malloc(max_entries * sizeof(T));
else else
entries = nullptr; entries = nullptr;
for ( int i = 0; i < num_entries; ++i ) for ( int i = 0; i < num_entries; ++i )
entries[i] = b.entries[i]; entries[i] = b.entries[i];
return *this; return *this;
} }
List& operator=(List&& b) List& operator=(List&& b) {
{ if ( this == &b )
if ( this == &b ) return *this;
return *this;
free(entries); free(entries);
entries = b.entries; entries = b.entries;
num_entries = b.num_entries; num_entries = b.num_entries;
max_entries = b.max_entries; max_entries = b.max_entries;
b.entries = nullptr; b.entries = nullptr;
b.num_entries = b.max_entries = 0; b.num_entries = b.max_entries = 0;
return *this; return *this;
} }
// Return nth ent of list (do not remove). // Return nth ent of list (do not remove).
T& operator[](int i) const { return entries[i]; } T& operator[](int i) const { return entries[i]; }
void clear() // remove all entries void clear() // remove all entries
{ {
free(entries); free(entries);
entries = nullptr; entries = nullptr;
num_entries = max_entries = 0; num_entries = max_entries = 0;
} }
bool empty() const noexcept { return num_entries == 0; } bool empty() const noexcept { return num_entries == 0; }
size_t size() const noexcept { return num_entries; } size_t size() const noexcept { return num_entries; }
int length() const { return num_entries; } int length() const { return num_entries; }
int max() const { return max_entries; } int max() const { return max_entries; }
int resize(int new_size = 0) // 0 => size to fit current number of entries int resize(int new_size = 0) // 0 => size to fit current number of entries
{ {
if ( new_size < num_entries ) if ( new_size < num_entries )
new_size = num_entries; // do not lose any entries new_size = num_entries; // do not lose any entries
if ( new_size != max_entries ) if ( new_size != max_entries ) {
{ entries = (T*)util::safe_realloc((void*)entries, sizeof(T) * new_size);
entries = (T*)util::safe_realloc((void*)entries, sizeof(T) * new_size); if ( entries )
if ( entries ) max_entries = new_size;
max_entries = new_size; else
else max_entries = 0;
max_entries = 0; }
}
return max_entries; return max_entries;
} }
void push_front(const T& a) void push_front(const T& a) {
{ if ( num_entries == max_entries )
if ( num_entries == max_entries ) resize(max_entries ? max_entries * LIST_GROWTH_FACTOR : DEFAULT_LIST_SIZE);
resize(max_entries ? max_entries * LIST_GROWTH_FACTOR : DEFAULT_LIST_SIZE);
for ( int i = num_entries; i > 0; --i ) for ( int i = num_entries; i > 0; --i )
entries[i] = entries[i - 1]; // move all pointers up one entries[i] = entries[i - 1]; // move all pointers up one
++num_entries; ++num_entries;
entries[0] = a; entries[0] = a;
} }
void push_back(const T& a) void push_back(const T& a) {
{ if ( num_entries == max_entries )
if ( num_entries == max_entries ) resize(max_entries ? max_entries * LIST_GROWTH_FACTOR : DEFAULT_LIST_SIZE);
resize(max_entries ? max_entries * LIST_GROWTH_FACTOR : DEFAULT_LIST_SIZE);
entries[num_entries++] = a; entries[num_entries++] = a;
} }
void pop_front() { remove_nth(0); } void pop_front() { remove_nth(0); }
void pop_back() { remove_nth(num_entries - 1); } void pop_back() { remove_nth(num_entries - 1); }
T& front() { return entries[0]; } T& front() { return entries[0]; }
T& back() { return entries[num_entries - 1]; } T& back() { return entries[num_entries - 1]; }
// The append method is maintained for historical/compatibility reasons. // The append method is maintained for historical/compatibility reasons.
// (It's commonly used in the event generation API) // (It's commonly used in the event generation API)
void append(const T& a) // add to end of list void append(const T& a) // add to end of list
{ {
push_back(a); push_back(a);
} }
bool remove(const T& a) // delete entry from list bool remove(const T& a) // delete entry from list
{ {
int pos = member_pos(a); int pos = member_pos(a);
if ( pos != -1 ) if ( pos != -1 ) {
{ remove_nth(pos);
remove_nth(pos); return true;
return true; }
}
return false; return false;
} }
T remove_nth(int n) // delete nth entry from list T remove_nth(int n) // delete nth entry from list
{ {
assert(n >= 0 && n < num_entries); assert(n >= 0 && n < num_entries);
T old_ent = entries[n]; T old_ent = entries[n];
// For data where we don't care about ordering, we don't care about keeping // For data where we don't care about ordering, we don't care about keeping
// the list in the same order when removing an element. Just swap the last // the list in the same order when removing an element. Just swap the last
// element with the element being removed. // element with the element being removed.
if constexpr ( Order == ListOrder::ORDERED ) if constexpr ( Order == ListOrder::ORDERED ) {
{ --num_entries;
--num_entries;
for ( ; n < num_entries; ++n ) for ( ; n < num_entries; ++n )
entries[n] = entries[n + 1]; entries[n] = entries[n + 1];
} }
else else {
{ entries[n] = entries[num_entries - 1];
entries[n] = entries[num_entries - 1]; --num_entries;
--num_entries; }
}
return old_ent; return old_ent;
} }
// Return 0 if ent is not in the list, ent otherwise. // Return 0 if ent is not in the list, ent otherwise.
bool is_member(const T& a) const bool is_member(const T& a) const {
{ int pos = member_pos(a);
int pos = member_pos(a); return pos != -1;
return pos != -1; }
}
// Returns -1 if ent is not in the list, otherwise its position. // Returns -1 if ent is not in the list, otherwise its position.
int member_pos(const T& e) const int member_pos(const T& e) const {
{ int i;
int i; for ( i = 0; i < length() && e != entries[i]; ++i )
for ( i = 0; i < length() && e != entries[i]; ++i ) ;
;
return (i == length()) ? -1 : i; return (i == length()) ? -1 : i;
} }
T replace(int ent_index, const T& new_ent) // replace entry #i with a new value T replace(int ent_index, const T& new_ent) // replace entry #i with a new value
{ {
if ( ent_index < 0 ) if ( ent_index < 0 )
return T{}; return T{};
T old_ent{}; T old_ent{};
if ( ent_index > num_entries - 1 ) if ( ent_index > num_entries - 1 ) { // replacement beyond the end of the list
{ // replacement beyond the end of the list resize(ent_index + 1);
resize(ent_index + 1);
for ( int i = num_entries; i < max_entries; ++i ) for ( int i = num_entries; i < max_entries; ++i )
entries[i] = T{}; entries[i] = T{};
num_entries = max_entries; num_entries = max_entries;
} }
else else
old_ent = entries[ent_index]; old_ent = entries[ent_index];
entries[ent_index] = new_ent; entries[ent_index] = new_ent;
return old_ent; return old_ent;
} }
// Type traits needed for some of the std algorithms to work // Type traits needed for some of the std algorithms to work
using value_type = T; using value_type = T;
using pointer = T*; using pointer = T*;
using const_pointer = const T*; using const_pointer = const T*;
// Iterator support // Iterator support
using iterator = pointer; using iterator = pointer;
using const_iterator = const_pointer; using const_iterator = const_pointer;
using reverse_iterator = std::reverse_iterator<iterator>; using reverse_iterator = std::reverse_iterator<iterator>;
using const_reverse_iterator = std::reverse_iterator<const_iterator>; using const_reverse_iterator = std::reverse_iterator<const_iterator>;
iterator begin() { return entries; } iterator begin() { return entries; }
iterator end() { return entries + num_entries; } iterator end() { return entries + num_entries; }
const_iterator begin() const { return entries; } const_iterator begin() const { return entries; }
const_iterator end() const { return entries + num_entries; } const_iterator end() const { return entries + num_entries; }
const_iterator cbegin() const { return entries; } const_iterator cbegin() const { return entries; }
const_iterator cend() const { return entries + num_entries; } const_iterator cend() const { return entries + num_entries; }
reverse_iterator rbegin() { return reverse_iterator{end()}; } reverse_iterator rbegin() { return reverse_iterator{end()}; }
reverse_iterator rend() { return reverse_iterator{begin()}; } reverse_iterator rend() { return reverse_iterator{begin()}; }
const_reverse_iterator rbegin() const { return const_reverse_iterator{end()}; } const_reverse_iterator rbegin() const { return const_reverse_iterator{end()}; }
const_reverse_iterator rend() const { return const_reverse_iterator{begin()}; } const_reverse_iterator rend() const { return const_reverse_iterator{begin()}; }
const_reverse_iterator crbegin() const { return rbegin(); } const_reverse_iterator crbegin() const { return rbegin(); }
const_reverse_iterator crend() const { return rend(); } const_reverse_iterator crend() const { return rend(); }
protected: protected:
// This could essentially be an std::vector if we wanted. Some // This could essentially be an std::vector if we wanted. Some
// reasons to maybe not refactor to use std::vector ? // reasons to maybe not refactor to use std::vector ?
// //
// - Harder to use a custom growth factor. Also, the growth // - Harder to use a custom growth factor. Also, the growth
// factor would be implementation-specific, taking some control over // factor would be implementation-specific, taking some control over
// performance out of our hands. // performance out of our hands.
// //
// - It won't ever take advantage of realloc's occasional ability to // - It won't ever take advantage of realloc's occasional ability to
// grow in-place. // grow in-place.
// //
// - Combine above point this with lack of control of growth // - Combine above point this with lack of control of growth
// factor means the common choice of 2x growth factor causes // factor means the common choice of 2x growth factor causes
// a growth pattern that crawls forward in memory with no possible // a growth pattern that crawls forward in memory with no possible
// re-use of previous chunks (the new capacity is always larger than // re-use of previous chunks (the new capacity is always larger than
// all previously allocated chunks combined). This point and // all previously allocated chunks combined). This point and
// whether 2x is empirically an issue still seems debated (at least // whether 2x is empirically an issue still seems debated (at least
// GCC seems to stand by 2x as empirically better). // GCC seems to stand by 2x as empirically better).
// //
// - Sketchy shrinking behavior: standard says that requests to // - Sketchy shrinking behavior: standard says that requests to
// shrink are non-binding (it's expected implementations heed, but // shrink are non-binding (it's expected implementations heed, but
// still not great to have no guarantee). Also, it would not take // still not great to have no guarantee). Also, it would not take
// advantage of realloc's ability to contract in-place, it would // advantage of realloc's ability to contract in-place, it would
// allocate-and-copy. // allocate-and-copy.
T* entries; T* entries;
int max_entries; int max_entries;
int num_entries; int num_entries;
}; };
// Specialization of the List class to store pointers of a type. // Specialization of the List class to store pointers of a type.
template <typename T, ListOrder Order = ListOrder::ORDERED> using PList = List<T*, Order>; template<typename T, ListOrder Order = ListOrder::ORDERED>
using PList = List<T*, Order>;
// Popular type of list: list of strings. // Popular type of list: list of strings.
using name_list = PList<char>; using name_list = PList<char>;
} // namespace zeek } // namespace zeek
// Macro to visit each list element in turn. // Macro to visit each list element in turn.
#define loop_over_list(list, iterator) \ #define loop_over_list(list, iterator) \
int iterator; \ int iterator; \
for ( iterator = 0; iterator < (list).length(); ++iterator ) for ( iterator = 0; iterator < (list).length(); ++iterator )

View file

@ -10,360 +10,314 @@
#include "zeek/EquivClass.h" #include "zeek/EquivClass.h"
#include "zeek/IntSet.h" #include "zeek/IntSet.h"
namespace zeek::detail namespace zeek::detail {
{
static int nfa_state_id = 0; static int nfa_state_id = 0;
NFA_State::NFA_State(int arg_sym, EquivClass* ec) NFA_State::NFA_State(int arg_sym, EquivClass* ec) {
{ sym = arg_sym;
sym = arg_sym; ccl = nullptr;
ccl = nullptr; accept = NO_ACCEPT;
accept = NO_ACCEPT; first_trans_is_back_ref = false;
first_trans_is_back_ref = false; mark = nullptr;
mark = nullptr; epsclosure = nullptr;
epsclosure = nullptr; id = ++nfa_state_id;
id = ++nfa_state_id;
// Fix up equivalence classes based on this transition. Note that any
// Fix up equivalence classes based on this transition. Note that any // character which has its own transition gets its own equivalence
// character which has its own transition gets its own equivalence // class. Thus only characters which are only in character classes
// class. Thus only characters which are only in character classes // have a chance at being in the same equivalence class. E.g. "a|b"
// have a chance at being in the same equivalence class. E.g. "a|b" // puts 'a' and 'b' into two different equivalence classes. "[ab]"
// puts 'a' and 'b' into two different equivalence classes. "[ab]" // puts them in the same equivalence class (barring other differences
// puts them in the same equivalence class (barring other differences // elsewhere in the input).
// elsewhere in the input).
if ( ec && sym != SYM_EPSILON /* no associated symbol */ )
if ( ec && sym != SYM_EPSILON /* no associated symbol */ ) ec->UniqueChar(sym);
ec->UniqueChar(sym); }
}
NFA_State::NFA_State(CCL* arg_ccl) {
NFA_State::NFA_State(CCL* arg_ccl) sym = SYM_CCL;
{ ccl = arg_ccl;
sym = SYM_CCL; accept = NO_ACCEPT;
ccl = arg_ccl; first_trans_is_back_ref = false;
accept = NO_ACCEPT; mark = nullptr;
first_trans_is_back_ref = false; id = ++nfa_state_id;
mark = nullptr; epsclosure = nullptr;
id = ++nfa_state_id; }
epsclosure = nullptr;
} NFA_State::~NFA_State() {
for ( int i = 0; i < xtions.length(); ++i )
NFA_State::~NFA_State() if ( i > 0 || ! first_trans_is_back_ref )
{ Unref(xtions[i]);
for ( int i = 0; i < xtions.length(); ++i )
if ( i > 0 || ! first_trans_is_back_ref ) delete epsclosure;
Unref(xtions[i]); }
delete epsclosure; void NFA_State::AddXtionsTo(NFA_state_list* ns) {
} for ( int i = 0; i < xtions.length(); ++i )
ns->push_back(xtions[i]);
void NFA_State::AddXtionsTo(NFA_state_list* ns) }
{
for ( int i = 0; i < xtions.length(); ++i ) NFA_State* NFA_State::DeepCopy() {
ns->push_back(xtions[i]); if ( mark ) {
} Ref(mark);
return mark;
NFA_State* NFA_State::DeepCopy() }
{
if ( mark ) NFA_State* copy = ccl ? new NFA_State(ccl) : new NFA_State(sym, nullptr);
{ SetMark(copy);
Ref(mark);
return mark; for ( int i = 0; i < xtions.length(); ++i )
} copy->AddXtion(xtions[i]->DeepCopy());
NFA_State* copy = ccl ? new NFA_State(ccl) : new NFA_State(sym, nullptr); return copy;
SetMark(copy); }
for ( int i = 0; i < xtions.length(); ++i ) void NFA_State::ClearMarks() {
copy->AddXtion(xtions[i]->DeepCopy()); if ( mark ) {
SetMark(nullptr);
return copy; for ( int i = 0; i < xtions.length(); ++i )
} xtions[i]->ClearMarks();
}
void NFA_State::ClearMarks() }
{
if ( mark ) NFA_state_list* NFA_State::EpsilonClosure() {
{ if ( epsclosure )
SetMark(nullptr); return epsclosure;
for ( int i = 0; i < xtions.length(); ++i )
xtions[i]->ClearMarks(); epsclosure = new NFA_state_list;
}
} NFA_state_list states;
states.push_back(this);
NFA_state_list* NFA_State::EpsilonClosure() SetMark(this);
{
if ( epsclosure ) int i;
return epsclosure; for ( i = 0; i < states.length(); ++i ) {
NFA_State* ns = states[i];
epsclosure = new NFA_state_list; if ( ns->TransSym() == SYM_EPSILON ) {
NFA_state_list* x = ns->Transitions();
NFA_state_list states; for ( int j = 0; j < x->length(); ++j ) {
states.push_back(this); NFA_State* nxt = (*x)[j];
SetMark(this); if ( ! nxt->Mark() ) {
states.push_back(nxt);
int i; nxt->SetMark(nxt);
for ( i = 0; i < states.length(); ++i ) }
{ }
NFA_State* ns = states[i];
if ( ns->TransSym() == SYM_EPSILON ) if ( ns->Accept() != NO_ACCEPT )
{ epsclosure->push_back(ns);
NFA_state_list* x = ns->Transitions(); }
for ( int j = 0; j < x->length(); ++j )
{ else
NFA_State* nxt = (*x)[j]; // Non-epsilon transition - keep it.
if ( ! nxt->Mark() ) epsclosure->push_back(ns);
{ }
states.push_back(nxt);
nxt->SetMark(nxt); // Clear out markers.
} for ( i = 0; i < states.length(); ++i )
} states[i]->SetMark(nullptr);
if ( ns->Accept() != NO_ACCEPT ) // Make it fit.
epsclosure->push_back(ns); epsclosure->resize(0);
}
return epsclosure;
else }
// Non-epsilon transition - keep it.
epsclosure->push_back(ns); void NFA_State::Describe(ODesc* d) const { d->Add("NFA state"); }
}
void NFA_State::Dump(FILE* f) {
// Clear out markers. if ( mark )
for ( i = 0; i < states.length(); ++i ) return;
states[i]->SetMark(nullptr);
fprintf(f, "NFA state %d, sym = %d, accept = %d:\n", id, sym, accept);
// Make it fit.
epsclosure->resize(0); for ( int i = 0; i < xtions.length(); ++i )
fprintf(f, "\ttransition to %d\n", xtions[i]->ID());
return epsclosure;
} SetMark(this);
for ( int i = 0; i < xtions.length(); ++i )
void NFA_State::Describe(ODesc* d) const xtions[i]->Dump(f);
{ }
d->Add("NFA state");
} NFA_Machine::NFA_Machine(NFA_State* first, NFA_State* final) {
first_state = first;
void NFA_State::Dump(FILE* f) final_state = final ? final : first;
{ eol = bol = 0;
if ( mark ) }
return;
NFA_Machine::~NFA_Machine() { Unref(first_state); }
fprintf(f, "NFA state %d, sym = %d, accept = %d:\n", id, sym, accept);
void NFA_Machine::InsertEpsilon() {
for ( int i = 0; i < xtions.length(); ++i ) NFA_State* eps = new EpsilonState();
fprintf(f, "\ttransition to %d\n", xtions[i]->ID()); eps->AddXtion(first_state);
first_state = eps;
SetMark(this); }
for ( int i = 0; i < xtions.length(); ++i )
xtions[i]->Dump(f); void NFA_Machine::AppendEpsilon() { AppendState(new EpsilonState()); }
}
void NFA_Machine::AddAccept(int accept_val) {
NFA_Machine::NFA_Machine(NFA_State* first, NFA_State* final) // Hang the accepting number off an epsilon state. If it is associated
{ // with a state that has a non-epsilon out-transition, then the state
first_state = first; // will accept BEFORE it makes that transition, i.e., one character
final_state = final ? final : first; // too soon.
eol = bol = 0;
} if ( final_state->TransSym() != SYM_EPSILON )
AppendState(new EpsilonState());
NFA_Machine::~NFA_Machine()
{ final_state->SetAccept(accept_val);
Unref(first_state); }
}
void NFA_Machine::LinkCopies(int n) {
void NFA_Machine::InsertEpsilon() if ( n <= 0 )
{ return;
NFA_State* eps = new EpsilonState();
eps->AddXtion(first_state); // Make all the copies before doing any appending, otherwise
first_state = eps; // subsequent DuplicateMachine()'s will include the extra
} // copies!
NFA_Machine** copies = new NFA_Machine*[n];
void NFA_Machine::AppendEpsilon()
{ int i;
AppendState(new EpsilonState()); for ( i = 0; i < n; ++i )
} copies[i] = DuplicateMachine();
void NFA_Machine::AddAccept(int accept_val) for ( i = 0; i < n; ++i )
{ AppendMachine(copies[i]);
// Hang the accepting number off an epsilon state. If it is associated
// with a state that has a non-epsilon out-transition, then the state delete[] copies;
// will accept BEFORE it makes that transition, i.e., one character }
// too soon.
NFA_Machine* NFA_Machine::DuplicateMachine() {
if ( final_state->TransSym() != SYM_EPSILON ) NFA_State* new_first_state = first_state->DeepCopy();
AppendState(new EpsilonState()); NFA_Machine* new_m = new NFA_Machine(new_first_state, final_state->Mark());
first_state->ClearMarks();
final_state->SetAccept(accept_val);
} return new_m;
}
void NFA_Machine::LinkCopies(int n)
{ void NFA_Machine::AppendState(NFA_State* s) {
if ( n <= 0 ) final_state->AddXtion(s);
return; final_state = s;
}
// Make all the copies before doing any appending, otherwise
// subsequent DuplicateMachine()'s will include the extra void NFA_Machine::AppendMachine(NFA_Machine* m) {
// copies! AppendEpsilon();
NFA_Machine** copies = new NFA_Machine*[n]; final_state->AddXtion(m->FirstState());
final_state = m->FinalState();
int i;
for ( i = 0; i < n; ++i ) Ref(m->FirstState()); // so states stay around after the following
copies[i] = DuplicateMachine(); Unref(m);
}
for ( i = 0; i < n; ++i )
AppendMachine(copies[i]); void NFA_Machine::MakeOptional() {
InsertEpsilon();
delete[] copies; AppendEpsilon();
} first_state->AddXtion(final_state);
Ref(final_state);
NFA_Machine* NFA_Machine::DuplicateMachine() }
{
NFA_State* new_first_state = first_state->DeepCopy(); void NFA_Machine::MakePositiveClosure() {
NFA_Machine* new_m = new NFA_Machine(new_first_state, final_state->Mark()); AppendEpsilon();
first_state->ClearMarks(); final_state->AddXtion(first_state);
return new_m; // Don't Ref the state the final epsilon points to, otherwise we'll
} // have reference cycles that lead to leaks.
final_state->SetFirstTransIsBackRef();
void NFA_Machine::AppendState(NFA_State* s) }
{
final_state->AddXtion(s); void NFA_Machine::MakeRepl(int lower, int upper) {
final_state = s; NFA_Machine* dup = nullptr;
} if ( upper > lower || upper == NO_UPPER_BOUND )
dup = DuplicateMachine();
void NFA_Machine::AppendMachine(NFA_Machine* m)
{ LinkCopies(lower - 1);
AppendEpsilon();
final_state->AddXtion(m->FirstState()); if ( upper == NO_UPPER_BOUND ) {
final_state = m->FinalState(); dup->MakeClosure();
AppendMachine(dup);
Ref(m->FirstState()); // so states stay around after the following return;
Unref(m); }
}
while ( upper > lower ) {
void NFA_Machine::MakeOptional() NFA_Machine* dup2;
{ if ( --upper == lower )
InsertEpsilon(); // Don't need "dup" for any further copies
AppendEpsilon(); dup2 = dup;
first_state->AddXtion(final_state); else
Ref(final_state); dup2 = dup->DuplicateMachine();
}
dup2->MakeOptional();
void NFA_Machine::MakePositiveClosure() AppendMachine(dup2);
{ }
AppendEpsilon(); }
final_state->AddXtion(first_state);
void NFA_Machine::Describe(ODesc* d) const { d->Add("NFA machine"); }
// Don't Ref the state the final epsilon points to, otherwise we'll
// have reference cycles that lead to leaks. void NFA_Machine::Dump(FILE* f) {
final_state->SetFirstTransIsBackRef(); first_state->Dump(f);
} first_state->ClearMarks();
}
void NFA_Machine::MakeRepl(int lower, int upper)
{ NFA_Machine* make_alternate(NFA_Machine* m1, NFA_Machine* m2) {
NFA_Machine* dup = nullptr; if ( ! m1 )
if ( upper > lower || upper == NO_UPPER_BOUND ) return m2;
dup = DuplicateMachine(); if ( ! m2 )
return m1;
LinkCopies(lower - 1);
NFA_State* first = new EpsilonState();
if ( upper == NO_UPPER_BOUND ) NFA_State* last = new EpsilonState();
{
dup->MakeClosure(); first->AddXtion(m1->FirstState());
AppendMachine(dup); first->AddXtion(m2->FirstState());
return;
} m1->AppendState(last);
m2->AppendState(last);
while ( upper > lower ) Ref(last);
{
NFA_Machine* dup2; // Keep these around.
if ( --upper == lower ) Ref(m1->FirstState());
// Don't need "dup" for any further copies Ref(m2->FirstState());
dup2 = dup;
else Unref(m1);
dup2 = dup->DuplicateMachine(); Unref(m2);
dup2->MakeOptional(); return new NFA_Machine(first, last);
AppendMachine(dup2); }
}
} NFA_state_list* epsilon_closure(NFA_state_list* states) {
// We just keep one of this as it may get quite large.
void NFA_Machine::Describe(ODesc* d) const static IntSet closuremap;
{ closuremap.Clear();
d->Add("NFA machine");
} NFA_state_list* closure = new NFA_state_list;
void NFA_Machine::Dump(FILE* f) for ( int i = 0; i < states->length(); ++i ) {
{ NFA_state_list* stateclosure = (*states)[i]->EpsilonClosure();
first_state->Dump(f);
first_state->ClearMarks(); for ( int j = 0; j < stateclosure->length(); ++j ) {
} NFA_State* ns = (*stateclosure)[j];
if ( ! closuremap.Contains(ns->ID()) ) {
NFA_Machine* make_alternate(NFA_Machine* m1, NFA_Machine* m2) closuremap.Insert(ns->ID());
{ closure->push_back(ns);
if ( ! m1 ) }
return m2; }
if ( ! m2 ) }
return m1;
// Sort all of the closures in the list by ID
NFA_State* first = new EpsilonState(); std::sort(closure->begin(), closure->end(), NFA_state_cmp_neg);
NFA_State* last = new EpsilonState();
// Make it fit.
first->AddXtion(m1->FirstState()); closure->resize(0);
first->AddXtion(m2->FirstState());
delete states;
m1->AppendState(last);
m2->AppendState(last); return closure;
Ref(last); }
// Keep these around. bool NFA_state_cmp_neg(const NFA_State* v1, const NFA_State* v2) {
Ref(m1->FirstState()); if ( v1->ID() < v2->ID() )
Ref(m2->FirstState()); return true;
else
Unref(m1); return false;
Unref(m2); }
return new NFA_Machine(first, last); } // namespace zeek::detail
}
NFA_state_list* epsilon_closure(NFA_state_list* states)
{
// We just keep one of this as it may get quite large.
static IntSet closuremap;
closuremap.Clear();
NFA_state_list* closure = new NFA_state_list;
for ( int i = 0; i < states->length(); ++i )
{
NFA_state_list* stateclosure = (*states)[i]->EpsilonClosure();
for ( int j = 0; j < stateclosure->length(); ++j )
{
NFA_State* ns = (*stateclosure)[j];
if ( ! closuremap.Contains(ns->ID()) )
{
closuremap.Insert(ns->ID());
closure->push_back(ns);
}
}
}
// Sort all of the closures in the list by ID
std::sort(closure->begin(), closure->end(), NFA_state_cmp_neg);
// Make it fit.
closure->resize(0);
delete states;
return closure;
}
bool NFA_state_cmp_neg(const NFA_State* v1, const NFA_State* v2)
{
if ( v1->ID() < v2->ID() )
return true;
else
return false;
}
} // namespace zeek::detail

146
src/NFA.h
View file

@ -16,13 +16,11 @@
#define SYM_EPSILON 259 #define SYM_EPSILON 259
#define SYM_CCL 260 #define SYM_CCL 260
namespace zeek namespace zeek {
{
class Func; class Func;
namespace detail namespace detail {
{
class CCL; class CCL;
class EquivClass; class EquivClass;
@ -30,104 +28,100 @@ class EquivClass;
class NFA_State; class NFA_State;
using NFA_state_list = PList<NFA_State>; using NFA_state_list = PList<NFA_State>;
class NFA_State : public Obj class NFA_State : public Obj {
{
public: public:
NFA_State(int sym, EquivClass* ec); NFA_State(int sym, EquivClass* ec);
explicit NFA_State(CCL* ccl); explicit NFA_State(CCL* ccl);
~NFA_State() override; ~NFA_State() override;
void AddXtion(NFA_State* next_state) { xtions.push_back(next_state); } void AddXtion(NFA_State* next_state) { xtions.push_back(next_state); }
NFA_state_list* Transitions() { return &xtions; } NFA_state_list* Transitions() { return &xtions; }
void AddXtionsTo(NFA_state_list* ns); void AddXtionsTo(NFA_state_list* ns);
void SetAccept(int accept_val) { accept = accept_val; } void SetAccept(int accept_val) { accept = accept_val; }
int Accept() const { return accept; } int Accept() const { return accept; }
// Returns a deep copy of this NFA state and everything it points // Returns a deep copy of this NFA state and everything it points
// to. Upon return, each state's marker is set to point to its // to. Upon return, each state's marker is set to point to its
// copy. // copy.
NFA_State* DeepCopy(); NFA_State* DeepCopy();
void SetMark(NFA_State* m) { mark = m; } void SetMark(NFA_State* m) { mark = m; }
NFA_State* Mark() const { return mark; } NFA_State* Mark() const { return mark; }
void ClearMarks(); void ClearMarks();
void SetFirstTransIsBackRef() { first_trans_is_back_ref = true; } void SetFirstTransIsBackRef() { first_trans_is_back_ref = true; }
int TransSym() const { return sym; } int TransSym() const { return sym; }
CCL* TransCCL() const { return ccl; } CCL* TransCCL() const { return ccl; }
int ID() const { return id; } int ID() const { return id; }
NFA_state_list* EpsilonClosure(); NFA_state_list* EpsilonClosure();
void Describe(ODesc* d) const override; void Describe(ODesc* d) const override;
void Dump(FILE* f); void Dump(FILE* f);
protected: protected:
int sym; // if SYM_CCL, then use ccl int sym; // if SYM_CCL, then use ccl
int id; // number that uniquely identifies this state int id; // number that uniquely identifies this state
CCL* ccl; // if nil, then use sym CCL* ccl; // if nil, then use sym
int accept; int accept;
// Whether the first transition points backwards. Used // Whether the first transition points backwards. Used
// to avoid reference-counting loops. // to avoid reference-counting loops.
bool first_trans_is_back_ref; bool first_trans_is_back_ref;
NFA_state_list xtions; NFA_state_list xtions;
NFA_state_list* epsclosure; NFA_state_list* epsclosure;
NFA_State* mark; NFA_State* mark;
}; };
class EpsilonState : public NFA_State class EpsilonState : public NFA_State {
{
public: public:
EpsilonState() : NFA_State(SYM_EPSILON, nullptr) { } EpsilonState() : NFA_State(SYM_EPSILON, nullptr) {}
}; };
class NFA_Machine : public Obj class NFA_Machine : public Obj {
{
public: public:
explicit NFA_Machine(NFA_State* first, NFA_State* final = nullptr); explicit NFA_Machine(NFA_State* first, NFA_State* final = nullptr);
~NFA_Machine() override; ~NFA_Machine() override;
NFA_State* FirstState() const { return first_state; } NFA_State* FirstState() const { return first_state; }
void SetFinalState(NFA_State* final) { final_state = final; } void SetFinalState(NFA_State* final) { final_state = final; }
NFA_State* FinalState() const { return final_state; } NFA_State* FinalState() const { return final_state; }
void AddAccept(int accept_val); void AddAccept(int accept_val);
void MakeClosure() void MakeClosure() {
{ MakePositiveClosure();
MakePositiveClosure(); MakeOptional();
MakeOptional(); }
} void MakeOptional();
void MakeOptional(); void MakePositiveClosure();
void MakePositiveClosure();
// re{lower,upper}; upper can be NO_UPPER_BOUND = infinity. // re{lower,upper}; upper can be NO_UPPER_BOUND = infinity.
void MakeRepl(int lower, int upper); void MakeRepl(int lower, int upper);
void MarkBOL() { bol = 1; } void MarkBOL() { bol = 1; }
void MarkEOL() { eol = 1; } void MarkEOL() { eol = 1; }
NFA_Machine* DuplicateMachine(); NFA_Machine* DuplicateMachine();
void LinkCopies(int n); void LinkCopies(int n);
void InsertEpsilon(); void InsertEpsilon();
void AppendEpsilon(); void AppendEpsilon();
void AppendState(NFA_State* new_state); void AppendState(NFA_State* new_state);
void AppendMachine(NFA_Machine* new_mach); void AppendMachine(NFA_Machine* new_mach);
void Describe(ODesc* d) const override; void Describe(ODesc* d) const override;
void Dump(FILE* f); void Dump(FILE* f);
protected: protected:
NFA_State* first_state; NFA_State* first_state;
NFA_State* final_state; NFA_State* final_state;
int bol, eol; int bol, eol;
}; };
extern NFA_Machine* make_alternate(NFA_Machine* m1, NFA_Machine* m2); extern NFA_Machine* make_alternate(NFA_Machine* m1, NFA_Machine* m2);
@ -141,5 +135,5 @@ extern NFA_state_list* epsilon_closure(NFA_state_list* states);
// For sorting NFA states based on their ID fields (decreasing) // For sorting NFA states based on their ID fields (decreasing)
extern bool NFA_state_cmp_neg(const NFA_State* v1, const NFA_State* v2); extern bool NFA_state_cmp_neg(const NFA_State* v1, const NFA_State* v2);
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -105,8 +105,7 @@ zeek::StringVal* cmd_line_bpf_filter;
zeek::StringVal* global_hash_seed; zeek::StringVal* global_hash_seed;
namespace zeek::detail namespace zeek::detail {
{
int watchdog_interval; int watchdog_interval;
@ -194,29 +193,26 @@ zeek_uint_t bits_per_uid;
zeek_uint_t tunnel_max_changes_per_connection; zeek_uint_t tunnel_max_changes_per_connection;
} // namespace zeek::detail. The namespace has be closed here before we include the netvar_def } // namespace zeek::detail
// files. // files.
// Because of how the BIF include files are built with namespaces already in them, // Because of how the BIF include files are built with namespaces already in them,
// these files need to be included separately before the namespace is opened below. // these files need to be included separately before the namespace is opened below.
static void bif_init_event_handlers() static void bif_init_event_handlers() {
{
#include "event.bif.netvar_init" #include "event.bif.netvar_init"
} }
static void bif_init_net_var() static void bif_init_net_var() {
{
#include "const.bif.netvar_init" #include "const.bif.netvar_init"
#include "packet_analysis.bif.netvar_init" #include "packet_analysis.bif.netvar_init"
#include "reporter.bif.netvar_init" #include "reporter.bif.netvar_init"
#include "supervisor.bif.netvar_init" #include "supervisor.bif.netvar_init"
} }
static void init_bif_types() static void init_bif_types() {
{
#include "types.bif.netvar_init" #include "types.bif.netvar_init"
} }
#include "const.bif.netvar_def" #include "const.bif.netvar_def"
#include "event.bif.netvar_def" #include "event.bif.netvar_def"
@ -226,127 +222,116 @@ static void init_bif_types()
#include "types.bif.netvar_def" #include "types.bif.netvar_def"
// Re-open the namespace now that the bif headers are all included. // Re-open the namespace now that the bif headers are all included.
namespace zeek::detail namespace zeek::detail {
{
void init_event_handlers() void init_event_handlers() { bif_init_event_handlers(); }
{
bif_init_event_handlers();
}
void init_general_global_var() void init_general_global_var() {
{ table_expire_interval = id::find_val("table_expire_interval")->AsInterval();
table_expire_interval = id::find_val("table_expire_interval")->AsInterval(); table_expire_delay = id::find_val("table_expire_delay")->AsInterval();
table_expire_delay = id::find_val("table_expire_delay")->AsInterval(); table_incremental_step = id::find_val("table_incremental_step")->AsCount();
table_incremental_step = id::find_val("table_incremental_step")->AsCount(); packet_filter_default = id::find_val("packet_filter_default")->AsBool();
packet_filter_default = id::find_val("packet_filter_default")->AsBool(); sig_max_group_size = id::find_val("sig_max_group_size")->AsCount();
sig_max_group_size = id::find_val("sig_max_group_size")->AsCount(); check_for_unused_event_handlers = id::find_val("check_for_unused_event_handlers")->AsBool();
check_for_unused_event_handlers = id::find_val("check_for_unused_event_handlers")->AsBool(); record_all_packets = id::find_val("record_all_packets")->AsBool();
record_all_packets = id::find_val("record_all_packets")->AsBool(); bits_per_uid = id::find_val("bits_per_uid")->AsCount();
bits_per_uid = id::find_val("bits_per_uid")->AsCount(); }
}
void init_builtin_types() void init_builtin_types() {
{ init_bif_types();
init_bif_types(); id::detail::init_types();
id::detail::init_types(); }
}
void init_net_var() void init_net_var() {
{ bif_init_net_var();
bif_init_net_var();
ignore_checksums = id::find_val("ignore_checksums")->AsBool(); ignore_checksums = id::find_val("ignore_checksums")->AsBool();
partial_connection_ok = id::find_val("partial_connection_ok")->AsBool(); partial_connection_ok = id::find_val("partial_connection_ok")->AsBool();
tcp_SYN_ack_ok = id::find_val("tcp_SYN_ack_ok")->AsBool(); tcp_SYN_ack_ok = id::find_val("tcp_SYN_ack_ok")->AsBool();
tcp_match_undelivered = id::find_val("tcp_match_undelivered")->AsBool(); tcp_match_undelivered = id::find_val("tcp_match_undelivered")->AsBool();
frag_timeout = id::find_val("frag_timeout")->AsInterval(); frag_timeout = id::find_val("frag_timeout")->AsInterval();
tcp_SYN_timeout = id::find_val("tcp_SYN_timeout")->AsInterval(); tcp_SYN_timeout = id::find_val("tcp_SYN_timeout")->AsInterval();
tcp_session_timer = id::find_val("tcp_session_timer")->AsInterval(); tcp_session_timer = id::find_val("tcp_session_timer")->AsInterval();
tcp_connection_linger = id::find_val("tcp_connection_linger")->AsInterval(); tcp_connection_linger = id::find_val("tcp_connection_linger")->AsInterval();
tcp_attempt_delay = id::find_val("tcp_attempt_delay")->AsInterval(); tcp_attempt_delay = id::find_val("tcp_attempt_delay")->AsInterval();
tcp_close_delay = id::find_val("tcp_close_delay")->AsInterval(); tcp_close_delay = id::find_val("tcp_close_delay")->AsInterval();
tcp_reset_delay = id::find_val("tcp_reset_delay")->AsInterval(); tcp_reset_delay = id::find_val("tcp_reset_delay")->AsInterval();
tcp_partial_close_delay = id::find_val("tcp_partial_close_delay")->AsInterval(); tcp_partial_close_delay = id::find_val("tcp_partial_close_delay")->AsInterval();
tcp_max_initial_window = id::find_val("tcp_max_initial_window")->AsCount(); tcp_max_initial_window = id::find_val("tcp_max_initial_window")->AsCount();
tcp_max_above_hole_without_any_acks = tcp_max_above_hole_without_any_acks = id::find_val("tcp_max_above_hole_without_any_acks")->AsCount();
id::find_val("tcp_max_above_hole_without_any_acks")->AsCount(); tcp_excessive_data_without_further_acks = id::find_val("tcp_excessive_data_without_further_acks")->AsCount();
tcp_excessive_data_without_further_acks = tcp_max_old_segments = id::find_val("tcp_max_old_segments")->AsCount();
id::find_val("tcp_excessive_data_without_further_acks")->AsCount();
tcp_max_old_segments = id::find_val("tcp_max_old_segments")->AsCount();
non_analyzed_lifetime = id::find_val("non_analyzed_lifetime")->AsInterval(); non_analyzed_lifetime = id::find_val("non_analyzed_lifetime")->AsInterval();
tcp_inactivity_timeout = id::find_val("tcp_inactivity_timeout")->AsInterval(); tcp_inactivity_timeout = id::find_val("tcp_inactivity_timeout")->AsInterval();
udp_inactivity_timeout = id::find_val("udp_inactivity_timeout")->AsInterval(); udp_inactivity_timeout = id::find_val("udp_inactivity_timeout")->AsInterval();
icmp_inactivity_timeout = id::find_val("icmp_inactivity_timeout")->AsInterval(); icmp_inactivity_timeout = id::find_val("icmp_inactivity_timeout")->AsInterval();
tcp_storm_thresh = id::find_val("tcp_storm_thresh")->AsCount(); tcp_storm_thresh = id::find_val("tcp_storm_thresh")->AsCount();
tcp_storm_interarrival_thresh = id::find_val("tcp_storm_interarrival_thresh")->AsInterval(); tcp_storm_interarrival_thresh = id::find_val("tcp_storm_interarrival_thresh")->AsInterval();
tcp_content_deliver_all_orig = bool(id::find_val("tcp_content_deliver_all_orig")->AsBool()); tcp_content_deliver_all_orig = bool(id::find_val("tcp_content_deliver_all_orig")->AsBool());
tcp_content_deliver_all_resp = bool(id::find_val("tcp_content_deliver_all_resp")->AsBool()); tcp_content_deliver_all_resp = bool(id::find_val("tcp_content_deliver_all_resp")->AsBool());
udp_content_deliver_all_orig = bool(id::find_val("udp_content_deliver_all_orig")->AsBool()); udp_content_deliver_all_orig = bool(id::find_val("udp_content_deliver_all_orig")->AsBool());
udp_content_deliver_all_resp = bool(id::find_val("udp_content_deliver_all_resp")->AsBool()); udp_content_deliver_all_resp = bool(id::find_val("udp_content_deliver_all_resp")->AsBool());
udp_content_delivery_ports_use_resp = bool( udp_content_delivery_ports_use_resp = bool(id::find_val("udp_content_delivery_ports_use_resp")->AsBool());
id::find_val("udp_content_delivery_ports_use_resp")->AsBool());
dns_session_timeout = id::find_val("dns_session_timeout")->AsInterval(); dns_session_timeout = id::find_val("dns_session_timeout")->AsInterval();
rpc_timeout = id::find_val("rpc_timeout")->AsInterval(); rpc_timeout = id::find_val("rpc_timeout")->AsInterval();
watchdog_interval = int(id::find_val("watchdog_interval")->AsInterval()); watchdog_interval = int(id::find_val("watchdog_interval")->AsInterval());
max_timer_expires = id::find_val("max_timer_expires")->AsCount(); max_timer_expires = id::find_val("max_timer_expires")->AsCount();
mime_segment_length = id::find_val("mime_segment_length")->AsCount(); mime_segment_length = id::find_val("mime_segment_length")->AsCount();
mime_segment_overlap_length = id::find_val("mime_segment_overlap_length")->AsCount(); mime_segment_overlap_length = id::find_val("mime_segment_overlap_length")->AsCount();
http_entity_data_delivery_size = id::find_val("http_entity_data_delivery_size")->AsCount(); http_entity_data_delivery_size = id::find_val("http_entity_data_delivery_size")->AsCount();
truncate_http_URI = id::find_val("truncate_http_URI")->AsInt(); truncate_http_URI = id::find_val("truncate_http_URI")->AsInt();
dns_skip_all_auth = id::find_val("dns_skip_all_auth")->AsBool(); dns_skip_all_auth = id::find_val("dns_skip_all_auth")->AsBool();
dns_skip_all_addl = id::find_val("dns_skip_all_addl")->AsBool(); dns_skip_all_addl = id::find_val("dns_skip_all_addl")->AsBool();
dns_max_queries = id::find_val("dns_max_queries")->AsCount(); dns_max_queries = id::find_val("dns_max_queries")->AsCount();
orig_addr_anonymization = 0; orig_addr_anonymization = 0;
if ( const auto& id = id::find("orig_addr_anonymization") ) if ( const auto& id = id::find("orig_addr_anonymization") )
if ( const auto& v = id->GetVal() ) if ( const auto& v = id->GetVal() )
orig_addr_anonymization = v->AsInt(); orig_addr_anonymization = v->AsInt();
resp_addr_anonymization = 0; resp_addr_anonymization = 0;
if ( const auto& id = id::find("resp_addr_anonymization") ) if ( const auto& id = id::find("resp_addr_anonymization") )
if ( const auto& v = id->GetVal() ) if ( const auto& v = id->GetVal() )
resp_addr_anonymization = v->AsInt(); resp_addr_anonymization = v->AsInt();
other_addr_anonymization = 0; other_addr_anonymization = 0;
if ( const auto& id = id::find("other_addr_anonymization") ) if ( const auto& id = id::find("other_addr_anonymization") )
if ( const auto& v = id->GetVal() ) if ( const auto& v = id->GetVal() )
other_addr_anonymization = v->AsInt(); other_addr_anonymization = v->AsInt();
connection_status_update_interval = 0.0; connection_status_update_interval = 0.0;
if ( const auto& id = id::find("connection_status_update_interval") ) if ( const auto& id = id::find("connection_status_update_interval") )
if ( const auto& v = id->GetVal() ) if ( const auto& v = id->GetVal() )
connection_status_update_interval = v->AsInterval(); connection_status_update_interval = v->AsInterval();
expensive_profiling_multiple = id::find_val("expensive_profiling_multiple")->AsCount(); expensive_profiling_multiple = id::find_val("expensive_profiling_multiple")->AsCount();
profiling_interval = id::find_val("profiling_interval")->AsInterval(); profiling_interval = id::find_val("profiling_interval")->AsInterval();
segment_profiling = id::find_val("segment_profiling")->AsBool(); segment_profiling = id::find_val("segment_profiling")->AsBool();
pkt_profile_mode = id::find_val("pkt_profile_mode")->InternalInt(); pkt_profile_mode = id::find_val("pkt_profile_mode")->InternalInt();
pkt_profile_freq = id::find_val("pkt_profile_freq")->AsDouble(); pkt_profile_freq = id::find_val("pkt_profile_freq")->AsDouble();
load_sample_freq = id::find_val("load_sample_freq")->AsCount(); load_sample_freq = id::find_val("load_sample_freq")->AsCount();
dpd_reassemble_first_packets = id::find_val("dpd_reassemble_first_packets")->AsBool(); dpd_reassemble_first_packets = id::find_val("dpd_reassemble_first_packets")->AsBool();
dpd_buffer_size = id::find_val("dpd_buffer_size")->AsCount(); dpd_buffer_size = id::find_val("dpd_buffer_size")->AsCount();
dpd_max_packets = id::find_val("dpd_max_packets")->AsCount(); dpd_max_packets = id::find_val("dpd_max_packets")->AsCount();
dpd_match_only_beginning = id::find_val("dpd_match_only_beginning")->AsBool(); dpd_match_only_beginning = id::find_val("dpd_match_only_beginning")->AsBool();
dpd_late_match_stop = id::find_val("dpd_late_match_stop")->AsBool(); dpd_late_match_stop = id::find_val("dpd_late_match_stop")->AsBool();
dpd_ignore_ports = id::find_val("dpd_ignore_ports")->AsBool(); dpd_ignore_ports = id::find_val("dpd_ignore_ports")->AsBool();
tunnel_max_changes_per_connection = tunnel_max_changes_per_connection = id::find_val("Tunnel::max_changes_per_connection")->AsCount();
id::find_val("Tunnel::max_changes_per_connection")->AsCount(); }
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -6,8 +6,7 @@
#include "zeek/Stats.h" #include "zeek/Stats.h"
#include "zeek/Val.h" #include "zeek/Val.h"
namespace zeek::detail namespace zeek::detail {
{
extern int watchdog_interval; extern int watchdog_interval;
@ -103,7 +102,7 @@ extern void init_event_handlers();
extern void init_net_var(); extern void init_net_var();
extern void init_builtin_types(); extern void init_builtin_types();
} // namespace zeek::detail } // namespace zeek::detail
#include "const.bif.netvar_h" #include "const.bif.netvar_h"
#include "event.bif.netvar_h" #include "event.bif.netvar_h"

View file

@ -8,84 +8,68 @@
zeek::notifier::detail::Registry zeek::notifier::detail::registry; zeek::notifier::detail::Registry zeek::notifier::detail::registry;
namespace zeek::notifier::detail namespace zeek::notifier::detail {
{
Receiver::Receiver() Receiver::Receiver() { DBG_LOG(DBG_NOTIFIERS, "creating receiver %p", this); }
{
DBG_LOG(DBG_NOTIFIERS, "creating receiver %p", this);
}
Receiver::~Receiver() Receiver::~Receiver() { DBG_LOG(DBG_NOTIFIERS, "deleting receiver %p", this); }
{
DBG_LOG(DBG_NOTIFIERS, "deleting receiver %p", this);
}
Registry::~Registry() Registry::~Registry() {
{ while ( registrations.begin() != registrations.end() )
while ( registrations.begin() != registrations.end() ) Unregister(registrations.begin()->first);
Unregister(registrations.begin()->first); }
}
void Registry::Register(Modifiable* m, Receiver* r) void Registry::Register(Modifiable* m, Receiver* r) {
{ DBG_LOG(DBG_NOTIFIERS, "registering object %p for receiver %p", m, r);
DBG_LOG(DBG_NOTIFIERS, "registering object %p for receiver %p", m, r);
registrations.insert({m, r}); registrations.insert({m, r});
++m->num_receivers; ++m->num_receivers;
} }
void Registry::Unregister(Modifiable* m, Receiver* r) void Registry::Unregister(Modifiable* m, Receiver* r) {
{ DBG_LOG(DBG_NOTIFIERS, "unregistering object %p from receiver %p", m, r);
DBG_LOG(DBG_NOTIFIERS, "unregistering object %p from receiver %p", m, r);
auto x = registrations.equal_range(m); auto x = registrations.equal_range(m);
for ( auto i = x.first; i != x.second; i++ ) for ( auto i = x.first; i != x.second; i++ ) {
{ if ( i->second == r ) {
if ( i->second == r ) --i->first->num_receivers;
{ registrations.erase(i);
--i->first->num_receivers; break;
registrations.erase(i); }
break; }
} }
}
}
void Registry::Unregister(Modifiable* m) void Registry::Unregister(Modifiable* m) {
{ DBG_LOG(DBG_NOTIFIERS, "unregistering object %p from all notifiers", m);
DBG_LOG(DBG_NOTIFIERS, "unregistering object %p from all notifiers", m);
auto x = registrations.equal_range(m); auto x = registrations.equal_range(m);
for ( auto i = x.first; i != x.second; i++ ) for ( auto i = x.first; i != x.second; i++ )
--i->first->num_receivers; --i->first->num_receivers;
registrations.erase(x.first, x.second); registrations.erase(x.first, x.second);
} }
void Registry::Modified(Modifiable* m) void Registry::Modified(Modifiable* m) {
{ DBG_LOG(DBG_NOTIFIERS, "object %p has been modified", m);
DBG_LOG(DBG_NOTIFIERS, "object %p has been modified", m);
auto x = registrations.equal_range(m); auto x = registrations.equal_range(m);
for ( auto i = x.first; i != x.second; i++ ) for ( auto i = x.first; i != x.second; i++ )
i->second->Modified(m); i->second->Modified(m);
} }
void Registry::Terminate() void Registry::Terminate() {
{ std::set<Receiver*> receivers;
std::set<Receiver*> receivers;
for ( auto& r : registrations ) for ( auto& r : registrations )
receivers.emplace(r.second); receivers.emplace(r.second);
for ( auto& r : receivers ) for ( auto& r : receivers )
r->Terminate(); r->Terminate();
} }
Modifiable::~Modifiable() Modifiable::~Modifiable() {
{ if ( num_receivers )
if ( num_receivers ) registry.Unregister(this);
registry.Unregister(this); }
}
} // namespace zeek::notifier::detail } // namespace zeek::notifier::detail

View file

@ -10,86 +10,83 @@
#include <cstdint> #include <cstdint>
#include <unordered_map> #include <unordered_map>
namespace zeek::notifier::detail namespace zeek::notifier::detail {
{
class Modifiable; class Modifiable;
/** Interface class for receivers of notifications. */ /** Interface class for receivers of notifications. */
class Receiver class Receiver {
{
public: public:
Receiver(); Receiver();
virtual ~Receiver(); virtual ~Receiver();
/** /**
* Callback executed when a register object has been modified. * Callback executed when a register object has been modified.
* *
* @param m object that was modified * @param m object that was modified
*/ */
virtual void Modified(Modifiable* m) = 0; virtual void Modified(Modifiable* m) = 0;
/** /**
* Callback executed when notification registry is terminating and * Callback executed when notification registry is terminating and
* no further modifications can possibly occur. * no further modifications can possibly occur.
*/ */
virtual void Terminate() { } virtual void Terminate() {}
}; };
/** Singleton class tracking all notification requests globally. */ /** Singleton class tracking all notification requests globally. */
class Registry class Registry {
{
public: public:
~Registry(); ~Registry();
/** /**
* Registers a receiver to be informed when a modifiable object has * Registers a receiver to be informed when a modifiable object has
* changed. * changed.
* *
* @param m object to track. Does not take ownership, but the object * @param m object to track. Does not take ownership, but the object
* will automatically unregister itself on destruction. * will automatically unregister itself on destruction.
* *
* @param r receiver to notify on changes. Does not take ownership, * @param r receiver to notify on changes. Does not take ownership,
* the receiver must remain valid as long as the registration stays * the receiver must remain valid as long as the registration stays
* in place. * in place.
*/ */
void Register(Modifiable* m, Receiver* r); void Register(Modifiable* m, Receiver* r);
/** /**
* Cancels a receiver's request to be informed about an object's * Cancels a receiver's request to be informed about an object's
* modification. The arguments to the method must match what was * modification. The arguments to the method must match what was
* originally registered. * originally registered.
* *
* @param m object to no longer track. * @param m object to no longer track.
* *
* @param r receiver to no longer notify. * @param r receiver to no longer notify.
*/ */
void Unregister(Modifiable* m, Receiver* Receiver); void Unregister(Modifiable* m, Receiver* Receiver);
/** /**
* Cancels any active receiver requests to be informed about a * Cancels any active receiver requests to be informed about a
* particular object's modifications. * particular object's modifications.
* *
* @param m object to no longer track. * @param m object to no longer track.
*/ */
void Unregister(Modifiable* m); void Unregister(Modifiable* m);
/** /**
* Notifies all receivers that no further modifications will occur * Notifies all receivers that no further modifications will occur
* as the registry is shutting down. * as the registry is shutting down.
*/ */
void Terminate(); void Terminate();
private: private:
friend class Modifiable; friend class Modifiable;
// Inform all registered receivers of a modification to an object. // Inform all registered receivers of a modification to an object.
// Will be called from the object itself. // Will be called from the object itself.
void Modified(Modifiable* m); void Modified(Modifiable* m);
using ModifiableMap = std::unordered_multimap<Modifiable*, Receiver*>; using ModifiableMap = std::unordered_multimap<Modifiable*, Receiver*>;
ModifiableMap registrations; ModifiableMap registrations;
}; };
/** /**
* Singleton object tracking all global notification requests. * Singleton object tracking all global notification requests.
@ -100,26 +97,24 @@ extern Registry registry;
* Base class for objects that can trigger notifications to receivers when * Base class for objects that can trigger notifications to receivers when
* modified. * modified.
*/ */
class Modifiable class Modifiable {
{
public: public:
/** /**
* Calling this method signals to all registered receivers that the * Calling this method signals to all registered receivers that the
* object has been modified. * object has been modified.
*/ */
void Modified() void Modified() {
{ if ( num_receivers )
if ( num_receivers ) registry.Modified(this);
registry.Modified(this); }
}
protected: protected:
friend class Registry; friend class Registry;
virtual ~Modifiable(); virtual ~Modifiable();
// Number of currently registered receivers. // Number of currently registered receivers.
uint64_t num_receivers = 0; uint64_t num_receivers = 0;
}; };
} // namespace zeek::notifier::detail } // namespace zeek::notifier::detail

View file

@ -11,208 +11,180 @@
#include "zeek/Func.h" #include "zeek/Func.h"
#include "zeek/plugin/Manager.h" #include "zeek/plugin/Manager.h"
namespace zeek namespace zeek {
{ namespace detail {
namespace detail
{
Location start_location("<start uninitialized>", 0, 0, 0, 0); Location start_location("<start uninitialized>", 0, 0, 0, 0);
Location end_location("<end uninitialized>", 0, 0, 0, 0); Location end_location("<end uninitialized>", 0, 0, 0, 0);
void Location::Describe(ODesc* d) const void Location::Describe(ODesc* d) const {
{ if ( filename ) {
if ( filename ) d->Add(filename);
{
d->Add(filename);
if ( first_line == 0 ) if ( first_line == 0 )
return; return;
d->AddSP(","); d->AddSP(",");
} }
if ( last_line != first_line ) if ( last_line != first_line ) {
{ d->Add("lines ");
d->Add("lines "); d->Add(first_line);
d->Add(first_line); d->Add("-");
d->Add("-"); d->Add(last_line);
d->Add(last_line); }
} else {
else d->Add("line ");
{ d->Add(first_line);
d->Add("line "); }
d->Add(first_line); }
}
}
bool Location::operator==(const Location& l) const bool Location::operator==(const Location& l) const {
{ if ( filename == l.filename || (filename && l.filename && util::streq(filename, l.filename)) )
if ( filename == l.filename || (filename && l.filename && util::streq(filename, l.filename)) ) return first_line == l.first_line && last_line == l.last_line;
return first_line == l.first_line && last_line == l.last_line; else
else return false;
return false; }
}
} // namespace detail } // namespace detail
int Obj::suppress_errors = 0; int Obj::suppress_errors = 0;
Obj::~Obj() Obj::~Obj() {
{ if ( notify_plugins )
if ( notify_plugins ) PLUGIN_HOOK_VOID(HOOK_OBJ_DTOR, HookObjDtor(this));
PLUGIN_HOOK_VOID(HOOK_OBJ_DTOR, HookObjDtor(this));
delete location; delete location;
} }
void Obj::Warn(const char* msg, const Obj* obj2, bool pinpoint_only, void Obj::Warn(const char* msg, const Obj* obj2, bool pinpoint_only, const detail::Location* expr_location) const {
const detail::Location* expr_location) const ODesc d;
{ DoMsg(&d, msg, obj2, pinpoint_only, expr_location);
ODesc d; reporter->Warning("%s", d.Description());
DoMsg(&d, msg, obj2, pinpoint_only, expr_location); reporter->PopLocation();
reporter->Warning("%s", d.Description()); }
reporter->PopLocation();
}
void Obj::Error(const char* msg, const Obj* obj2, bool pinpoint_only, void Obj::Error(const char* msg, const Obj* obj2, bool pinpoint_only, const detail::Location* expr_location) const {
const detail::Location* expr_location) const if ( suppress_errors )
{ return;
if ( suppress_errors )
return;
ODesc d; ODesc d;
DoMsg(&d, msg, obj2, pinpoint_only, expr_location); DoMsg(&d, msg, obj2, pinpoint_only, expr_location);
reporter->Error("%s", d.Description()); reporter->Error("%s", d.Description());
reporter->PopLocation(); reporter->PopLocation();
} }
void Obj::BadTag(const char* msg, const char* t1, const char* t2) const void Obj::BadTag(const char* msg, const char* t1, const char* t2) const {
{ char out[512];
char out[512];
if ( t2 ) if ( t2 )
snprintf(out, sizeof(out), "%s (%s/%s)", msg, t1, t2); snprintf(out, sizeof(out), "%s (%s/%s)", msg, t1, t2);
else if ( t1 ) else if ( t1 )
snprintf(out, sizeof(out), "%s (%s)", msg, t1); snprintf(out, sizeof(out), "%s (%s)", msg, t1);
else else
snprintf(out, sizeof(out), "%s", msg); snprintf(out, sizeof(out), "%s", msg);
ODesc d; ODesc d;
DoMsg(&d, out); DoMsg(&d, out);
reporter->FatalErrorWithCore("%s", d.Description()); reporter->FatalErrorWithCore("%s", d.Description());
reporter->PopLocation(); reporter->PopLocation();
} }
void Obj::Internal(const char* msg) const void Obj::Internal(const char* msg) const {
{ ODesc d;
ODesc d; DoMsg(&d, msg);
DoMsg(&d, msg); auto rcs = render_call_stack();
auto rcs = render_call_stack();
if ( rcs.empty() ) if ( rcs.empty() )
reporter->InternalError("%s", d.Description()); reporter->InternalError("%s", d.Description());
else else
reporter->InternalError("%s, call stack: %s", d.Description(), rcs.data()); reporter->InternalError("%s, call stack: %s", d.Description(), rcs.data());
reporter->PopLocation(); reporter->PopLocation();
} }
void Obj::InternalWarning(const char* msg) const void Obj::InternalWarning(const char* msg) const {
{ ODesc d;
ODesc d; DoMsg(&d, msg);
DoMsg(&d, msg); reporter->InternalWarning("%s", d.Description());
reporter->InternalWarning("%s", d.Description()); reporter->PopLocation();
reporter->PopLocation(); }
}
void Obj::AddLocation(ODesc* d) const void Obj::AddLocation(ODesc* d) const {
{ if ( ! location ) {
if ( ! location ) d->Add("<no location>");
{ return;
d->Add("<no location>"); }
return;
}
location->Describe(d); location->Describe(d);
} }
bool Obj::SetLocationInfo(const detail::Location* start, const detail::Location* end) bool Obj::SetLocationInfo(const detail::Location* start, const detail::Location* end) {
{ if ( ! start || ! end )
if ( ! start || ! end ) return false;
return false;
if ( end->filename && ! util::streq(start->filename, end->filename) ) if ( end->filename && ! util::streq(start->filename, end->filename) )
return false; return false;
if ( location && (start == &detail::no_location || end == &detail::no_location) ) if ( location && (start == &detail::no_location || end == &detail::no_location) )
// We already have a better location, so don't use this one. // We already have a better location, so don't use this one.
return true; return true;
delete location; delete location;
location = new detail::Location(start->filename, start->first_line, end->last_line, location =
start->first_column, end->last_column); new detail::Location(start->filename, start->first_line, end->last_line, start->first_column, end->last_column);
return true; return true;
} }
void Obj::UpdateLocationEndInfo(const detail::Location& end) void Obj::UpdateLocationEndInfo(const detail::Location& end) {
{ if ( ! location )
if ( ! location ) SetLocationInfo(&end, &end);
SetLocationInfo(&end, &end);
location->last_line = end.last_line; location->last_line = end.last_line;
location->last_column = end.last_column; location->last_column = end.last_column;
} }
void Obj::DoMsg(ODesc* d, const char s1[], const Obj* obj2, bool pinpoint_only, void Obj::DoMsg(ODesc* d, const char s1[], const Obj* obj2, bool pinpoint_only,
const detail::Location* expr_location) const const detail::Location* expr_location) const {
{ d->SetShort();
d->SetShort();
d->Add(s1); d->Add(s1);
PinPoint(d, obj2, pinpoint_only); PinPoint(d, obj2, pinpoint_only);
const detail::Location* loc2 = nullptr; const detail::Location* loc2 = nullptr;
if ( obj2 && obj2->GetLocationInfo() != &detail::no_location && if ( obj2 && obj2->GetLocationInfo() != &detail::no_location && *obj2->GetLocationInfo() != *GetLocationInfo() )
*obj2->GetLocationInfo() != *GetLocationInfo() ) loc2 = obj2->GetLocationInfo();
loc2 = obj2->GetLocationInfo(); else if ( expr_location )
else if ( expr_location ) loc2 = expr_location;
loc2 = expr_location;
reporter->PushLocation(GetLocationInfo(), loc2); reporter->PushLocation(GetLocationInfo(), loc2);
} }
void Obj::PinPoint(ODesc* d, const Obj* obj2, bool pinpoint_only) const void Obj::PinPoint(ODesc* d, const Obj* obj2, bool pinpoint_only) const {
{ d->Add(" (");
d->Add(" ("); Describe(d);
Describe(d); if ( obj2 && ! pinpoint_only ) {
if ( obj2 && ! pinpoint_only ) d->Add(" and ");
{ obj2->Describe(d);
d->Add(" and "); }
obj2->Describe(d);
}
d->Add(")"); d->Add(")");
} }
void Obj::Print() const void Obj::Print() const {
{ static File fstderr(stderr);
static File fstderr(stderr); ODesc d(DESC_READABLE, &fstderr);
ODesc d(DESC_READABLE, &fstderr); Describe(&d);
Describe(&d); d.Add("\n");
d.Add("\n"); }
}
void bad_ref(int type) void bad_ref(int type) {
{ reporter->InternalError("bad reference count [%d]", type);
reporter->InternalError("bad reference count [%d]", type); abort();
abort(); }
}
void obj_delete_func(void* v) void obj_delete_func(void* v) { Unref((Obj*)v); }
{
Unref((Obj*)v);
}
} // namespace zeek } // namespace zeek

240
src/Obj.h
View file

@ -6,34 +6,28 @@
#include <climits> #include <climits>
namespace zeek namespace zeek {
{
class ODesc; class ODesc;
namespace detail namespace detail {
{
class Location final class Location final {
{
public: public:
constexpr Location(const char* fname, int line_f, int line_l, int col_f, int col_l) noexcept constexpr Location(const char* fname, int line_f, int line_l, int col_f, int col_l) noexcept
: filename(fname), first_line(line_f), last_line(line_l), first_column(col_f), : filename(fname), first_line(line_f), last_line(line_l), first_column(col_f), last_column(col_l) {}
last_column(col_l)
{
}
Location() = default; Location() = default;
void Describe(ODesc* d) const; void Describe(ODesc* d) const;
bool operator==(const Location& l) const; bool operator==(const Location& l) const;
bool operator!=(const Location& l) const { return ! (*this == l); } bool operator!=(const Location& l) const { return ! (*this == l); }
const char* filename = nullptr; const char* filename = nullptr;
int first_line = 0, last_line = 0; int first_line = 0, last_line = 0;
int first_column = 0, last_column = 0; int first_column = 0, last_column = 0;
}; };
#define YYLTYPE zeek::detail::yyltype #define YYLTYPE zeek::detail::yyltype
using yyltype = Location; using yyltype = Location;
@ -48,154 +42,138 @@ extern Location start_location;
extern Location end_location; extern Location end_location;
// Used by parser to set the above. // Used by parser to set the above.
inline void set_location(const Location loc) inline void set_location(const Location loc) { start_location = end_location = loc; }
{
start_location = end_location = loc;
}
inline void set_location(const Location start, const Location end) inline void set_location(const Location start, const Location end) {
{ start_location = start;
start_location = start; end_location = end;
end_location = end; }
}
} // namespace detail } // namespace detail
class Obj class Obj {
{
public: public:
Obj() Obj() {
{ // A bit of a hack. We'd like to associate location
// A bit of a hack. We'd like to associate location // information with every object created when parsing,
// information with every object created when parsing, // since for them, the location is generally well-defined.
// since for them, the location is generally well-defined. // We could maintain a separate flag that tells us whether
// We could maintain a separate flag that tells us whether // we're inside a parse, but the parser also sets the
// we're inside a parse, but the parser also sets the // location to no_location when it's done, so it makes
// location to no_location when it's done, so it makes // sense to just check for that. *However*, start_location
// sense to just check for that. *However*, start_location // and end_location are maintained as their own objects
// and end_location are maintained as their own objects // rather than pointers or references, so we can't directly
// rather than pointers or references, so we can't directly // check them for equality with no_location. So instead
// check them for equality with no_location. So instead // we check for whether start_location has a line number
// we check for whether start_location has a line number // of 0, which should only happen if it's been assigned
// of 0, which should only happen if it's been assigned // to no_location (or hasn't been initialized at all).
// to no_location (or hasn't been initialized at all). location = nullptr;
location = nullptr; if ( detail::start_location.first_line != 0 )
if ( detail::start_location.first_line != 0 ) SetLocationInfo(&detail::start_location, &detail::end_location);
SetLocationInfo(&detail::start_location, &detail::end_location); }
}
virtual ~Obj(); virtual ~Obj();
/* disallow copying */ /* disallow copying */
Obj(const Obj&) = delete; Obj(const Obj&) = delete;
Obj& operator=(const Obj&) = delete; Obj& operator=(const Obj&) = delete;
// Report user warnings/errors. If obj2 is given, then it's // Report user warnings/errors. If obj2 is given, then it's
// included in the message, though if pinpoint_only is non-zero, // included in the message, though if pinpoint_only is non-zero,
// then obj2 is only used to pinpoint the location. // then obj2 is only used to pinpoint the location.
void Warn(const char* msg, const Obj* obj2 = nullptr, bool pinpoint_only = false, void Warn(const char* msg, const Obj* obj2 = nullptr, bool pinpoint_only = false,
const detail::Location* expr_location = nullptr) const; const detail::Location* expr_location = nullptr) const;
void Error(const char* msg, const Obj* obj2 = nullptr, bool pinpoint_only = false, void Error(const char* msg, const Obj* obj2 = nullptr, bool pinpoint_only = false,
const detail::Location* expr_location = nullptr) const; const detail::Location* expr_location = nullptr) const;
// Report internal errors. // Report internal errors.
void BadTag(const char* msg, const char* t1 = nullptr, const char* t2 = nullptr) const; void BadTag(const char* msg, const char* t1 = nullptr, const char* t2 = nullptr) const;
#define CHECK_TAG(t1, t2, text, tag_to_text_func) \ #define CHECK_TAG(t1, t2, text, tag_to_text_func) \
{ \ { \
if ( t1 != t2 ) \ if ( t1 != t2 ) \
BadTag(text, tag_to_text_func(t1), tag_to_text_func(t2)); \ BadTag(text, tag_to_text_func(t1), tag_to_text_func(t2)); \
} }
[[noreturn]] void Internal(const char* msg) const; [[noreturn]] void Internal(const char* msg) const;
void InternalWarning(const char* msg) const; void InternalWarning(const char* msg) const;
virtual void Describe(ODesc* d) const {/* FIXME: Add code */}; virtual void Describe(ODesc* d) const {/* FIXME: Add code */};
void AddLocation(ODesc* d) const; void AddLocation(ODesc* d) const;
// Get location info for debugging. // Get location info for debugging.
virtual const detail::Location* GetLocationInfo() const virtual const detail::Location* GetLocationInfo() const { return location ? location : &detail::no_location; }
{
return location ? location : &detail::no_location;
}
virtual bool SetLocationInfo(const detail::Location* loc) { return SetLocationInfo(loc, loc); } virtual bool SetLocationInfo(const detail::Location* loc) { return SetLocationInfo(loc, loc); }
// Location = range from start to end. // Location = range from start to end.
virtual bool SetLocationInfo(const detail::Location* start, const detail::Location* end); virtual bool SetLocationInfo(const detail::Location* start, const detail::Location* end);
// Set new end-of-location information. This is used to // Set new end-of-location information. This is used to
// extend compound objects such as statement lists. // extend compound objects such as statement lists.
virtual void UpdateLocationEndInfo(const detail::Location& end); virtual void UpdateLocationEndInfo(const detail::Location& end);
// Enable notification of plugins when this objects gets destroyed. // Enable notification of plugins when this objects gets destroyed.
void NotifyPluginsOnDtor() { notify_plugins = true; } void NotifyPluginsOnDtor() { notify_plugins = true; }
int RefCnt() const { return ref_cnt; } int RefCnt() const { return ref_cnt; }
// Helper class to temporarily suppress errors // Helper class to temporarily suppress errors
// as long as there exist any instances. // as long as there exist any instances.
class SuppressErrors class SuppressErrors {
{ public:
public: SuppressErrors() { ++Obj::suppress_errors; }
SuppressErrors() { ++Obj::suppress_errors; } ~SuppressErrors() { --Obj::suppress_errors; }
~SuppressErrors() { --Obj::suppress_errors; } };
};
void Print() const; void Print() const;
protected: protected:
detail::Location* location; // all that matters in real estate detail::Location* location; // all that matters in real estate
private: private:
friend class SuppressErrors; friend class SuppressErrors;
void DoMsg(ODesc* d, const char s1[], const Obj* obj2 = nullptr, bool pinpoint_only = false, void DoMsg(ODesc* d, const char s1[], const Obj* obj2 = nullptr, bool pinpoint_only = false,
const detail::Location* expr_location = nullptr) const; const detail::Location* expr_location = nullptr) const;
void PinPoint(ODesc* d, const Obj* obj2 = nullptr, bool pinpoint_only = false) const; void PinPoint(ODesc* d, const Obj* obj2 = nullptr, bool pinpoint_only = false) const;
friend inline void Ref(Obj* o); friend inline void Ref(Obj* o);
friend inline void Unref(Obj* o); friend inline void Unref(Obj* o);
int ref_cnt = 1; int ref_cnt = 1;
bool notify_plugins = false; bool notify_plugins = false;
// If non-zero, do not print runtime errors. Useful for // If non-zero, do not print runtime errors. Useful for
// speculative evaluation. // speculative evaluation.
static int suppress_errors; static int suppress_errors;
}; };
// Sometimes useful when dealing with Obj subclasses that have their // Sometimes useful when dealing with Obj subclasses that have their
// own (protected) versions of Error. // own (protected) versions of Error.
inline void Error(const Obj* o, const char* msg) inline void Error(const Obj* o, const char* msg) { o->Error(msg); }
{
o->Error(msg);
}
[[noreturn]] extern void bad_ref(int type); [[noreturn]] extern void bad_ref(int type);
inline void Ref(Obj* o) inline void Ref(Obj* o) {
{ if ( ++(o->ref_cnt) <= 1 )
if ( ++(o->ref_cnt) <= 1 ) bad_ref(0);
bad_ref(0); if ( o->ref_cnt == INT_MAX )
if ( o->ref_cnt == INT_MAX ) bad_ref(1);
bad_ref(1); }
}
inline void Unref(Obj* o) inline void Unref(Obj* o) {
{ if ( o && --o->ref_cnt <= 0 ) {
if ( o && --o->ref_cnt <= 0 ) if ( o->ref_cnt < 0 )
{ bad_ref(2);
if ( o->ref_cnt < 0 ) delete o;
bad_ref(2);
delete o;
// We could do the following if o were passed by reference. // We could do the following if o were passed by reference.
// o = (Obj*) 0xcd; // o = (Obj*) 0xcd;
} }
} }
// A dict_delete_func that knows to Unref() dictionary entries. // A dict_delete_func that knows to Unref() dictionary entries.
extern void obj_delete_func(void* v); extern void obj_delete_func(void* v);
} // namespace zeek } // namespace zeek

File diff suppressed because it is too large Load diff

View file

@ -21,22 +21,18 @@
#include "zeek/telemetry/Gauge.h" #include "zeek/telemetry/Gauge.h"
#include "zeek/telemetry/Histogram.h" #include "zeek/telemetry/Histogram.h"
namespace broker namespace broker {
{
class data; class data;
} }
namespace zeek namespace zeek {
{
namespace probabilistic namespace probabilistic {
{
class BloomFilter; class BloomFilter;
} }
namespace probabilistic::detail namespace probabilistic::detail {
{
class CardinalityCounter; class CardinalityCounter;
} }
class OpaqueVal; class OpaqueVal;
using OpaqueValPtr = IntrusivePtr<OpaqueVal>; using OpaqueValPtr = IntrusivePtr<OpaqueVal>;
@ -48,60 +44,59 @@ using BloomFilterValPtr = IntrusivePtr<BloomFilterVal>;
* Singleton that registers all available all available types of opaque * Singleton that registers all available all available types of opaque
* values. This facilitates their serialization into Broker values. * values. This facilitates their serialization into Broker values.
*/ */
class OpaqueMgr class OpaqueMgr {
{
public: public:
using Factory = OpaqueValPtr(); using Factory = OpaqueValPtr();
/** /**
* Return's a unique ID for the type of an opaque value. * Return's a unique ID for the type of an opaque value.
* @param v opaque value to return type for; its class must have been * @param v opaque value to return type for; its class must have been
* registered with the manager, otherwise this method will abort * registered with the manager, otherwise this method will abort
* execution. * execution.
* *
* @return type ID, which can used with *Instantiate()* to create a * @return type ID, which can used with *Instantiate()* to create a
* new instance of the same type. * new instance of the same type.
*/ */
const std::string& TypeID(const OpaqueVal* v) const; const std::string& TypeID(const OpaqueVal* v) const;
/** /**
* Instantiates a new opaque value of a specific opaque type. * Instantiates a new opaque value of a specific opaque type.
* *
* @param id unique type ID for the class to instantiate; this will * @param id unique type ID for the class to instantiate; this will
* normally have been returned earlier by *TypeID()*. * normally have been returned earlier by *TypeID()*.
* *
* @return A freshly instantiated value of the OpaqueVal-derived * @return A freshly instantiated value of the OpaqueVal-derived
* classes that *id* specifies, with reference count at +1. If *id* * classes that *id* specifies, with reference count at +1. If *id*
* is unknown, this will return null. * is unknown, this will return null.
* *
*/ */
OpaqueValPtr Instantiate(const std::string& id) const; OpaqueValPtr Instantiate(const std::string& id) const;
/** Returns the global manager singleton object. */ /** Returns the global manager singleton object. */
static OpaqueMgr* mgr(); static OpaqueMgr* mgr();
/** /**
* Internal helper class to register an OpaqueVal-derived classes * Internal helper class to register an OpaqueVal-derived classes
* with the manager. * with the manager.
*/ */
template <class T> class Register template<class T>
{ class Register {
public: public:
Register(const char* id) { OpaqueMgr::mgr()->_types.emplace(id, &T::OpaqueInstantiate); } Register(const char* id) { OpaqueMgr::mgr()->_types.emplace(id, &T::OpaqueInstantiate); }
}; };
private: private:
std::unordered_map<std::string, Factory*> _types; std::unordered_map<std::string, Factory*> _types;
}; };
/** Macro to insert into an OpaqueVal-derived class's declaration. */ /** Macro to insert into an OpaqueVal-derived class's declaration. */
#define DECLARE_OPAQUE_VALUE(T) \ #define DECLARE_OPAQUE_VALUE(T) \
friend class zeek::OpaqueMgr::Register<T>; \ friend class zeek::OpaqueMgr::Register<T>; \
friend zeek::IntrusivePtr<T> zeek::make_intrusive<T>(); \ friend zeek::IntrusivePtr<T> zeek::make_intrusive<T>(); \
broker::expected<broker::data> DoSerialize() const override; \ broker::expected<broker::data> DoSerialize() const override; \
bool DoUnserialize(const broker::data& data) override; \ bool DoUnserialize(const broker::data& data) override; \
const char* OpaqueName() const override { return #T; } \ const char* OpaqueName() const override { return #T; } \
static zeek::OpaqueValPtr OpaqueInstantiate() { return zeek::make_intrusive<T>(); } static zeek::OpaqueValPtr OpaqueInstantiate() { return zeek::make_intrusive<T>(); }
#define __OPAQUE_MERGE(a, b) a##b #define __OPAQUE_MERGE(a, b) a##b
#define __OPAQUE_ID(x) __OPAQUE_MERGE(_opaque, x) #define __OPAQUE_ID(x) __OPAQUE_MERGE(_opaque, x)
@ -114,348 +109,335 @@ private:
* completely internally, with no further script-level operators provided * completely internally, with no further script-level operators provided
* (other than bif functions). See OpaqueVal.h for derived classes. * (other than bif functions). See OpaqueVal.h for derived classes.
*/ */
class OpaqueVal : public Val class OpaqueVal : public Val {
{
public: public:
explicit OpaqueVal(OpaqueTypePtr t); explicit OpaqueVal(OpaqueTypePtr t);
~OpaqueVal() override = default; ~OpaqueVal() override = default;
/** /**
* Serializes the value into a Broker representation. * Serializes the value into a Broker representation.
* *
* @return the broker representation, or an error if serialization * @return the broker representation, or an error if serialization
* isn't supported or failed. * isn't supported or failed.
*/ */
broker::expected<broker::data> Serialize() const; broker::expected<broker::data> Serialize() const;
/** /**
* Reinstantiates a value from its serialized Broker representation. * Reinstantiates a value from its serialized Broker representation.
* *
* @param data Broker representation as returned by *Serialize()*. * @param data Broker representation as returned by *Serialize()*.
* @return unserialized instances with reference count at +1 * @return unserialized instances with reference count at +1
*/ */
static OpaqueValPtr Unserialize(const broker::data& data); static OpaqueValPtr Unserialize(const broker::data& data);
protected: protected:
friend class Val; friend class Val;
friend class OpaqueMgr; friend class OpaqueMgr;
/** /**
* Must be overridden to provide a serialized version of the derived * Must be overridden to provide a serialized version of the derived
* class' state. * class' state.
* *
* @return the serialized data or an error if serialization * @return the serialized data or an error if serialization
* isn't supported or failed. * isn't supported or failed.
*/ */
virtual broker::expected<broker::data> DoSerialize() const = 0; virtual broker::expected<broker::data> DoSerialize() const = 0;
/** /**
* Must be overridden to recreate the derived class' state from a * Must be overridden to recreate the derived class' state from a
* serialization. * serialization.
* *
* @return true if successful. * @return true if successful.
*/ */
virtual bool DoUnserialize(const broker::data& data) = 0; virtual bool DoUnserialize(const broker::data& data) = 0;
/** /**
* Internal helper for the serialization machinery. Automatically * Internal helper for the serialization machinery. Automatically
* overridden by the `DECLARE_OPAQUE_VALUE` macro. * overridden by the `DECLARE_OPAQUE_VALUE` macro.
*/ */
virtual const char* OpaqueName() const = 0; virtual const char* OpaqueName() const = 0;
/** /**
* Provides an implementation of *Val::DoClone()* that leverages the * Provides an implementation of *Val::DoClone()* that leverages the
* serialization methods to deep-copy an instance. Derived classes * serialization methods to deep-copy an instance. Derived classes
* may also override this with a more efficient custom clone * may also override this with a more efficient custom clone
* implementation of their own. * implementation of their own.
*/ */
ValPtr DoClone(CloneState* state) override; ValPtr DoClone(CloneState* state) override;
/** /**
* Helper function for derived class that need to record a type * Helper function for derived class that need to record a type
* during serialization. * during serialization.
*/ */
static broker::expected<broker::data> SerializeType(const TypePtr& t); static broker::expected<broker::data> SerializeType(const TypePtr& t);
/** /**
* Helper function for derived class that need to restore a type * Helper function for derived class that need to restore a type
* during unserialization. Returns the type at reference count +1. * during unserialization. Returns the type at reference count +1.
*/ */
static TypePtr UnserializeType(const broker::data& data); static TypePtr UnserializeType(const broker::data& data);
void ValDescribe(ODesc* d) const override; void ValDescribe(ODesc* d) const override;
void ValDescribeReST(ODesc* d) const override; void ValDescribeReST(ODesc* d) const override;
}; };
class HashVal : public OpaqueVal class HashVal : public OpaqueVal {
{
public: public:
template <class T> template<class T>
static void digest_all(detail::HashAlgorithm alg, const T& vlist, u_char* result) static void digest_all(detail::HashAlgorithm alg, const T& vlist, u_char* result) {
{ auto h = detail::hash_init(alg);
auto h = detail::hash_init(alg);
for ( const auto& v : vlist ) for ( const auto& v : vlist )
digest_one(h, v); digest_one(h, v);
detail::hash_final(h, result); detail::hash_final(h, result);
} }
bool IsValid() const; bool IsValid() const;
bool Init(); bool Init();
bool Feed(const void* data, size_t size); bool Feed(const void* data, size_t size);
StringValPtr Get(); StringValPtr Get();
protected: protected:
static void digest_one(EVP_MD_CTX* h, const Val* v); static void digest_one(EVP_MD_CTX* h, const Val* v);
static void digest_one(EVP_MD_CTX* h, const ValPtr& v); static void digest_one(EVP_MD_CTX* h, const ValPtr& v);
explicit HashVal(OpaqueTypePtr t); explicit HashVal(OpaqueTypePtr t);
virtual bool DoInit(); virtual bool DoInit();
virtual bool DoFeed(const void* data, size_t size); virtual bool DoFeed(const void* data, size_t size);
virtual StringValPtr DoGet(); virtual StringValPtr DoGet();
private: private:
// This flag exists because Get() can only be called once. // This flag exists because Get() can only be called once.
bool valid; bool valid;
}; };
class MD5Val : public HashVal class MD5Val : public HashVal {
{
public: public:
template <class T> static void digest(const T& vlist, u_char result[MD5_DIGEST_LENGTH]) template<class T>
{ static void digest(const T& vlist, u_char result[MD5_DIGEST_LENGTH]) {
digest_all(detail::Hash_MD5, vlist, result); digest_all(detail::Hash_MD5, vlist, result);
} }
template <class T> template<class T>
static void hmac(const T& vlist, u_char key[MD5_DIGEST_LENGTH], static void hmac(const T& vlist, u_char key[MD5_DIGEST_LENGTH], u_char result[MD5_DIGEST_LENGTH]) {
u_char result[MD5_DIGEST_LENGTH]) digest(vlist, result);
{
digest(vlist, result);
for ( int i = 0; i < MD5_DIGEST_LENGTH; ++i ) for ( int i = 0; i < MD5_DIGEST_LENGTH; ++i )
result[i] ^= key[i]; result[i] ^= key[i];
detail::internal_md5(result, MD5_DIGEST_LENGTH, result); detail::internal_md5(result, MD5_DIGEST_LENGTH, result);
} }
MD5Val(); MD5Val();
~MD5Val(); ~MD5Val();
ValPtr DoClone(CloneState* state) override; ValPtr DoClone(CloneState* state) override;
protected: protected:
friend class Val; friend class Val;
bool DoInit() override; bool DoInit() override;
bool DoFeed(const void* data, size_t size) override; bool DoFeed(const void* data, size_t size) override;
StringValPtr DoGet() override; StringValPtr DoGet() override;
DECLARE_OPAQUE_VALUE(MD5Val) DECLARE_OPAQUE_VALUE(MD5Val)
private: private:
#if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER)
EVP_MD_CTX* ctx; EVP_MD_CTX* ctx;
#else #else
MD5_CTX ctx; MD5_CTX ctx;
#endif #endif
}; };
class SHA1Val : public HashVal class SHA1Val : public HashVal {
{
public: public:
template <class T> static void digest(const T& vlist, u_char result[SHA_DIGEST_LENGTH]) template<class T>
{ static void digest(const T& vlist, u_char result[SHA_DIGEST_LENGTH]) {
digest_all(detail::Hash_SHA1, vlist, result); digest_all(detail::Hash_SHA1, vlist, result);
} }
SHA1Val(); SHA1Val();
~SHA1Val(); ~SHA1Val();
ValPtr DoClone(CloneState* state) override; ValPtr DoClone(CloneState* state) override;
protected: protected:
friend class Val; friend class Val;
bool DoInit() override; bool DoInit() override;
bool DoFeed(const void* data, size_t size) override; bool DoFeed(const void* data, size_t size) override;
StringValPtr DoGet() override; StringValPtr DoGet() override;
DECLARE_OPAQUE_VALUE(SHA1Val) DECLARE_OPAQUE_VALUE(SHA1Val)
private: private:
#if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER)
EVP_MD_CTX* ctx; EVP_MD_CTX* ctx;
#else #else
SHA_CTX ctx; SHA_CTX ctx;
#endif #endif
}; };
class SHA256Val : public HashVal class SHA256Val : public HashVal {
{
public: public:
template <class T> static void digest(const T& vlist, u_char result[SHA256_DIGEST_LENGTH]) template<class T>
{ static void digest(const T& vlist, u_char result[SHA256_DIGEST_LENGTH]) {
digest_all(detail::Hash_SHA256, vlist, result); digest_all(detail::Hash_SHA256, vlist, result);
} }
SHA256Val(); SHA256Val();
~SHA256Val(); ~SHA256Val();
ValPtr DoClone(CloneState* state) override; ValPtr DoClone(CloneState* state) override;
protected: protected:
friend class Val; friend class Val;
bool DoInit() override; bool DoInit() override;
bool DoFeed(const void* data, size_t size) override; bool DoFeed(const void* data, size_t size) override;
StringValPtr DoGet() override; StringValPtr DoGet() override;
DECLARE_OPAQUE_VALUE(SHA256Val) DECLARE_OPAQUE_VALUE(SHA256Val)
private: private:
#if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER)
EVP_MD_CTX* ctx; EVP_MD_CTX* ctx;
#else #else
SHA256_CTX ctx; SHA256_CTX ctx;
#endif #endif
}; };
class EntropyVal : public OpaqueVal class EntropyVal : public OpaqueVal {
{
public: public:
EntropyVal(); EntropyVal();
bool Feed(const void* data, size_t size); bool Feed(const void* data, size_t size);
bool Get(double* r_ent, double* r_chisq, double* r_mean, double* r_montepicalc, double* r_scc); bool Get(double* r_ent, double* r_chisq, double* r_mean, double* r_montepicalc, double* r_scc);
protected: protected:
friend class Val; friend class Val;
DECLARE_OPAQUE_VALUE(EntropyVal) DECLARE_OPAQUE_VALUE(EntropyVal)
private: private:
detail::RandTest state; detail::RandTest state;
}; };
class BloomFilterVal : public OpaqueVal class BloomFilterVal : public OpaqueVal {
{
public: public:
explicit BloomFilterVal(probabilistic::BloomFilter* bf); explicit BloomFilterVal(probabilistic::BloomFilter* bf);
~BloomFilterVal() override; ~BloomFilterVal() override;
ValPtr DoClone(CloneState* state) override; ValPtr DoClone(CloneState* state) override;
const TypePtr& Type() const { return type; } const TypePtr& Type() const { return type; }
bool Typify(TypePtr type); bool Typify(TypePtr type);
void Add(const Val* val); void Add(const Val* val);
bool Decrement(const Val* val); bool Decrement(const Val* val);
size_t Count(const Val* val) const; size_t Count(const Val* val) const;
void Clear(); void Clear();
bool Empty() const; bool Empty() const;
std::string InternalState() const; std::string InternalState() const;
static BloomFilterValPtr Merge(const BloomFilterVal* x, const BloomFilterVal* y); static BloomFilterValPtr Merge(const BloomFilterVal* x, const BloomFilterVal* y);
static BloomFilterValPtr Intersect(const BloomFilterVal* x, const BloomFilterVal* y); static BloomFilterValPtr Intersect(const BloomFilterVal* x, const BloomFilterVal* y);
protected: protected:
friend class Val; friend class Val;
BloomFilterVal(); BloomFilterVal();
DECLARE_OPAQUE_VALUE(BloomFilterVal) DECLARE_OPAQUE_VALUE(BloomFilterVal)
private: private:
// Disable. // Disable.
BloomFilterVal(const BloomFilterVal&); BloomFilterVal(const BloomFilterVal&);
BloomFilterVal& operator=(const BloomFilterVal&); BloomFilterVal& operator=(const BloomFilterVal&);
TypePtr type; TypePtr type;
detail::CompositeHash* hash; detail::CompositeHash* hash;
probabilistic::BloomFilter* bloom_filter; probabilistic::BloomFilter* bloom_filter;
}; };
class CardinalityVal : public OpaqueVal class CardinalityVal : public OpaqueVal {
{
public: public:
explicit CardinalityVal(probabilistic::detail::CardinalityCounter*); explicit CardinalityVal(probabilistic::detail::CardinalityCounter*);
~CardinalityVal() override; ~CardinalityVal() override;
ValPtr DoClone(CloneState* state) override; ValPtr DoClone(CloneState* state) override;
void Add(const Val* val); void Add(const Val* val);
const TypePtr& Type() const { return type; } const TypePtr& Type() const { return type; }
bool Typify(TypePtr type); bool Typify(TypePtr type);
probabilistic::detail::CardinalityCounter* Get() { return c; }; probabilistic::detail::CardinalityCounter* Get() { return c; };
protected: protected:
CardinalityVal(); CardinalityVal();
DECLARE_OPAQUE_VALUE(CardinalityVal) DECLARE_OPAQUE_VALUE(CardinalityVal)
private: private:
TypePtr type; TypePtr type;
detail::CompositeHash* hash; detail::CompositeHash* hash;
probabilistic::detail::CardinalityCounter* c; probabilistic::detail::CardinalityCounter* c;
}; };
class ParaglobVal : public OpaqueVal class ParaglobVal : public OpaqueVal {
{
public: public:
explicit ParaglobVal(std::unique_ptr<paraglob::Paraglob> p); explicit ParaglobVal(std::unique_ptr<paraglob::Paraglob> p);
VectorValPtr Get(StringVal*& pattern); VectorValPtr Get(StringVal*& pattern);
ValPtr DoClone(CloneState* state) override; ValPtr DoClone(CloneState* state) override;
bool operator==(const ParaglobVal& other) const; bool operator==(const ParaglobVal& other) const;
protected: protected:
ParaglobVal() : OpaqueVal(paraglob_type) { } ParaglobVal() : OpaqueVal(paraglob_type) {}
DECLARE_OPAQUE_VALUE(ParaglobVal) DECLARE_OPAQUE_VALUE(ParaglobVal)
private: private:
std::unique_ptr<paraglob::Paraglob> internal_paraglob; std::unique_ptr<paraglob::Paraglob> internal_paraglob;
}; };
/** /**
* Base class for metric handles. Handle types are not serializable. * Base class for metric handles. Handle types are not serializable.
*/ */
class TelemetryVal : public OpaqueVal class TelemetryVal : public OpaqueVal {
{
protected: protected:
explicit TelemetryVal(telemetry::IntCounter); explicit TelemetryVal(telemetry::IntCounter);
explicit TelemetryVal(telemetry::IntCounterFamily); explicit TelemetryVal(telemetry::IntCounterFamily);
explicit TelemetryVal(telemetry::DblCounter); explicit TelemetryVal(telemetry::DblCounter);
explicit TelemetryVal(telemetry::DblCounterFamily); explicit TelemetryVal(telemetry::DblCounterFamily);
explicit TelemetryVal(telemetry::IntGauge); explicit TelemetryVal(telemetry::IntGauge);
explicit TelemetryVal(telemetry::IntGaugeFamily); explicit TelemetryVal(telemetry::IntGaugeFamily);
explicit TelemetryVal(telemetry::DblGauge); explicit TelemetryVal(telemetry::DblGauge);
explicit TelemetryVal(telemetry::DblGaugeFamily); explicit TelemetryVal(telemetry::DblGaugeFamily);
explicit TelemetryVal(telemetry::IntHistogram); explicit TelemetryVal(telemetry::IntHistogram);
explicit TelemetryVal(telemetry::IntHistogramFamily); explicit TelemetryVal(telemetry::IntHistogramFamily);
explicit TelemetryVal(telemetry::DblHistogram); explicit TelemetryVal(telemetry::DblHistogram);
explicit TelemetryVal(telemetry::DblHistogramFamily); explicit TelemetryVal(telemetry::DblHistogramFamily);
broker::expected<broker::data> DoSerialize() const override; broker::expected<broker::data> DoSerialize() const override;
bool DoUnserialize(const broker::data& data) override; bool DoUnserialize(const broker::data& data) override;
}; };
template <class Handle> class TelemetryValImpl : public TelemetryVal template<class Handle>
{ class TelemetryValImpl : public TelemetryVal {
public: public:
using HandleType = Handle; using HandleType = Handle;
explicit TelemetryValImpl(Handle hdl) : TelemetryVal(hdl), hdl(hdl) { } explicit TelemetryValImpl(Handle hdl) : TelemetryVal(hdl), hdl(hdl) {}
Handle GetHandle() const noexcept { return hdl; } Handle GetHandle() const noexcept { return hdl; }
protected: protected:
ValPtr DoClone(CloneState*) override { return make_intrusive<TelemetryValImpl>(hdl); } ValPtr DoClone(CloneState*) override { return make_intrusive<TelemetryValImpl>(hdl); }
const char* OpaqueName() const override { return Handle::OpaqueName; } const char* OpaqueName() const override { return Handle::OpaqueName; }
private: private:
Handle hdl; Handle hdl;
}; };
using IntCounterMetricVal = TelemetryValImpl<telemetry::IntCounter>; using IntCounterMetricVal = TelemetryValImpl<telemetry::IntCounter>;
using IntCounterMetricFamilyVal = TelemetryValImpl<telemetry::IntCounterFamily>; using IntCounterMetricFamilyVal = TelemetryValImpl<telemetry::IntCounterFamily>;
@ -470,4 +452,4 @@ using IntHistogramMetricFamilyVal = TelemetryValImpl<telemetry::IntHistogramFami
using DblHistogramMetricVal = TelemetryValImpl<telemetry::DblHistogram>; using DblHistogramMetricVal = TelemetryValImpl<telemetry::DblHistogram>;
using DblHistogramMetricFamilyVal = TelemetryValImpl<telemetry::DblHistogramFamily>; using DblHistogramMetricFamilyVal = TelemetryValImpl<telemetry::DblHistogramFamily>;
} // namespace zeek } // namespace zeek

File diff suppressed because it is too large Load diff

View file

@ -9,83 +9,81 @@
#include "zeek/DNS_Mgr.h" #include "zeek/DNS_Mgr.h"
#include "zeek/script_opt/ScriptOpt.h" #include "zeek/script_opt/ScriptOpt.h"
namespace zeek namespace zeek {
{
/** /**
* Options that define general Zeek processing behavior, usually determined * Options that define general Zeek processing behavior, usually determined
* from command-line arguments. * from command-line arguments.
*/ */
struct Options struct Options {
{ /**
/** * Unset options that aren't meant to be used by the supervisor, but may
* Unset options that aren't meant to be used by the supervisor, but may * make sense for supervised nodes to inherit (as opposed to flagging
* make sense for supervised nodes to inherit (as opposed to flagging * as an error an exiting outright if used in supervisor-mode).
* as an error an exiting outright if used in supervisor-mode). */
*/ void filter_supervisor_options();
void filter_supervisor_options();
/** /**
* Inherit certain options set in the original supervisor parent process * Inherit certain options set in the original supervisor parent process
* and discard the rest. * and discard the rest.
*/ */
void filter_supervised_node_options(); void filter_supervised_node_options();
bool print_version = false; bool print_version = false;
bool print_build_info = false; bool print_build_info = false;
bool print_usage = false; bool print_usage = false;
bool print_execution_time = false; bool print_execution_time = false;
bool print_signature_debug_info = false; bool print_signature_debug_info = false;
int print_plugins = 0; int print_plugins = 0;
std::optional<std::string> debug_log_streams; std::optional<std::string> debug_log_streams;
std::optional<std::string> debug_script_tracing_file; std::optional<std::string> debug_script_tracing_file;
std::optional<std::string> identifier_to_print; std::optional<std::string> identifier_to_print;
std::optional<std::string> script_code_to_exec; std::optional<std::string> script_code_to_exec;
std::vector<std::string> script_prefixes = {""}; // "" = "no prefix" std::vector<std::string> script_prefixes = {""}; // "" = "no prefix"
int signature_re_level = 4; int signature_re_level = 4;
bool ignore_checksums = false; bool ignore_checksums = false;
bool use_watchdog = false; bool use_watchdog = false;
double pseudo_realtime = 0; double pseudo_realtime = 0;
detail::DNS_MgrMode dns_mode = detail::DNS_DEFAULT; detail::DNS_MgrMode dns_mode = detail::DNS_DEFAULT;
bool supervisor_mode = false; bool supervisor_mode = false;
bool parse_only = false; bool parse_only = false;
bool bare_mode = false; bool bare_mode = false;
bool debug_scripts = false; bool debug_scripts = false;
bool perftools_check_leaks = false; bool perftools_check_leaks = false;
bool perftools_profile = false; bool perftools_profile = false;
bool deterministic_mode = false; bool deterministic_mode = false;
bool abort_on_scripting_errors = false; bool abort_on_scripting_errors = false;
bool no_unused_warnings = false; bool no_unused_warnings = false;
bool run_unit_tests = false; bool run_unit_tests = false;
std::vector<std::string> doctest_args; std::vector<std::string> doctest_args;
std::optional<std::string> pcap_filter; std::optional<std::string> pcap_filter;
std::optional<std::string> interface; std::optional<std::string> interface;
std::optional<std::string> pcap_file; std::optional<std::string> pcap_file;
std::vector<std::string> signature_files; std::vector<std::string> signature_files;
std::optional<std::string> pcap_output_file; std::optional<std::string> pcap_output_file;
std::optional<std::string> random_seed_input_file; std::optional<std::string> random_seed_input_file;
std::optional<std::string> random_seed_output_file; std::optional<std::string> random_seed_output_file;
std::optional<std::string> process_status_file; std::optional<std::string> process_status_file;
std::optional<std::string> zeekygen_config_file; std::optional<std::string> zeekygen_config_file;
std::optional<std::string> unprocessed_output_file; std::optional<std::string> unprocessed_output_file;
std::optional<std::string> event_trace_file; std::optional<std::string> event_trace_file;
std::set<std::string> plugins_to_load; std::set<std::string> plugins_to_load;
std::vector<std::string> scripts_to_load; std::vector<std::string> scripts_to_load;
std::vector<std::string> script_options_to_set; std::vector<std::string> script_options_to_set;
std::vector<std::string> script_args; std::vector<std::string> script_args;
// For script optimization: // For script optimization:
detail::AnalyOpt analysis_options; detail::AnalyOpt analysis_options;
}; };
/** /**
* Parse Zeek command-line arguments. * Parse Zeek command-line arguments.
@ -107,4 +105,4 @@ void usage(const char* prog, int code = 1);
*/ */
bool fake_dns(); bool fake_dns();
} // namespace zeek } // namespace zeek

View file

@ -4,37 +4,33 @@
#include "zeek/Val.h" #include "zeek/Val.h"
namespace zeek::detail namespace zeek::detail {
{
bool would_overflow(const zeek::Type* from_type, const zeek::Type* to_type, const Val* val) bool would_overflow(const zeek::Type* from_type, const zeek::Type* to_type, const Val* val) {
{ if ( ! to_type || ! from_type )
if ( ! to_type || ! from_type ) return true;
return true;
if ( same_type(to_type, from_type) ) if ( same_type(to_type, from_type) )
return false; return false;
if ( to_type->InternalType() == TYPE_INTERNAL_DOUBLE ) if ( to_type->InternalType() == TYPE_INTERNAL_DOUBLE )
return false; return false;
if ( to_type->InternalType() == TYPE_INTERNAL_UNSIGNED ) if ( to_type->InternalType() == TYPE_INTERNAL_UNSIGNED ) {
{ if ( from_type->InternalType() == TYPE_INTERNAL_DOUBLE )
if ( from_type->InternalType() == TYPE_INTERNAL_DOUBLE ) return double_to_count_would_overflow(val->InternalDouble());
return double_to_count_would_overflow(val->InternalDouble()); if ( from_type->InternalType() == TYPE_INTERNAL_INT )
if ( from_type->InternalType() == TYPE_INTERNAL_INT ) return int_to_count_would_overflow(val->InternalInt());
return int_to_count_would_overflow(val->InternalInt()); }
}
if ( to_type->InternalType() == TYPE_INTERNAL_INT ) if ( to_type->InternalType() == TYPE_INTERNAL_INT ) {
{ if ( from_type->InternalType() == TYPE_INTERNAL_DOUBLE )
if ( from_type->InternalType() == TYPE_INTERNAL_DOUBLE ) return double_to_int_would_overflow(val->InternalDouble());
return double_to_int_would_overflow(val->InternalDouble()); if ( from_type->InternalType() == TYPE_INTERNAL_UNSIGNED )
if ( from_type->InternalType() == TYPE_INTERNAL_UNSIGNED ) return count_to_int_would_overflow(val->InternalUnsigned());
return count_to_int_would_overflow(val->InternalUnsigned()); }
}
return false; return false;
} }
} } // namespace zeek::detail

View file

@ -4,29 +4,18 @@
#include "zeek/Type.h" #include "zeek/Type.h"
namespace zeek::detail namespace zeek::detail {
{
inline bool double_to_count_would_overflow(double v) inline bool double_to_count_would_overflow(double v) { return v < 0.0 || v > static_cast<double>(UINT64_MAX); }
{
return v < 0.0 || v > static_cast<double>(UINT64_MAX);
}
inline bool int_to_count_would_overflow(zeek_int_t v) inline bool int_to_count_would_overflow(zeek_int_t v) { return v < 0.0; }
{
return v < 0.0;
}
inline bool double_to_int_would_overflow(double v) inline bool double_to_int_would_overflow(double v) {
{ return v < static_cast<double>(INT64_MIN) || v > static_cast<double>(INT64_MAX);
return v < static_cast<double>(INT64_MIN) || v > static_cast<double>(INT64_MAX); }
}
inline bool count_to_int_would_overflow(zeek_uint_t v) inline bool count_to_int_would_overflow(zeek_uint_t v) { return v > INT64_MAX; }
{
return v > INT64_MAX;
}
extern bool would_overflow(const zeek::Type* from_type, const zeek::Type* to_type, const Val* val); extern bool would_overflow(const zeek::Type* from_type, const zeek::Type* to_type, const Val* val);
} } // namespace zeek::detail

View file

@ -2,121 +2,106 @@
#include "zeek/IP.h" #include "zeek/IP.h"
namespace zeek::detail namespace zeek::detail {
{
void PacketFilter::DeleteFilter(void* data) void PacketFilter::DeleteFilter(void* data) {
{ auto f = static_cast<Filter*>(data);
auto f = static_cast<Filter*>(data); delete f;
delete f; }
}
PacketFilter::PacketFilter(bool arg_default) PacketFilter::PacketFilter(bool arg_default) {
{ default_match = arg_default;
default_match = arg_default; src_filter.SetDeleteFunction(PacketFilter::DeleteFilter);
src_filter.SetDeleteFunction(PacketFilter::DeleteFilter); dst_filter.SetDeleteFunction(PacketFilter::DeleteFilter);
dst_filter.SetDeleteFunction(PacketFilter::DeleteFilter); }
}
void PacketFilter::AddSrc(const IPAddr& src, uint32_t tcp_flags, double probability) void PacketFilter::AddSrc(const IPAddr& src, uint32_t tcp_flags, double probability) {
{ Filter* f = new Filter;
Filter* f = new Filter; f->tcp_flags = tcp_flags;
f->tcp_flags = tcp_flags; f->probability = probability * static_cast<double>(util::detail::max_random());
f->probability = probability * static_cast<double>(util::detail::max_random()); auto prev = static_cast<Filter*>(src_filter.Insert(src, 128, f));
auto prev = static_cast<Filter*>(src_filter.Insert(src, 128, f)); delete prev;
delete prev; }
}
void PacketFilter::AddSrc(Val* src, uint32_t tcp_flags, double probability) void PacketFilter::AddSrc(Val* src, uint32_t tcp_flags, double probability) {
{ Filter* f = new Filter;
Filter* f = new Filter; f->tcp_flags = tcp_flags;
f->tcp_flags = tcp_flags; f->probability = probability * static_cast<double>(util::detail::max_random());
f->probability = probability * static_cast<double>(util::detail::max_random()); auto prev = static_cast<Filter*>(src_filter.Insert(src, f));
auto prev = static_cast<Filter*>(src_filter.Insert(src, f)); delete prev;
delete prev; }
}
void PacketFilter::AddDst(const IPAddr& dst, uint32_t tcp_flags, double probability) void PacketFilter::AddDst(const IPAddr& dst, uint32_t tcp_flags, double probability) {
{ Filter* f = new Filter;
Filter* f = new Filter; f->tcp_flags = tcp_flags;
f->tcp_flags = tcp_flags; f->probability = probability * static_cast<double>(util::detail::max_random());
f->probability = probability * static_cast<double>(util::detail::max_random()); auto prev = static_cast<Filter*>(dst_filter.Insert(dst, 128, f));
auto prev = static_cast<Filter*>(dst_filter.Insert(dst, 128, f)); delete prev;
delete prev; }
}
void PacketFilter::AddDst(Val* dst, uint32_t tcp_flags, double probability) void PacketFilter::AddDst(Val* dst, uint32_t tcp_flags, double probability) {
{ Filter* f = new Filter;
Filter* f = new Filter; f->tcp_flags = tcp_flags;
f->tcp_flags = tcp_flags; f->probability = probability * static_cast<double>(util::detail::max_random());
f->probability = probability * static_cast<double>(util::detail::max_random()); auto prev = static_cast<Filter*>(dst_filter.Insert(dst, f));
auto prev = static_cast<Filter*>(dst_filter.Insert(dst, f)); delete prev;
delete prev; }
}
bool PacketFilter::RemoveSrc(const IPAddr& src) bool PacketFilter::RemoveSrc(const IPAddr& src) {
{ auto f = static_cast<Filter*>(src_filter.Remove(src, 128));
auto f = static_cast<Filter*>(src_filter.Remove(src, 128)); delete f;
delete f; return f != nullptr;
return f != nullptr; }
}
bool PacketFilter::RemoveSrc(Val* src) bool PacketFilter::RemoveSrc(Val* src) {
{ auto f = static_cast<Filter*>(src_filter.Remove(src));
auto f = static_cast<Filter*>(src_filter.Remove(src)); delete f;
delete f; return f != nullptr;
return f != nullptr; }
}
bool PacketFilter::RemoveDst(const IPAddr& dst) bool PacketFilter::RemoveDst(const IPAddr& dst) {
{ auto f = static_cast<Filter*>(dst_filter.Remove(dst, 128));
auto f = static_cast<Filter*>(dst_filter.Remove(dst, 128)); delete f;
delete f; return f != nullptr;
return f != nullptr; }
}
bool PacketFilter::RemoveDst(Val* dst) bool PacketFilter::RemoveDst(Val* dst) {
{ auto f = static_cast<Filter*>(dst_filter.Remove(dst));
auto f = static_cast<Filter*>(dst_filter.Remove(dst)); delete f;
delete f; return f != nullptr;
return f != nullptr; }
}
bool PacketFilter::Match(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen) bool PacketFilter::Match(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen) {
{ Filter* f = (Filter*)src_filter.Lookup(ip->SrcAddr(), 128);
Filter* f = (Filter*)src_filter.Lookup(ip->SrcAddr(), 128); if ( f )
if ( f ) return MatchFilter(*f, *ip, len, caplen);
return MatchFilter(*f, *ip, len, caplen);
f = (Filter*)dst_filter.Lookup(ip->DstAddr(), 128); f = (Filter*)dst_filter.Lookup(ip->DstAddr(), 128);
if ( f ) if ( f )
return MatchFilter(*f, *ip, len, caplen); return MatchFilter(*f, *ip, len, caplen);
return default_match; return default_match;
} }
bool PacketFilter::MatchFilter(const Filter& f, const IP_Hdr& ip, int len, int caplen) bool PacketFilter::MatchFilter(const Filter& f, const IP_Hdr& ip, int len, int caplen) {
{ if ( ip.NextProto() == IPPROTO_TCP && f.tcp_flags ) {
if ( ip.NextProto() == IPPROTO_TCP && f.tcp_flags ) // Caution! The packet sanity checks have not been performed yet
{ int ip_hdr_len = ip.HdrLen();
// Caution! The packet sanity checks have not been performed yet len -= ip_hdr_len; // remove IP header
int ip_hdr_len = ip.HdrLen(); caplen -= ip_hdr_len;
len -= ip_hdr_len; // remove IP header
caplen -= ip_hdr_len;
if ( (unsigned int)len < sizeof(struct tcphdr) || if ( (unsigned int)len < sizeof(struct tcphdr) || (unsigned int)caplen < sizeof(struct tcphdr) )
(unsigned int)caplen < sizeof(struct tcphdr) ) // Packet too short, will be dropped anyway.
// Packet too short, will be dropped anyway. return false;
return false;
const struct tcphdr* tp = (const struct tcphdr*)ip.Payload(); const struct tcphdr* tp = (const struct tcphdr*)ip.Payload();
if ( tp->th_flags & f.tcp_flags ) if ( tp->th_flags & f.tcp_flags )
// At least one of the flags is set, so don't drop // At least one of the flags is set, so don't drop
return false; return false;
} }
return util::detail::random_number() < f.probability; return util::detail::random_number() < f.probability;
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -7,54 +7,50 @@
#include "zeek/IPAddr.h" #include "zeek/IPAddr.h"
#include "zeek/PrefixTable.h" #include "zeek/PrefixTable.h"
namespace zeek namespace zeek {
{
class IP_Hdr; class IP_Hdr;
class Val; class Val;
namespace detail namespace detail {
{
class PacketFilter class PacketFilter {
{
public: public:
explicit PacketFilter(bool arg_default); explicit PacketFilter(bool arg_default);
~PacketFilter() { } ~PacketFilter() {}
// Drops all packets from a particular source (which may be given // Drops all packets from a particular source (which may be given
// as an AddrVal or a SubnetVal) which hasn't any of TCP flags set // as an AddrVal or a SubnetVal) which hasn't any of TCP flags set
// (TH_*) with the given probability (from 0..MAX_PROB). // (TH_*) with the given probability (from 0..MAX_PROB).
void AddSrc(const IPAddr& src, uint32_t tcp_flags, double probability); void AddSrc(const IPAddr& src, uint32_t tcp_flags, double probability);
void AddSrc(Val* src, uint32_t tcp_flags, double probability); void AddSrc(Val* src, uint32_t tcp_flags, double probability);
void AddDst(const IPAddr& src, uint32_t tcp_flags, double probability); void AddDst(const IPAddr& src, uint32_t tcp_flags, double probability);
void AddDst(Val* src, uint32_t tcp_flags, double probability); void AddDst(Val* src, uint32_t tcp_flags, double probability);
// Removes the filter entry for the given src/dst // Removes the filter entry for the given src/dst
// Returns false if filter doesn not exist. // Returns false if filter doesn not exist.
bool RemoveSrc(const IPAddr& src); bool RemoveSrc(const IPAddr& src);
bool RemoveSrc(Val* dst); bool RemoveSrc(Val* dst);
bool RemoveDst(const IPAddr& dst); bool RemoveDst(const IPAddr& dst);
bool RemoveDst(Val* dst); bool RemoveDst(Val* dst);
// Returns true if packet matches a drop filter // Returns true if packet matches a drop filter
bool Match(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen); bool Match(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen);
private: private:
struct Filter struct Filter {
{ uint32_t tcp_flags;
uint32_t tcp_flags; double probability;
double probability; };
};
static void DeleteFilter(void* data); static void DeleteFilter(void* data);
bool MatchFilter(const Filter& f, const IP_Hdr& ip, int len, int caplen); bool MatchFilter(const Filter& f, const IP_Hdr& ip, int len, int caplen);
bool default_match; bool default_match;
PrefixTable src_filter; PrefixTable src_filter;
PrefixTable dst_filter; PrefixTable dst_filter;
}; };
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -9,154 +9,135 @@
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
namespace zeek::detail namespace zeek::detail {
{
static void pipe_fail(int eno) static void pipe_fail(int eno) {
{ char tmp[256];
char tmp[256]; zeek::util::zeek_strerror_r(eno, tmp, sizeof(tmp));
zeek::util::zeek_strerror_r(eno, tmp, sizeof(tmp));
if ( reporter ) if ( reporter )
reporter->FatalError("Pipe failure: %s", tmp); reporter->FatalError("Pipe failure: %s", tmp);
else else
fprintf(stderr, "Pipe failure: %s", tmp); fprintf(stderr, "Pipe failure: %s", tmp);
} }
static int set_flags(int fd, int flags) static int set_flags(int fd, int flags) {
{ auto rval = fcntl(fd, F_GETFD);
auto rval = fcntl(fd, F_GETFD);
if ( rval == -1 ) if ( rval == -1 )
pipe_fail(errno); pipe_fail(errno);
if ( flags ) if ( flags ) {
{ rval |= flags;
rval |= flags;
if ( fcntl(fd, F_SETFD, rval) == -1 ) if ( fcntl(fd, F_SETFD, rval) == -1 )
pipe_fail(errno); pipe_fail(errno);
} }
return rval; return rval;
} }
static int unset_flags(int fd, int flags) static int unset_flags(int fd, int flags) {
{ auto rval = fcntl(fd, F_GETFD);
auto rval = fcntl(fd, F_GETFD);
if ( rval == -1 ) if ( rval == -1 )
pipe_fail(errno); pipe_fail(errno);
if ( flags ) if ( flags ) {
{ rval &= ~flags;
rval &= ~flags;
if ( fcntl(fd, F_SETFD, rval) == -1 ) if ( fcntl(fd, F_SETFD, rval) == -1 )
pipe_fail(errno); pipe_fail(errno);
} }
return rval; return rval;
} }
static int set_status_flags(int fd, int flags) static int set_status_flags(int fd, int flags) {
{ auto rval = fcntl(fd, F_GETFL);
auto rval = fcntl(fd, F_GETFL);
if ( rval == -1 ) if ( rval == -1 )
pipe_fail(errno); pipe_fail(errno);
if ( flags ) if ( flags ) {
{ rval |= flags;
rval |= flags;
if ( fcntl(fd, F_SETFL, rval) == -1 ) if ( fcntl(fd, F_SETFL, rval) == -1 )
pipe_fail(errno); pipe_fail(errno);
} }
return rval; return rval;
} }
static int dup_or_fail(int fd, int flags, int status_flags) static int dup_or_fail(int fd, int flags, int status_flags) {
{ int rval = dup(fd);
int rval = dup(fd);
if ( rval < 0 ) if ( rval < 0 )
pipe_fail(errno); pipe_fail(errno);
set_flags(fd, flags); set_flags(fd, flags);
set_status_flags(fd, status_flags); set_status_flags(fd, status_flags);
return rval; return rval;
} }
Pipe::Pipe(int flags0, int flags1, int status_flags0, int status_flags1, int* arg_fds) Pipe::Pipe(int flags0, int flags1, int status_flags0, int status_flags1, int* arg_fds) {
{ if ( arg_fds ) {
if ( arg_fds ) fds[0] = arg_fds[0];
{ fds[1] = arg_fds[1];
fds[0] = arg_fds[0]; }
fds[1] = arg_fds[1]; else {
} // pipe2 can set flags atomically, but not yet available everywhere.
else if ( ::pipe(fds) )
{ pipe_fail(errno);
// pipe2 can set flags atomically, but not yet available everywhere. }
if ( ::pipe(fds) )
pipe_fail(errno);
}
flags[0] = set_flags(fds[0], flags0); flags[0] = set_flags(fds[0], flags0);
flags[1] = set_flags(fds[1], flags1); flags[1] = set_flags(fds[1], flags1);
status_flags[0] = set_status_flags(fds[0], status_flags0); status_flags[0] = set_status_flags(fds[0], status_flags0);
status_flags[1] = set_status_flags(fds[1], status_flags1); status_flags[1] = set_status_flags(fds[1], status_flags1);
} }
void Pipe::SetFlags(int arg_flags) void Pipe::SetFlags(int arg_flags) {
{ flags[0] = set_flags(fds[0], arg_flags);
flags[0] = set_flags(fds[0], arg_flags); flags[1] = set_flags(fds[1], arg_flags);
flags[1] = set_flags(fds[1], arg_flags); }
}
void Pipe::UnsetFlags(int arg_flags) void Pipe::UnsetFlags(int arg_flags) {
{ flags[0] = unset_flags(fds[0], arg_flags);
flags[0] = unset_flags(fds[0], arg_flags); flags[1] = unset_flags(fds[1], arg_flags);
flags[1] = unset_flags(fds[1], arg_flags); }
}
Pipe::~Pipe() Pipe::~Pipe() {
{ close(fds[0]);
close(fds[0]); close(fds[1]);
close(fds[1]); }
}
Pipe::Pipe(const Pipe& other) Pipe::Pipe(const Pipe& other) {
{ fds[0] = dup_or_fail(other.fds[0], other.flags[0], other.status_flags[0]);
fds[0] = dup_or_fail(other.fds[0], other.flags[0], other.status_flags[0]); fds[1] = dup_or_fail(other.fds[1], other.flags[1], other.status_flags[1]);
fds[1] = dup_or_fail(other.fds[1], other.flags[1], other.status_flags[1]); flags[0] = other.flags[0];
flags[0] = other.flags[0]; flags[1] = other.flags[1];
flags[1] = other.flags[1]; status_flags[0] = other.status_flags[0];
status_flags[0] = other.status_flags[0]; status_flags[1] = other.status_flags[1];
status_flags[1] = other.status_flags[1]; }
}
Pipe& Pipe::operator=(const Pipe& other) Pipe& Pipe::operator=(const Pipe& other) {
{ if ( this == &other )
if ( this == &other ) return *this;
return *this;
close(fds[0]); close(fds[0]);
close(fds[1]); close(fds[1]);
fds[0] = dup_or_fail(other.fds[0], other.flags[0], other.status_flags[0]); fds[0] = dup_or_fail(other.fds[0], other.flags[0], other.status_flags[0]);
fds[1] = dup_or_fail(other.fds[1], other.flags[1], other.status_flags[1]); fds[1] = dup_or_fail(other.fds[1], other.flags[1], other.status_flags[1]);
flags[0] = other.flags[0]; flags[0] = other.flags[0];
flags[1] = other.flags[1]; flags[1] = other.flags[1];
status_flags[0] = other.status_flags[0]; status_flags[0] = other.status_flags[0];
status_flags[1] = other.status_flags[1]; status_flags[1] = other.status_flags[1];
return *this; return *this;
} }
PipePair::PipePair(int flags, int status_flags, int* fds) PipePair::PipePair(int flags, int status_flags, int* fds)
: pipes{Pipe(flags, flags, status_flags, status_flags, fds ? fds + 0 : nullptr), : pipes{Pipe(flags, flags, status_flags, status_flags, fds ? fds + 0 : nullptr),
Pipe(flags, flags, status_flags, status_flags, fds ? fds + 2 : nullptr)} Pipe(flags, flags, status_flags, status_flags, fds ? fds + 2 : nullptr)} {}
{
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -2,129 +2,125 @@
#pragma once #pragma once
namespace zeek::detail namespace zeek::detail {
{
class Pipe class Pipe {
{
public: public:
/** /**
* Create a pair of file descriptors via pipe(), or aborts if it cannot. * Create a pair of file descriptors via pipe(), or aborts if it cannot.
* @param flags0 file descriptor flags to set on read end of pipe. * @param flags0 file descriptor flags to set on read end of pipe.
* @param flags1 file descriptor flags to set on write end of pipe. * @param flags1 file descriptor flags to set on write end of pipe.
* @param status_flags0 descriptor status flags to set on read end of pipe. * @param status_flags0 descriptor status flags to set on read end of pipe.
* @param status_flags1 descriptor status flags to set on write end of pipe. * @param status_flags1 descriptor status flags to set on write end of pipe.
* @param fds may be supplied to open an existing file descriptors rather * @param fds may be supplied to open an existing file descriptors rather
* than create ones from a new pipe. Should point to memory containing * than create ones from a new pipe. Should point to memory containing
* two consecutive file descriptors, the "read" one and then the "write" one. * two consecutive file descriptors, the "read" one and then the "write" one.
*/ */
explicit Pipe(int flags0 = 0, int flags1 = 0, int status_flags0 = 0, int status_flags1 = 0, explicit Pipe(int flags0 = 0, int flags1 = 0, int status_flags0 = 0, int status_flags1 = 0, int* fds = nullptr);
int* fds = nullptr);
/** /**
* Close the pair of file descriptors owned by the object. * Close the pair of file descriptors owned by the object.
*/ */
~Pipe(); ~Pipe();
/** /**
* Make a copy of another Pipe object (file descriptors are dup'd). * Make a copy of another Pipe object (file descriptors are dup'd).
*/ */
Pipe(const Pipe& other); Pipe(const Pipe& other);
/** /**
* Assign a Pipe object by closing file descriptors and duping those of * Assign a Pipe object by closing file descriptors and duping those of
* the other. * the other.
*/ */
Pipe& operator=(const Pipe& other); Pipe& operator=(const Pipe& other);
/** /**
* @return the file descriptor associated with the read-end of the pipe. * @return the file descriptor associated with the read-end of the pipe.
*/ */
int ReadFD() const { return fds[0]; } int ReadFD() const { return fds[0]; }
/** /**
* @return the file descriptor associated with the write-end of the pipe. * @return the file descriptor associated with the write-end of the pipe.
*/ */
int WriteFD() const { return fds[1]; } int WriteFD() const { return fds[1]; }
/** /**
* Sets the given file descriptor flags for both the read and write end * Sets the given file descriptor flags for both the read and write end
* of the pipe. * of the pipe.
*/ */
void SetFlags(int flags); void SetFlags(int flags);
/** /**
* Unsets the given file descriptor flags for both the read and write end * Unsets the given file descriptor flags for both the read and write end
* of the pipe. * of the pipe.
*/ */
void UnsetFlags(int flags); void UnsetFlags(int flags);
private: private:
int fds[2]; int fds[2];
int flags[2]; int flags[2];
int status_flags[2]; int status_flags[2];
}; };
/** /**
* A pair of pipes that can be used for bi-directional IPC. * A pair of pipes that can be used for bi-directional IPC.
*/ */
class PipePair class PipePair {
{
public: public:
/** /**
* Create a pair of pipes * Create a pair of pipes
* @param flags file descriptor flags to set on pipes * @param flags file descriptor flags to set on pipes
* @status_flags descriptor status flags to set on pipes * @status_flags descriptor status flags to set on pipes
* @fds may be supplied to open existing file descriptors rather * @fds may be supplied to open existing file descriptors rather
* than create ones from a new pair of pipes. Should point to memory * than create ones from a new pair of pipes. Should point to memory
* containing four consecutive file descriptors, "read" end and "write" end * containing four consecutive file descriptors, "read" end and "write" end
* of the first pipe followed by the "read" end and "write" end of the * of the first pipe followed by the "read" end and "write" end of the
* second pipe. * second pipe.
*/ */
PipePair(int flags, int status_flags, int* fds = nullptr); PipePair(int flags, int status_flags, int* fds = nullptr);
/** /**
* @return the pipe used for receiving input * @return the pipe used for receiving input
*/ */
Pipe& In() { return pipes[swapped]; } Pipe& In() { return pipes[swapped]; }
/** /**
* @return the pipe used for sending output * @return the pipe used for sending output
*/ */
Pipe& Out() { return pipes[! swapped]; } Pipe& Out() { return pipes[! swapped]; }
/** /**
* @return the pipe used for receiving input * @return the pipe used for receiving input
*/ */
const Pipe& In() const { return pipes[swapped]; } const Pipe& In() const { return pipes[swapped]; }
/** /**
* @return the pipe used for sending output * @return the pipe used for sending output
*/ */
const Pipe& Out() const { return pipes[! swapped]; } const Pipe& Out() const { return pipes[! swapped]; }
/** /**
* @return a file descriptor that may used for receiving messages by * @return a file descriptor that may used for receiving messages by
* polling/reading it. * polling/reading it.
*/ */
int InFD() const { return In().ReadFD(); } int InFD() const { return In().ReadFD(); }
/** /**
* @return a file descriptor that may be used for sending messages by * @return a file descriptor that may be used for sending messages by
* writing to it. * writing to it.
*/ */
int OutFD() const { return Out().WriteFD(); } int OutFD() const { return Out().WriteFD(); }
/** /**
* Swaps the meaning of the pipes in the pair. E.g. call this after * Swaps the meaning of the pipes in the pair. E.g. call this after
* fork()'ing so that the child process uses the right pipe for * fork()'ing so that the child process uses the right pipe for
* reading/writing. * reading/writing.
*/ */
void Swap() { swapped = ! swapped; } void Swap() { swapped = ! swapped; }
private: private:
Pipe pipes[2]; Pipe pipes[2];
bool swapped = false; bool swapped = false;
}; };
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -17,186 +17,162 @@
using namespace std; using namespace std;
struct PolicyFile struct PolicyFile {
{ PolicyFile() {
PolicyFile() filedata = nullptr;
{ lmtime = 0;
filedata = nullptr; }
lmtime = 0; ~PolicyFile() {
} delete[] filedata;
~PolicyFile() filedata = nullptr;
{ }
delete[] filedata;
filedata = nullptr;
}
time_t lmtime; time_t lmtime;
char* filedata; char* filedata;
vector<const char*> lines; vector<const char*> lines;
}; };
using PolicyFileMap = map<string, PolicyFile*>; using PolicyFileMap = map<string, PolicyFile*>;
static PolicyFileMap policy_files; static PolicyFileMap policy_files;
namespace zeek::detail namespace zeek::detail {
{
int how_many_lines_in(const char* policy_filename) int how_many_lines_in(const char* policy_filename) {
{ if ( ! policy_filename )
if ( ! policy_filename ) reporter->InternalError("NULL value passed to how_many_lines_in\n");
reporter->InternalError("NULL value passed to how_many_lines_in\n");
FILE* throwaway = fopen(policy_filename, "r"); FILE* throwaway = fopen(policy_filename, "r");
if ( ! throwaway ) if ( ! throwaway ) {
{ debug_msg("Could not open policy file: %s.\n", policy_filename);
debug_msg("Could not open policy file: %s.\n", policy_filename); return -1;
return -1; }
}
fclose(throwaway); fclose(throwaway);
PolicyFileMap::iterator match; PolicyFileMap::iterator match;
match = policy_files.find(policy_filename); match = policy_files.find(policy_filename);
if ( match == policy_files.end() ) if ( match == policy_files.end() ) {
{ match = policy_files.find(policy_filename);
match = policy_files.find(policy_filename); if ( match == policy_files.end() ) {
if ( match == policy_files.end() ) debug_msg("Policy file %s was not loaded.\n", policy_filename);
{ return -1;
debug_msg("Policy file %s was not loaded.\n", policy_filename); }
return -1; }
}
}
PolicyFile* pf = match->second; PolicyFile* pf = match->second;
return pf->lines.size(); return pf->lines.size();
} }
bool LoadPolicyFileText(const char* policy_filename, bool LoadPolicyFileText(const char* policy_filename, const std::optional<std::string>& preloaded_content) {
const std::optional<std::string>& preloaded_content) if ( ! policy_filename )
{ return true;
if ( ! policy_filename )
return true;
if ( policy_files.find(policy_filename) != policy_files.end() ) if ( policy_files.find(policy_filename) != policy_files.end() )
debug_msg("Policy file %s already loaded\n", policy_filename); debug_msg("Policy file %s already loaded\n", policy_filename);
PolicyFile* pf = new PolicyFile; PolicyFile* pf = new PolicyFile;
policy_files.insert(PolicyFileMap::value_type(policy_filename, pf)); policy_files.insert(PolicyFileMap::value_type(policy_filename, pf));
if ( preloaded_content ) if ( preloaded_content ) {
{ auto size = preloaded_content->size();
auto size = preloaded_content->size(); pf->filedata = new char[size + 1];
pf->filedata = new char[size + 1]; memcpy(pf->filedata, preloaded_content->data(), size);
memcpy(pf->filedata, preloaded_content->data(), size); pf->filedata[size] = '\0';
pf->filedata[size] = '\0'; }
} else {
else FILE* f = fopen(policy_filename, "r");
{
FILE* f = fopen(policy_filename, "r");
if ( ! f ) if ( ! f ) {
{ debug_msg("Could not open policy file: %s.\n", policy_filename);
debug_msg("Could not open policy file: %s.\n", policy_filename); return false;
return false; }
}
struct stat st; struct stat st;
if ( fstat(fileno(f), &st) != 0 ) if ( fstat(fileno(f), &st) != 0 ) {
{ char buf[256];
char buf[256]; util::zeek_strerror_r(errno, buf, sizeof(buf));
util::zeek_strerror_r(errno, buf, sizeof(buf)); reporter->Error("fstat failed on %s: %s", policy_filename, buf);
reporter->Error("fstat failed on %s: %s", policy_filename, buf); fclose(f);
fclose(f); return false;
return false; }
}
pf->lmtime = st.st_mtime; pf->lmtime = st.st_mtime;
off_t size = st.st_size; off_t size = st.st_size;
// ### This code is not necessarily Unicode safe! // ### This code is not necessarily Unicode safe!
// (probably fine with UTF-8) // (probably fine with UTF-8)
pf->filedata = new char[size + 1]; pf->filedata = new char[size + 1];
size_t n = fread(pf->filedata, 1, size, f); size_t n = fread(pf->filedata, 1, size, f);
if ( ferror(f) ) if ( ferror(f) )
reporter->InternalError("Failed to fread() file data"); reporter->InternalError("Failed to fread() file data");
pf->filedata[n] = 0; pf->filedata[n] = 0;
fclose(f); fclose(f);
} }
// Separate the string by newlines. // Separate the string by newlines.
pf->lines.push_back(pf->filedata); pf->lines.push_back(pf->filedata);
for ( char* iter = pf->filedata; *iter; ++iter ) for ( char* iter = pf->filedata; *iter; ++iter ) {
{ if ( *iter == '\n' ) {
if ( *iter == '\n' ) *iter = 0;
{ if ( *(iter + 1) )
*iter = 0; pf->lines.push_back(iter + 1);
if ( *(iter + 1) ) }
pf->lines.push_back(iter + 1); }
}
}
for ( int i = 0; i < int(pf->lines.size()); ++i ) for ( int i = 0; i < int(pf->lines.size()); ++i )
assert(pf->lines[i][0] != '\n'); assert(pf->lines[i][0] != '\n');
return true; return true;
} }
// REMEMBER: line number arguments are indexed from 0. // REMEMBER: line number arguments are indexed from 0.
bool PrintLines(const char* policy_filename, unsigned int start_line, unsigned int how_many_lines, bool PrintLines(const char* policy_filename, unsigned int start_line, unsigned int how_many_lines, bool show_numbers) {
bool show_numbers) if ( ! policy_filename )
{ return true;
if ( ! policy_filename )
return true;
FILE* throwaway = fopen(policy_filename, "r"); FILE* throwaway = fopen(policy_filename, "r");
if ( ! throwaway ) if ( ! throwaway ) {
{ debug_msg("Could not open policy file: %s.\n", policy_filename);
debug_msg("Could not open policy file: %s.\n", policy_filename); return false;
return false; }
}
fclose(throwaway); fclose(throwaway);
PolicyFileMap::iterator match; PolicyFileMap::iterator match;
match = policy_files.find(policy_filename); match = policy_files.find(policy_filename);
if ( match == policy_files.end() ) if ( match == policy_files.end() ) {
{ match = policy_files.find(policy_filename);
match = policy_files.find(policy_filename); if ( match == policy_files.end() ) {
if ( match == policy_files.end() ) debug_msg("Policy file %s was not loaded.\n", policy_filename);
{ return false;
debug_msg("Policy file %s was not loaded.\n", policy_filename); }
return false; }
}
}
PolicyFile* pf = match->second; PolicyFile* pf = match->second;
if ( start_line < 1 ) if ( start_line < 1 )
start_line = 1; start_line = 1;
if ( start_line > pf->lines.size() ) if ( start_line > pf->lines.size() ) {
{ debug_msg("Line number %d out of range; %s has %d lines\n", start_line, policy_filename, int(pf->lines.size()));
debug_msg("Line number %d out of range; %s has %d lines\n", start_line, policy_filename, return false;
int(pf->lines.size())); }
return false;
}
if ( start_line + how_many_lines - 1 > pf->lines.size() ) if ( start_line + how_many_lines - 1 > pf->lines.size() )
how_many_lines = pf->lines.size() - start_line + 1; how_many_lines = pf->lines.size() - start_line + 1;
for ( unsigned int i = 0; i < how_many_lines; ++i ) for ( unsigned int i = 0; i < how_many_lines; ++i ) {
{ if ( show_numbers )
if ( show_numbers ) debug_msg("%d\t", i + start_line);
debug_msg("%d\t", i + start_line);
const char* line = pf->lines[start_line + i - 1]; const char* line = pf->lines[start_line + i - 1];
debug_msg("%s\n", line); debug_msg("%s\n", line);
} }
return true; return true;
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -17,16 +17,13 @@
#include <optional> #include <optional>
#include <string> #include <string>
namespace zeek::detail namespace zeek::detail {
{
int how_many_lines_in(const char* policy_filename); int how_many_lines_in(const char* policy_filename);
bool LoadPolicyFileText(const char* policy_filename, bool LoadPolicyFileText(const char* policy_filename, const std::optional<std::string>& preloaded_content = {});
const std::optional<std::string>& preloaded_content = {});
// start_line is 1-based (the intuitive way) // start_line is 1-based (the intuitive way)
bool PrintLines(const char* policy_filename, unsigned int start_line, unsigned int how_many_lines, bool PrintLines(const char* policy_filename, unsigned int start_line, unsigned int how_many_lines, bool show_numbers);
bool show_numbers);
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -3,206 +3,168 @@
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
#include "zeek/Val.h" #include "zeek/Val.h"
namespace zeek::detail namespace zeek::detail {
{
prefix_t* PrefixTable::MakePrefix(const IPAddr& addr, int width) prefix_t* PrefixTable::MakePrefix(const IPAddr& addr, int width) {
{ prefix_t* prefix = (prefix_t*)util::safe_malloc(sizeof(prefix_t));
prefix_t* prefix = (prefix_t*)util::safe_malloc(sizeof(prefix_t));
addr.CopyIPv6(&prefix->add.sin6); addr.CopyIPv6(&prefix->add.sin6);
prefix->family = AF_INET6; prefix->family = AF_INET6;
prefix->bitlen = width; prefix->bitlen = width;
prefix->ref_count = 1; prefix->ref_count = 1;
return prefix; return prefix;
} }
IPPrefix PrefixTable::PrefixToIPPrefix(prefix_t* prefix) IPPrefix PrefixTable::PrefixToIPPrefix(prefix_t* prefix) {
{ return IPPrefix(IPAddr(IPv6, reinterpret_cast<const uint32_t*>(&prefix->add.sin6), IPAddr::Network), prefix->bitlen,
return IPPrefix( true);
IPAddr(IPv6, reinterpret_cast<const uint32_t*>(&prefix->add.sin6), IPAddr::Network), }
prefix->bitlen, true);
}
void* PrefixTable::Insert(const IPAddr& addr, int width, void* data) void* PrefixTable::Insert(const IPAddr& addr, int width, void* data) {
{ prefix_t* prefix = MakePrefix(addr, width);
prefix_t* prefix = MakePrefix(addr, width); patricia_node_t* node = patricia_lookup(tree, prefix);
patricia_node_t* node = patricia_lookup(tree, prefix); Deref_Prefix(prefix);
Deref_Prefix(prefix);
if ( ! node ) if ( ! node ) {
{ reporter->InternalWarning("Cannot create node in patricia tree");
reporter->InternalWarning("Cannot create node in patricia tree"); return nullptr;
return nullptr; }
}
void* old = node->data; void* old = node->data;
// If there is no data to be associated with addr, we take the // If there is no data to be associated with addr, we take the
// node itself. // node itself.
node->data = data ? data : node; node->data = data ? data : node;
return old; return old;
} }
void* PrefixTable::Insert(const Val* value, void* data) void* PrefixTable::Insert(const Val* value, void* data) {
{ // [elem] -> elem
// [elem] -> elem if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 )
if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 ) value = value->AsListVal()->Idx(0).get();
value = value->AsListVal()->Idx(0).get();
switch ( value->GetType()->Tag() ) switch ( value->GetType()->Tag() ) {
{ case TYPE_ADDR: return Insert(value->AsAddr(), 128, data); break;
case TYPE_ADDR:
return Insert(value->AsAddr(), 128, data);
break;
case TYPE_SUBNET: case TYPE_SUBNET: return Insert(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6(), data); break;
return Insert(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6(), data);
break;
default: default: reporter->InternalWarning("Wrong index type for PrefixTable"); return nullptr;
reporter->InternalWarning("Wrong index type for PrefixTable"); }
return nullptr; }
}
}
std::list<std::tuple<IPPrefix, void*>> PrefixTable::FindAll(const IPAddr& addr, int width) const std::list<std::tuple<IPPrefix, void*>> PrefixTable::FindAll(const IPAddr& addr, int width) const {
{ std::list<std::tuple<IPPrefix, void*>> out;
std::list<std::tuple<IPPrefix, void*>> out; prefix_t* prefix = MakePrefix(addr, width);
prefix_t* prefix = MakePrefix(addr, width);
int elems = 0; int elems = 0;
patricia_node_t** list = nullptr; patricia_node_t** list = nullptr;
patricia_search_all(tree, prefix, &list, &elems); patricia_search_all(tree, prefix, &list, &elems);
for ( int i = 0; i < elems; ++i ) for ( int i = 0; i < elems; ++i )
out.emplace_back(PrefixToIPPrefix(list[i]->prefix), list[i]->data); out.emplace_back(PrefixToIPPrefix(list[i]->prefix), list[i]->data);
Deref_Prefix(prefix); Deref_Prefix(prefix);
free(list); free(list);
return out; return out;
} }
std::list<std::tuple<IPPrefix, void*>> PrefixTable::FindAll(const SubNetVal* value) const std::list<std::tuple<IPPrefix, void*>> PrefixTable::FindAll(const SubNetVal* value) const {
{ return FindAll(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6());
return FindAll(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6()); }
}
void* PrefixTable::Lookup(const IPAddr& addr, int width, bool exact) const void* PrefixTable::Lookup(const IPAddr& addr, int width, bool exact) const {
{ prefix_t* prefix = MakePrefix(addr, width);
prefix_t* prefix = MakePrefix(addr, width); patricia_node_t* node = exact ? patricia_search_exact(tree, prefix) : patricia_search_best(tree, prefix);
patricia_node_t* node = exact ? patricia_search_exact(tree, prefix)
: patricia_search_best(tree, prefix);
int elems = 0; int elems = 0;
patricia_node_t** list = nullptr; patricia_node_t** list = nullptr;
Deref_Prefix(prefix); Deref_Prefix(prefix);
return node ? node->data : nullptr; return node ? node->data : nullptr;
} }
void* PrefixTable::Lookup(const Val* value, bool exact) const void* PrefixTable::Lookup(const Val* value, bool exact) const {
{ // [elem] -> elem
// [elem] -> elem if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 )
if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 ) value = value->AsListVal()->Idx(0).get();
value = value->AsListVal()->Idx(0).get();
switch ( value->GetType()->Tag() ) switch ( value->GetType()->Tag() ) {
{ case TYPE_ADDR: return Lookup(value->AsAddr(), 128, exact); break;
case TYPE_ADDR:
return Lookup(value->AsAddr(), 128, exact);
break;
case TYPE_SUBNET: case TYPE_SUBNET: return Lookup(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6(), exact); break;
return Lookup(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6(), exact);
break;
default: default:
reporter->InternalWarning("Wrong index type %d for PrefixTable", reporter->InternalWarning("Wrong index type %d for PrefixTable", value->GetType()->Tag());
value->GetType()->Tag()); return nullptr;
return nullptr; }
} }
}
void* PrefixTable::Remove(const IPAddr& addr, int width) void* PrefixTable::Remove(const IPAddr& addr, int width) {
{ prefix_t* prefix = MakePrefix(addr, width);
prefix_t* prefix = MakePrefix(addr, width); patricia_node_t* node = patricia_search_exact(tree, prefix);
patricia_node_t* node = patricia_search_exact(tree, prefix); Deref_Prefix(prefix);
Deref_Prefix(prefix);
if ( ! node ) if ( ! node )
return nullptr; return nullptr;
void* old = node->data; void* old = node->data;
patricia_remove(tree, node); patricia_remove(tree, node);
return old; return old;
} }
void* PrefixTable::Remove(const Val* value) void* PrefixTable::Remove(const Val* value) {
{ // [elem] -> elem
// [elem] -> elem if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 )
if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 ) value = value->AsListVal()->Idx(0).get();
value = value->AsListVal()->Idx(0).get();
switch ( value->GetType()->Tag() ) switch ( value->GetType()->Tag() ) {
{ case TYPE_ADDR: return Remove(value->AsAddr(), 128); break;
case TYPE_ADDR:
return Remove(value->AsAddr(), 128);
break;
case TYPE_SUBNET: case TYPE_SUBNET: return Remove(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6()); break;
return Remove(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6());
break;
default: default: reporter->InternalWarning("Wrong index type for PrefixTable"); return nullptr;
reporter->InternalWarning("Wrong index type for PrefixTable"); }
return nullptr; }
}
}
PrefixTable::iterator PrefixTable::InitIterator() PrefixTable::iterator PrefixTable::InitIterator() {
{ iterator i;
iterator i; i.Xsp = i.Xstack;
i.Xsp = i.Xstack; i.Xrn = tree->head;
i.Xrn = tree->head; i.Xnode = nullptr;
i.Xnode = nullptr; return i;
return i; }
}
void* PrefixTable::GetNext(iterator* i) void* PrefixTable::GetNext(iterator* i) {
{ while ( true ) {
while ( true ) i->Xnode = i->Xrn;
{ if ( ! i->Xnode )
i->Xnode = i->Xrn; return nullptr;
if ( ! i->Xnode )
return nullptr;
if ( i->Xrn->l ) if ( i->Xrn->l ) {
{ if ( i->Xrn->r )
if ( i->Xrn->r ) *i->Xsp++ = i->Xrn->r;
*i->Xsp++ = i->Xrn->r;
i->Xrn = i->Xrn->l; i->Xrn = i->Xrn->l;
} }
else if ( i->Xrn->r ) else if ( i->Xrn->r )
i->Xrn = i->Xrn->r; i->Xrn = i->Xrn->r;
else if ( i->Xsp != i->Xstack ) else if ( i->Xsp != i->Xstack )
i->Xrn = *(--i->Xsp); i->Xrn = *(--i->Xsp);
else else
i->Xrn = (patricia_node_t*)nullptr; i->Xrn = (patricia_node_t*)nullptr;
if ( i->Xnode->prefix ) if ( i->Xnode->prefix )
return (void*)i->Xnode->data; return (void*)i->Xnode->data;
} }
// Not reached. // Not reached.
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -1,80 +1,74 @@
#pragma once #pragma once
extern "C" extern "C" {
{
#include "zeek/3rdparty/patricia.h" #include "zeek/3rdparty/patricia.h"
} }
#include <list> #include <list>
#include <tuple> #include <tuple>
#include "zeek/IPAddr.h" #include "zeek/IPAddr.h"
namespace zeek namespace zeek {
{
class Val; class Val;
class SubNetVal; class SubNetVal;
namespace detail namespace detail {
{
class PrefixTable class PrefixTable {
{
private: private:
struct iterator struct iterator {
{ patricia_node_t* Xstack[PATRICIA_MAXBITS + 1];
patricia_node_t* Xstack[PATRICIA_MAXBITS + 1]; patricia_node_t** Xsp;
patricia_node_t** Xsp; patricia_node_t* Xrn;
patricia_node_t* Xrn; patricia_node_t* Xnode;
patricia_node_t* Xnode; };
};
public: public:
PrefixTable() PrefixTable() {
{ tree = New_Patricia(128);
tree = New_Patricia(128); delete_function = nullptr;
delete_function = nullptr; }
} ~PrefixTable() { Destroy_Patricia(tree, delete_function); }
~PrefixTable() { Destroy_Patricia(tree, delete_function); }
// Addr in network byte order. If data is zero, acts like a set. // Addr in network byte order. If data is zero, acts like a set.
// Returns ptr to old data if already existing. // Returns ptr to old data if already existing.
// For existing items without data, returns non-nil if found. // For existing items without data, returns non-nil if found.
void* Insert(const IPAddr& addr, int width, void* data = nullptr); void* Insert(const IPAddr& addr, int width, void* data = nullptr);
// Value may be addr or subnet. // Value may be addr or subnet.
void* Insert(const Val* value, void* data = nullptr); void* Insert(const Val* value, void* data = nullptr);
// Returns nil if not found, pointer to data otherwise. // Returns nil if not found, pointer to data otherwise.
// For items without data, returns non-nil if found. // For items without data, returns non-nil if found.
// If exact is false, performs exact rather than longest-prefix match. // If exact is false, performs exact rather than longest-prefix match.
void* Lookup(const IPAddr& addr, int width, bool exact = false) const; void* Lookup(const IPAddr& addr, int width, bool exact = false) const;
void* Lookup(const Val* value, bool exact = false) const; void* Lookup(const Val* value, bool exact = false) const;
// Returns list of all found matches or empty list otherwise. // Returns list of all found matches or empty list otherwise.
std::list<std::tuple<IPPrefix, void*>> FindAll(const IPAddr& addr, int width) const; std::list<std::tuple<IPPrefix, void*>> FindAll(const IPAddr& addr, int width) const;
std::list<std::tuple<IPPrefix, void*>> FindAll(const SubNetVal* value) const; std::list<std::tuple<IPPrefix, void*>> FindAll(const SubNetVal* value) const;
// Returns pointer to data or nil if not found. // Returns pointer to data or nil if not found.
void* Remove(const IPAddr& addr, int width); void* Remove(const IPAddr& addr, int width);
void* Remove(const Val* value); void* Remove(const Val* value);
void Clear() { Clear_Patricia(tree, delete_function); } void Clear() { Clear_Patricia(tree, delete_function); }
// Sets a function to call for each node when table is cleared/destroyed. // Sets a function to call for each node when table is cleared/destroyed.
void SetDeleteFunction(data_fn_t del_fn) { delete_function = del_fn; } void SetDeleteFunction(data_fn_t del_fn) { delete_function = del_fn; }
iterator InitIterator(); iterator InitIterator();
void* GetNext(iterator* i); void* GetNext(iterator* i);
private: private:
static prefix_t* MakePrefix(const IPAddr& addr, int width); static prefix_t* MakePrefix(const IPAddr& addr, int width);
static IPPrefix PrefixToIPPrefix(prefix_t* p); static IPPrefix PrefixToIPPrefix(prefix_t* p);
patricia_tree_t* tree; patricia_tree_t* tree;
data_fn_t delete_function; data_fn_t delete_function;
}; };
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -10,133 +10,116 @@
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek::detail namespace zeek::detail {
{
PriorityQueue::PriorityQueue(int initial_size) : max_heap_size(initial_size) PriorityQueue::PriorityQueue(int initial_size) : max_heap_size(initial_size) { heap = new PQ_Element*[max_heap_size]; }
{
heap = new PQ_Element*[max_heap_size];
}
PriorityQueue::~PriorityQueue() PriorityQueue::~PriorityQueue() {
{ for ( int i = 0; i < heap_size; ++i )
for ( int i = 0; i < heap_size; ++i ) delete heap[i];
delete heap[i];
delete[] heap; delete[] heap;
} }
PQ_Element* PriorityQueue::Remove() PQ_Element* PriorityQueue::Remove() {
{ if ( heap_size == 0 )
if ( heap_size == 0 ) return nullptr;
return nullptr;
PQ_Element* top = heap[0]; PQ_Element* top = heap[0];
--heap_size; --heap_size;
SetElement(0, heap[heap_size]); SetElement(0, heap[heap_size]);
BubbleDown(0); BubbleDown(0);
top->SetOffset(-1); // = not in heap top->SetOffset(-1); // = not in heap
return top; return top;
} }
PQ_Element* PriorityQueue::Remove(PQ_Element* e) PQ_Element* PriorityQueue::Remove(PQ_Element* e) {
{ if ( e->Offset() < 0 || e->Offset() >= heap_size || heap[e->Offset()] != e )
if ( e->Offset() < 0 || e->Offset() >= heap_size || heap[e->Offset()] != e ) return nullptr; // not in heap
return nullptr; // not in heap
e->MinimizeTime(); e->MinimizeTime();
BubbleUp(e->Offset()); BubbleUp(e->Offset());
PQ_Element* e2 = Remove(); PQ_Element* e2 = Remove();
if ( e != e2 ) if ( e != e2 )
reporter->InternalError("inconsistency in PriorityQueue::Remove"); reporter->InternalError("inconsistency in PriorityQueue::Remove");
return e2; return e2;
} }
bool PriorityQueue::Add(PQ_Element* e) bool PriorityQueue::Add(PQ_Element* e) {
{ SetElement(heap_size, e);
SetElement(heap_size, e);
BubbleUp(heap_size); BubbleUp(heap_size);
++cumulative_num; ++cumulative_num;
if ( ++heap_size > peak_heap_size ) if ( ++heap_size > peak_heap_size )
peak_heap_size = heap_size; peak_heap_size = heap_size;
if ( heap_size >= max_heap_size ) if ( heap_size >= max_heap_size )
return Resize(max_heap_size * 2); return Resize(max_heap_size * 2);
else else
return true; return true;
} }
bool PriorityQueue::Resize(int new_size) bool PriorityQueue::Resize(int new_size) {
{ PQ_Element** tmp = new PQ_Element*[new_size];
PQ_Element** tmp = new PQ_Element*[new_size]; for ( int i = 0; i < max_heap_size; ++i )
for ( int i = 0; i < max_heap_size; ++i ) tmp[i] = heap[i];
tmp[i] = heap[i];
delete[] heap; delete[] heap;
heap = tmp; heap = tmp;
max_heap_size = new_size; max_heap_size = new_size;
return heap != nullptr; return heap != nullptr;
} }
void PriorityQueue::BubbleUp(int bin) void PriorityQueue::BubbleUp(int bin) {
{ if ( bin == 0 )
if ( bin == 0 ) return;
return;
int p = Parent(bin); int p = Parent(bin);
if ( heap[p]->Time() > heap[bin]->Time() ) if ( heap[p]->Time() > heap[bin]->Time() ) {
{ Swap(p, bin);
Swap(p, bin); BubbleUp(p);
BubbleUp(p); }
} }
}
void PriorityQueue::BubbleDown(int bin) void PriorityQueue::BubbleDown(int bin) {
{ double v = heap[bin]->Time();
double v = heap[bin]->Time();
int l = LeftChild(bin); int l = LeftChild(bin);
int r = RightChild(bin); int r = RightChild(bin);
if ( l >= heap_size ) if ( l >= heap_size )
return; // No children. return; // No children.
if ( r >= heap_size ) if ( r >= heap_size ) { // Just a left child.
{ // Just a left child. if ( heap[l]->Time() < v )
if ( heap[l]->Time() < v ) Swap(l, bin);
Swap(l, bin); }
}
else else {
{ double lv = heap[l]->Time();
double lv = heap[l]->Time(); double rv = heap[r]->Time();
double rv = heap[r]->Time();
if ( lv < rv ) if ( lv < rv ) {
{ if ( lv < v ) {
if ( lv < v ) Swap(l, bin);
{ BubbleDown(l);
Swap(l, bin); }
BubbleDown(l); }
}
}
else if ( rv < v ) else if ( rv < v ) {
{ Swap(r, bin);
Swap(r, bin); BubbleDown(r);
BubbleDown(r); }
} }
} }
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -7,91 +7,85 @@
#include <cmath> #include <cmath>
#include <cstdint> #include <cstdint>
namespace zeek::detail namespace zeek::detail {
{
class PriorityQueue; class PriorityQueue;
class PQ_Element class PQ_Element {
{
public: public:
explicit PQ_Element(double t) : time(t) { } explicit PQ_Element(double t) : time(t) {}
virtual ~PQ_Element() = default; virtual ~PQ_Element() = default;
double Time() const { return time; } double Time() const { return time; }
int Offset() const { return offset; } int Offset() const { return offset; }
void SetOffset(int off) { offset = off; } void SetOffset(int off) { offset = off; }
void MinimizeTime() { time = -HUGE_VAL; } void MinimizeTime() { time = -HUGE_VAL; }
protected: protected:
PQ_Element() = default; PQ_Element() = default;
double time = 0.0; double time = 0.0;
int offset = -1; int offset = -1;
}; };
class PriorityQueue class PriorityQueue {
{
public: public:
explicit PriorityQueue(int initial_size = 16); explicit PriorityQueue(int initial_size = 16);
~PriorityQueue(); ~PriorityQueue();
// Returns the top of queue, or nil if the queue is empty. // Returns the top of queue, or nil if the queue is empty.
PQ_Element* Top() const PQ_Element* Top() const {
{ if ( heap_size == 0 )
if ( heap_size == 0 ) return nullptr;
return nullptr;
return heap[0]; return heap[0];
} }
// Removes (and returns) top of queue. Returns nil if the queue // Removes (and returns) top of queue. Returns nil if the queue
// is empty. // is empty.
PQ_Element* Remove(); PQ_Element* Remove();
// Removes element e. Returns e, or nullptr if e wasn't in the queue. // Removes element e. Returns e, or nullptr if e wasn't in the queue.
// Note that e will be modified via MinimizeTime(). // Note that e will be modified via MinimizeTime().
PQ_Element* Remove(PQ_Element* e); PQ_Element* Remove(PQ_Element* e);
// Add a new element to the queue. Returns false on failure (not enough // Add a new element to the queue. Returns false on failure (not enough
// memory to add the element), true on success. // memory to add the element), true on success.
bool Add(PQ_Element* e); bool Add(PQ_Element* e);
int Size() const { return heap_size; } int Size() const { return heap_size; }
int PeakSize() const { return peak_heap_size; } int PeakSize() const { return peak_heap_size; }
uint64_t CumulativeNum() const { return cumulative_num; } uint64_t CumulativeNum() const { return cumulative_num; }
protected: protected:
bool Resize(int new_size); bool Resize(int new_size);
void BubbleUp(int bin); void BubbleUp(int bin);
void BubbleDown(int bin); void BubbleDown(int bin);
int Parent(int bin) const { return bin >> 1; } int Parent(int bin) const { return bin >> 1; }
int LeftChild(int bin) const { return bin << 1; } int LeftChild(int bin) const { return bin << 1; }
int RightChild(int bin) const { return LeftChild(bin) + 1; } int RightChild(int bin) const { return LeftChild(bin) + 1; }
void SetElement(int bin, PQ_Element* e) void SetElement(int bin, PQ_Element* e) {
{ heap[bin] = e;
heap[bin] = e; e->SetOffset(bin);
e->SetOffset(bin); }
}
void Swap(int bin1, int bin2) void Swap(int bin1, int bin2) {
{ PQ_Element* t = heap[bin1];
PQ_Element* t = heap[bin1]; SetElement(bin1, heap[bin2]);
SetElement(bin1, heap[bin2]); SetElement(bin2, t);
SetElement(bin2, t); }
}
PQ_Element** heap = nullptr; PQ_Element** heap = nullptr;
int heap_size = 0; int heap_size = 0;
int peak_heap_size = 0; int peak_heap_size = 0;
int max_heap_size = 0; int max_heap_size = 0;
uint64_t cumulative_num = 0; uint64_t cumulative_num = 0;
}; };
} // namespace zeek::detail } // namespace zeek::detail

Some files were not shown because too many files have changed in this diff Show more