This commit is contained in:
Johanna Amann 2023-11-06 09:42:46 +00:00
parent 0afe94154d
commit 36741c2fbf
787 changed files with 131811 additions and 153704 deletions

View file

@ -1,74 +1,66 @@
# Clang-format configuration for Zeek. This configuration requires # Copyright (c) 2020-2023 by the Zeek Project. See LICENSE for details.
# at least clang-format 12.0.1 to format correctly.
---
Language: Cpp Language: Cpp
Standard: c++17
BreakBeforeBraces: Whitesmiths
# BraceWrapping:
# AfterCaseLabel: true
# AfterClass: false
# AfterControlStatement: Always
# AfterEnum: false
# AfterFunction: true
# AfterNamespace: false
# AfterStruct: false
# AfterUnion: false
# AfterExternBlock: false
# BeforeCatch: true
# BeforeElse: true
# BeforeWhile: false
# IndentBraces: true
# SplitEmptyFunction: false
# SplitEmptyRecord: false
# SplitEmptyNamespace: false
AccessModifierOffset: -4 AccessModifierOffset: -4
AlignAfterOpenBracket: Align AlignAfterOpenBracket: Align
AlignTrailingComments: false AlignConsecutiveAssignments: false
AllowShortBlocksOnASingleLine: Empty AlignConsecutiveDeclarations: false
AllowShortEnumsOnASingleLine: true AlignEscapedNewlines: Right
AllowShortFunctionsOnASingleLine: Inline AlignOperands: true
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: true
AllowShortIfStatementsOnASingleLine: false AllowShortIfStatementsOnASingleLine: false
AllowShortLambdasOnASingleLine: Empty
AllowShortLoopsOnASingleLine: false AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: Yes
BinPackArguments: true BinPackArguments: true
BinPackParameters: true BinPackParameters: true
BreakConstructorInitializers: BeforeColon BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
AfterExternBlock: false
BeforeCatch: false
BeforeElse: true
IndentBraces: false
SplitEmptyFunction: false
SplitEmptyRecord: false
SplitEmptyNamespace: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Custom
BreakBeforeInheritanceComma: false
BreakInheritanceList: BeforeColon BreakInheritanceList: BeforeColon
ColumnLimit: 100 BreakBeforeTernaryOperators: false
ConstructorInitializerAllOnOneLineOrOnePerLine: false BreakConstructorInitializersBeforeComma: false
FixNamespaceComments: false BreakConstructorInitializers: BeforeColon
IndentCaseLabels: true BreakAfterJavaFieldAnnotations: false
IndentCaseBlocks: false BreakStringLiterals: true
IndentExternBlock: NoIndent ColumnLimit: 120
IndentPPDirectives: None CommentPragmas: 'NOLINT'
IndentWidth: 4 CompactNamespaces: false
NamespaceIndentation: None ConstructorInitializerAllOnOneLineOrOnePerLine: true
PointerAlignment: Left ConstructorInitializerIndentWidth: 4
SpaceAfterCStyleCast: false ContinuationIndentWidth: 4
SpaceAfterLogicalNot: true Cpp11BracedListStyle: true
SpaceBeforeAssignmentOperators: true DerivePointerAlignment: false
SpaceBeforeCpp11BracedList: false DisableFormat: false
SpaceBeforeCtorInitializerColon: true ExperimentalAutoDetectBinPacking: false
SpaceBeforeInheritanceColon: true FixNamespaceComments: true
SpaceBeforeParens: ControlStatements ForEachMacros:
SpaceBeforeRangeBasedForLoopColon: true - foreach
SpaceInEmptyBlock: true - Q_FOREACH
SpaceInEmptyParentheses: false - BOOST_FOREACH
SpacesInAngles: false
SpacesInConditionalStatement: true
SpacesInContainerLiterals: false
SpacesInParentheses: false
TabWidth: 4
UseTab: AlignWithSpaces
# Setting this to a high number causes clang-format to prefer breaking somewhere else
# over breaking after the assignment operator in a line that's over the column limit
PenaltyBreakAssignment: 100
IncludeBlocks: Regroup IncludeBlocks: Regroup
# Include categories go like this: # Include categories go like this:
@ -98,3 +90,57 @@ IncludeCategories:
Priority: 4 Priority: 4
- Regex: '.*' - Regex: '.*'
Priority: 5 Priority: 5
IncludeIsMainRegex: '$'
IndentCaseLabels: true
IndentPPDirectives: None
IndentWidth: 4
IndentWrappedFunctionNames: false
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: '^BEGIN_'
MacroBlockEnd: '^END_'
MaxEmptyLinesToKeep: 2
NamespaceIndentation: None
ObjCBinPackProtocolList: Auto
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: true
PenaltyBreakAssignment: 2
PenaltyBreakBeforeFirstCallParameter: 500
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 1000
PointerAlignment: Left
ReflowComments: true
SortIncludes: true
SortUsingDeclarations: true
SpaceAfterCStyleCast: false
SpaceAfterTemplateKeyword: false
SpaceAfterLogicalNot: true
SpaceBeforeAssignmentOperators: true
SpaceBeforeCpp11BracedList: false
SpaceBeforeCtorInitializerColon: true
SpaceBeforeInheritanceColon: true
SpaceBeforeParens: ControlStatements
SpaceBeforeRangeBasedForLoopColon: true
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
SpacesInConditionalStatement: true
Standard: Cpp11
StatementMacros:
- STANDARD_OPERATOR_1
TabWidth: 4
UseTab: Never
---
Language: Json
...

View file

@ -3,9 +3,13 @@
# #
repos: repos:
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: 'v13.0.0' rev: 'v17.0.3'
hooks: hooks:
- id: clang-format - id: clang-format
types_or:
- "c"
- "c++"
- "json"
- repo: https://github.com/maxwinterstein/shfmt-py - repo: https://github.com/maxwinterstein/shfmt-py
rev: v3.7.0.1 rev: v3.7.0.1
@ -14,7 +18,7 @@ repos:
args: ["-w", "-i", "4", "-ci"] args: ["-w", "-i", "4", "-ci"]
- repo: https://github.com/google/yapf - repo: https://github.com/google/yapf
rev: v0.40.0 rev: v0.40.2
hooks: hooks:
- id: yapf - id: yapf
@ -25,7 +29,7 @@ repos:
exclude: '^auxil/.*$' exclude: '^auxil/.*$'
- repo: https://github.com/crate-ci/typos - repo: https://github.com/crate-ci/typos
rev: v1.16.8 rev: v1.16.21
hooks: hooks:
- id: typos - id: typos
exclude: '^(.typos.toml|src/SmithWaterman.cc|testing/.*|auxil/.*|scripts/base/frameworks/files/magic/.*|CHANGES)$' exclude: '^(.typos.toml|src/SmithWaterman.cc|testing/.*|auxil/.*|scripts/base/frameworks/files/magic/.*|CHANGES)$'

View file

@ -67,7 +67,6 @@ uses_seh = "uses_seh"
[default.extend-words] [default.extend-words]
caf = "caf" caf = "caf"
helo = "helo" helo = "helo"
inout = "inout"
# Seems we use this in the management framework # Seems we use this in the management framework
requestor = "requestor" requestor = "requestor"
# `inout` is used as a keyword in Spicy, but looks like a typo of `input`. # `inout` is used as a keyword in Spicy, but looks like a typo of `input`.

View file

@ -15,25 +15,20 @@
#include "zeek/net_util.h" #include "zeek/net_util.h"
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek::detail namespace zeek::detail {
{
AnonymizeIPAddr* ip_anonymizer[NUM_ADDR_ANONYMIZATION_METHODS] = {nullptr}; AnonymizeIPAddr* ip_anonymizer[NUM_ADDR_ANONYMIZATION_METHODS] = {nullptr};
static uint32_t rand32() static uint32_t rand32() {
{ return ((util::detail::random_number() & 0xffff) << 16) | (util::detail::random_number() & 0xffff);
return ((util::detail::random_number() & 0xffff) << 16) | }
(util::detail::random_number() & 0xffff);
}
// From tcpdpriv. // From tcpdpriv.
static int bi_ffs(uint32_t value) static int bi_ffs(uint32_t value) {
{
int add = 0; int add = 0;
static uint8_t bvals[] = {0, 4, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1}; static uint8_t bvals[] = {0, 4, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1};
if ( (value & 0xFFFF0000) == 0 ) if ( (value & 0xFFFF0000) == 0 ) {
{
if ( value == 0 ) if ( value == 0 )
// Zero input ==> zero output. // Zero input ==> zero output.
return 0; return 0;
@ -55,54 +50,43 @@ static int bi_ffs(uint32_t value)
value >>= 4; value >>= 4;
return add + bvals[value & 0xf]; return add + bvals[value & 0xf];
} }
#define first_n_bit_mask(n) (~(0xFFFFFFFFU >> n)) #define first_n_bit_mask(n) (~(0xFFFFFFFFU >> n))
ipaddr32_t AnonymizeIPAddr::Anonymize(ipaddr32_t addr) ipaddr32_t AnonymizeIPAddr::Anonymize(ipaddr32_t addr) {
{
std::map<ipaddr32_t, ipaddr32_t>::iterator p = mapping.find(addr); std::map<ipaddr32_t, ipaddr32_t>::iterator p = mapping.find(addr);
if ( p != mapping.end() ) if ( p != mapping.end() )
return p->second; return p->second;
else else {
{
ipaddr32_t new_addr = anonymize(addr); ipaddr32_t new_addr = anonymize(addr);
mapping[addr] = new_addr; mapping[addr] = new_addr;
return new_addr; return new_addr;
} }
} }
// Keep the specified prefix unchanged. // Keep the specified prefix unchanged.
bool AnonymizeIPAddr::PreservePrefix(ipaddr32_t /* input */, int /* num_bits */) bool AnonymizeIPAddr::PreservePrefix(ipaddr32_t /* input */, int /* num_bits */) {
{
reporter->InternalError("prefix preserving is not supported for the anonymizer"); reporter->InternalError("prefix preserving is not supported for the anonymizer");
return false; return false;
} }
bool AnonymizeIPAddr::PreserveNet(ipaddr32_t input) bool AnonymizeIPAddr::PreserveNet(ipaddr32_t input) {
{ switch ( addr_to_class(ntohl(input)) ) {
switch ( addr_to_class(ntohl(input)) ) case 'A': return PreservePrefix(input, 8);
{ case 'B': return PreservePrefix(input, 16);
case 'A': case 'C': return PreservePrefix(input, 24);
return PreservePrefix(input, 8); default: return false;
case 'B':
return PreservePrefix(input, 16);
case 'C':
return PreservePrefix(input, 24);
default:
return false;
}
} }
}
ipaddr32_t AnonymizeIPAddr_Seq::anonymize(ipaddr32_t /* input */) ipaddr32_t AnonymizeIPAddr_Seq::anonymize(ipaddr32_t /* input */) {
{
++seq; ++seq;
return htonl(seq); return htonl(seq);
} }
ipaddr32_t AnonymizeIPAddr_RandomMD5::anonymize(ipaddr32_t input) ipaddr32_t AnonymizeIPAddr_RandomMD5::anonymize(ipaddr32_t input) {
{
uint8_t digest[16]; uint8_t digest[16];
ipaddr32_t output = 0; ipaddr32_t output = 0;
@ -112,22 +96,20 @@ ipaddr32_t AnonymizeIPAddr_RandomMD5::anonymize(ipaddr32_t input)
output = (output << 8) | digest[i]; output = (output << 8) | digest[i];
return output; return output;
} }
// This code is from "On the Design and Performance of Prefix-Preserving // This code is from "On the Design and Performance of Prefix-Preserving
// IP Traffic Trace Anonymization", by Xu et al (IMW 2001) // IP Traffic Trace Anonymization", by Xu et al (IMW 2001)
// //
// http://www.imconf.net/imw-2001/proceedings.html // http://www.imconf.net/imw-2001/proceedings.html
ipaddr32_t AnonymizeIPAddr_PrefixMD5::anonymize(ipaddr32_t input) ipaddr32_t AnonymizeIPAddr_PrefixMD5::anonymize(ipaddr32_t input) {
{
uint8_t digest[16]; uint8_t digest[16];
ipaddr32_t prefix_mask = 0xffffffff; ipaddr32_t prefix_mask = 0xffffffff;
input = ntohl(input); input = ntohl(input);
ipaddr32_t output = input; ipaddr32_t output = input;
for ( int i = 0; i < 32; ++i ) for ( int i = 0; i < 32; ++i ) {
{
// PAD(x_0 ... x_{i-1}) = x_0 ... x_{i-1} 1 0 ... 0 . // PAD(x_0 ... x_{i-1}) = x_0 ... x_{i-1} 1 0 ... 0 .
prefix.len = htonl(i + 1); prefix.len = htonl(i + 1);
prefix.prefix = htonl((input & ~(prefix_mask >> i)) | (1 << (31 - i))); prefix.prefix = htonl((input & ~(prefix_mask >> i)) | (1 << (31 - i)));
@ -143,16 +125,14 @@ ipaddr32_t AnonymizeIPAddr_PrefixMD5::anonymize(ipaddr32_t input)
} }
return htonl(output); return htonl(output);
} }
AnonymizeIPAddr_A50::~AnonymizeIPAddr_A50() AnonymizeIPAddr_A50::~AnonymizeIPAddr_A50() {
{
for ( auto& b : blocks ) for ( auto& b : blocks )
delete[] b; delete[] b;
} }
void AnonymizeIPAddr_A50::init() void AnonymizeIPAddr_A50::init() {
{
root = next_free_node = nullptr; root = next_free_node = nullptr;
// Prepare special nodes for 0.0.0.0 and 255.255.255.255. // Prepare special nodes for 0.0.0.0 and 255.255.255.255.
@ -163,14 +143,12 @@ void AnonymizeIPAddr_A50::init()
method = 0; method = 0;
before_anonymization = 1; before_anonymization = 1;
new_mapping = 0; new_mapping = 0;
} }
bool AnonymizeIPAddr_A50::PreservePrefix(ipaddr32_t input, int num_bits) bool AnonymizeIPAddr_A50::PreservePrefix(ipaddr32_t input, int num_bits) {
{
DEBUG_MSG("%s/%d\n", IPAddr(IPv4, &input, IPAddr::Network).AsString().c_str(), num_bits); DEBUG_MSG("%s/%d\n", IPAddr(IPv4, &input, IPAddr::Network).AsString().c_str(), num_bits);
if ( ! before_anonymization ) if ( ! before_anonymization ) {
{
reporter->Error("prefix preservation specified after anonymization begun"); reporter->Error("prefix preservation specified after anonymization begun");
return false; return false;
} }
@ -186,8 +164,7 @@ bool AnonymizeIPAddr_A50::PreservePrefix(ipaddr32_t input, int num_bits)
if ( num_bits == 32 ) if ( num_bits == 32 )
n->output = input; n->output = input;
else if ( num_bits > 0 ) else if ( num_bits > 0 ) {
{
assert((0xFFFFFFFFU >> 1) == 0x7FFFFFFFU); assert((0xFFFFFFFFU >> 1) == 0x7FFFFFFFU);
uint32_t suffix_mask = (0xFFFFFFFFU >> num_bits); uint32_t suffix_mask = (0xFFFFFFFFU >> num_bits);
uint32_t prefix_mask = ~suffix_mask; uint32_t prefix_mask = ~suffix_mask;
@ -195,24 +172,21 @@ bool AnonymizeIPAddr_A50::PreservePrefix(ipaddr32_t input, int num_bits)
} }
return true; return true;
} }
ipaddr32_t AnonymizeIPAddr_A50::anonymize(ipaddr32_t a) ipaddr32_t AnonymizeIPAddr_A50::anonymize(ipaddr32_t a) {
{
before_anonymization = 0; before_anonymization = 0;
new_mapping = 0; new_mapping = 0;
if ( Node* n = find_node(ntohl(a)) ) if ( Node* n = find_node(ntohl(a)) ) {
{
ipaddr32_t output = htonl(n->output); ipaddr32_t output = htonl(n->output);
return output; return output;
} }
else else
return 0; return 0;
} }
AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node_block() AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node_block() {
{
assert(! next_free_node); assert(! next_free_node);
int block_size = 1024; int block_size = 1024;
@ -229,45 +203,39 @@ AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node_block()
next_free_node = &block[1]; next_free_node = &block[1];
return &block[0]; return &block[0];
} }
inline AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node() inline AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node() {
{
new_mapping = 1; new_mapping = 1;
if ( next_free_node ) if ( next_free_node ) {
{
Node* n = next_free_node; Node* n = next_free_node;
next_free_node = n->child[0]; next_free_node = n->child[0];
return n; return n;
} }
else else
return new_node_block(); return new_node_block();
} }
inline void AnonymizeIPAddr_A50::free_node(Node* n) inline void AnonymizeIPAddr_A50::free_node(Node* n) {
{
n->child[0] = next_free_node; n->child[0] = next_free_node;
next_free_node = n; next_free_node = n;
} }
ipaddr32_t AnonymizeIPAddr_A50::make_output(ipaddr32_t old_output, int swivel) const ipaddr32_t AnonymizeIPAddr_A50::make_output(ipaddr32_t old_output, int swivel) const {
{
// -A50 anonymization // -A50 anonymization
if ( swivel == 32 ) if ( swivel == 32 )
return old_output ^ 1; return old_output ^ 1;
else else {
{
// Bits up to swivel are unchanged; bit swivel is flipped. // Bits up to swivel are unchanged; bit swivel is flipped.
ipaddr32_t known_part = ((old_output >> (32 - swivel)) ^ 1) << (32 - swivel); ipaddr32_t known_part = ((old_output >> (32 - swivel)) ^ 1) << (32 - swivel);
// Remainder of bits are random. // Remainder of bits are random.
return known_part | ((rand32() & 0x7FFFFFFF) >> swivel); return known_part | ((rand32() & 0x7FFFFFFF) >> swivel);
} }
} }
AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::make_peer(ipaddr32_t a, Node* n) AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::make_peer(ipaddr32_t a, Node* n) {
{
if ( a == 0 || a == 0xFFFFFFFFU ) if ( a == 0 || a == 0xFFFFFFFFU )
reporter->InternalError("0.0.0.0 and 255.255.255.255 should never get into the tree"); reporter->InternalError("0.0.0.0 and 255.255.255.255 should never get into the tree");
@ -281,8 +249,7 @@ AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::make_peer(ipaddr32_t a, Node* n)
return nullptr; return nullptr;
down[1] = new_node(); down[1] = new_node();
if ( ! down[1] ) if ( ! down[1] ) {
{
free_node(down[0]); free_node(down[0]);
return nullptr; return nullptr;
} }
@ -305,17 +272,15 @@ AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::make_peer(ipaddr32_t a, Node* n)
n->child[1] = down[1]; n->child[1] = down[1];
return down[bitvalue]; return down[bitvalue];
} }
AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::find_node(ipaddr32_t a) AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::find_node(ipaddr32_t a) {
{
// Watch out for special IP addresses, which never make it // Watch out for special IP addresses, which never make it
// into the tree. // into the tree.
if ( a == 0 || a == 0xFFFFFFFFU ) if ( a == 0 || a == 0xFFFFFFFFU )
return &special_nodes[a & 1]; return &special_nodes[a & 1];
if ( ! root ) if ( ! root ) {
{
root = new_node(); root = new_node();
root->input = a; root->input = a;
root->output = rand32(); root->output = rand32();
@ -326,16 +291,14 @@ AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::find_node(ipaddr32_t a)
// Straight from tcpdpriv. // Straight from tcpdpriv.
Node* n = root; Node* n = root;
while ( n ) while ( n ) {
{
if ( n->input == a ) if ( n->input == a )
return n; return n;
if ( ! n->child[0] ) if ( ! n->child[0] )
n = make_peer(a, n); n = make_peer(a, n);
else else {
{
// swivel is the first bit in which the two children // swivel is the first bit in which the two children
// differ. // differ.
int swivel = bi_ffs(n->child[0]->input ^ n->child[1]->input); int swivel = bi_ffs(n->child[0]->input ^ n->child[1]->input);
@ -354,14 +317,13 @@ AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::find_node(ipaddr32_t a)
reporter->InternalError("out of memory!"); reporter->InternalError("out of memory!");
return nullptr; return nullptr;
} }
static TableValPtr anon_preserve_orig_addr; static TableValPtr anon_preserve_orig_addr;
static TableValPtr anon_preserve_resp_addr; static TableValPtr anon_preserve_resp_addr;
static TableValPtr anon_preserve_other_addr; static TableValPtr anon_preserve_other_addr;
void init_ip_addr_anonymizers() void init_ip_addr_anonymizers() {
{
ip_anonymizer[KEEP_ORIG_ADDR] = nullptr; ip_anonymizer[KEEP_ORIG_ADDR] = nullptr;
ip_anonymizer[SEQUENTIALLY_NUMBERED] = new AnonymizeIPAddr_Seq(); ip_anonymizer[SEQUENTIALLY_NUMBERED] = new AnonymizeIPAddr_Seq();
ip_anonymizer[RANDOM_MD5] = new AnonymizeIPAddr_RandomMD5(); ip_anonymizer[RANDOM_MD5] = new AnonymizeIPAddr_RandomMD5();
@ -382,17 +344,15 @@ void init_ip_addr_anonymizers()
if ( id ) if ( id )
anon_preserve_other_addr = cast_intrusive<TableVal>(id->GetVal()); anon_preserve_other_addr = cast_intrusive<TableVal>(id->GetVal());
} }
ipaddr32_t anonymize_ip(ipaddr32_t ip, enum ip_addr_anonymization_class_t cl) ipaddr32_t anonymize_ip(ipaddr32_t ip, enum ip_addr_anonymization_class_t cl) {
{
TableVal* preserve_addr = nullptr; TableVal* preserve_addr = nullptr;
auto addr = make_intrusive<AddrVal>(ip); auto addr = make_intrusive<AddrVal>(ip);
int method = -1; int method = -1;
switch ( cl ) switch ( cl ) {
{
case ORIG_ADDR: // client address case ORIG_ADDR: // client address
preserve_addr = anon_preserve_orig_addr.get(); preserve_addr = anon_preserve_orig_addr.get();
method = orig_addr_anonymization; method = orig_addr_anonymization;
@ -414,8 +374,7 @@ ipaddr32_t anonymize_ip(ipaddr32_t ip, enum ip_addr_anonymization_class_t cl)
if ( preserve_addr && preserve_addr->FindOrDefault(addr) ) if ( preserve_addr && preserve_addr->FindOrDefault(addr) )
new_ip = ip; new_ip = ip;
else if ( method >= 0 && method < NUM_ADDR_ANONYMIZATION_METHODS ) else if ( method >= 0 && method < NUM_ADDR_ANONYMIZATION_METHODS ) {
{
if ( method == KEEP_ORIG_ADDR ) if ( method == KEEP_ORIG_ADDR )
new_ip = ip; new_ip = ip;
@ -433,17 +392,15 @@ ipaddr32_t anonymize_ip(ipaddr32_t ip, enum ip_addr_anonymization_class_t cl)
log_anonymization_mapping(ip, new_ip); log_anonymization_mapping(ip, new_ip);
#endif #endif
return new_ip; return new_ip;
} }
#ifdef LOG_ANONYMIZATION_MAPPING #ifdef LOG_ANONYMIZATION_MAPPING
void log_anonymization_mapping(ipaddr32_t input, ipaddr32_t output) void log_anonymization_mapping(ipaddr32_t input, ipaddr32_t output) {
{
if ( anonymization_mapping ) if ( anonymization_mapping )
event_mgr.Enqueue(anonymization_mapping, make_intrusive<AddrVal>(input), event_mgr.Enqueue(anonymization_mapping, make_intrusive<AddrVal>(input), make_intrusive<AddrVal>(output));
make_intrusive<AddrVal>(output)); }
}
#endif #endif
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -14,36 +14,32 @@
#include <map> #include <map>
#include <vector> #include <vector>
namespace zeek::detail namespace zeek::detail {
{
// TODO: Anon.h may not be the right place to put these functions ... // TODO: Anon.h may not be the right place to put these functions ...
enum ip_addr_anonymization_class_t enum ip_addr_anonymization_class_t {
{
ORIG_ADDR, // client address ORIG_ADDR, // client address
RESP_ADDR, // server address RESP_ADDR, // server address
OTHER_ADDR, OTHER_ADDR,
NUM_ADDR_ANONYMIZATION_CLASSES, NUM_ADDR_ANONYMIZATION_CLASSES,
}; };
enum ip_addr_anonymization_method_t enum ip_addr_anonymization_method_t {
{
KEEP_ORIG_ADDR, KEEP_ORIG_ADDR,
SEQUENTIALLY_NUMBERED, SEQUENTIALLY_NUMBERED,
RANDOM_MD5, RANDOM_MD5,
PREFIX_PRESERVING_A50, PREFIX_PRESERVING_A50,
PREFIX_PRESERVING_MD5, PREFIX_PRESERVING_MD5,
NUM_ADDR_ANONYMIZATION_METHODS, NUM_ADDR_ANONYMIZATION_METHODS,
}; };
using ipaddr32_t = uint32_t; using ipaddr32_t = uint32_t;
// NOTE: all addresses in parameters of *public* functions are in // NOTE: all addresses in parameters of *public* functions are in
// network order. // network order.
class AnonymizeIPAddr class AnonymizeIPAddr {
{
public: public:
virtual ~AnonymizeIPAddr() = default; virtual ~AnonymizeIPAddr() = default;
@ -57,39 +53,34 @@ public:
protected: protected:
std::map<ipaddr32_t, ipaddr32_t> mapping; std::map<ipaddr32_t, ipaddr32_t> mapping;
}; };
class AnonymizeIPAddr_Seq : public AnonymizeIPAddr class AnonymizeIPAddr_Seq : public AnonymizeIPAddr {
{
public: public:
AnonymizeIPAddr_Seq() { seq = 1; } AnonymizeIPAddr_Seq() { seq = 1; }
ipaddr32_t anonymize(ipaddr32_t addr) override; ipaddr32_t anonymize(ipaddr32_t addr) override;
protected: protected:
ipaddr32_t seq; ipaddr32_t seq;
}; };
class AnonymizeIPAddr_RandomMD5 : public AnonymizeIPAddr class AnonymizeIPAddr_RandomMD5 : public AnonymizeIPAddr {
{
public: public:
ipaddr32_t anonymize(ipaddr32_t addr) override; ipaddr32_t anonymize(ipaddr32_t addr) override;
}; };
class AnonymizeIPAddr_PrefixMD5 : public AnonymizeIPAddr class AnonymizeIPAddr_PrefixMD5 : public AnonymizeIPAddr {
{
public: public:
ipaddr32_t anonymize(ipaddr32_t addr) override; ipaddr32_t anonymize(ipaddr32_t addr) override;
protected: protected:
struct anon_prefix struct anon_prefix {
{
int len; int len;
ipaddr32_t prefix; ipaddr32_t prefix;
} prefix; } prefix;
}; };
class AnonymizeIPAddr_A50 : public AnonymizeIPAddr class AnonymizeIPAddr_A50 : public AnonymizeIPAddr {
{
public: public:
AnonymizeIPAddr_A50() { init(); } AnonymizeIPAddr_A50() { init(); }
~AnonymizeIPAddr_A50() override; ~AnonymizeIPAddr_A50() override;
@ -98,8 +89,7 @@ public:
bool PreservePrefix(ipaddr32_t input, int num_bits) override; bool PreservePrefix(ipaddr32_t input, int num_bits) override;
protected: protected:
struct Node struct Node {
{
ipaddr32_t input; ipaddr32_t input;
ipaddr32_t output; ipaddr32_t output;
Node* child[2]; Node* child[2];
@ -128,7 +118,7 @@ protected:
ipaddr32_t make_output(ipaddr32_t, int) const; ipaddr32_t make_output(ipaddr32_t, int) const;
Node* make_peer(ipaddr32_t, Node*); Node* make_peer(ipaddr32_t, Node*);
Node* find_node(ipaddr32_t); Node* find_node(ipaddr32_t);
}; };
// The global IP anonymizers. // The global IP anonymizers.
extern AnonymizeIPAddr* ip_anonymizer[NUM_ADDR_ANONYMIZATION_METHODS]; extern AnonymizeIPAddr* ip_anonymizer[NUM_ADDR_ANONYMIZATION_METHODS];
@ -139,4 +129,4 @@ ipaddr32_t anonymize_ip(ipaddr32_t ip, enum ip_addr_anonymization_class_t cl);
#define LOG_ANONYMIZATION_MAPPING #define LOG_ANONYMIZATION_MAPPING
void log_anonymization_mapping(ipaddr32_t input, ipaddr32_t output); void log_anonymization_mapping(ipaddr32_t input, ipaddr32_t output);
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -11,11 +11,9 @@
#include "zeek/input/Manager.h" #include "zeek/input/Manager.h"
#include "zeek/threading/SerialTypes.h" #include "zeek/threading/SerialTypes.h"
namespace zeek::detail namespace zeek::detail {
{
const char* attr_name(AttrTag t) const char* attr_name(AttrTag t) {
{
// Do not collapse the list. // Do not collapse the list.
// clang-format off // clang-format off
static const char* attr_names[int(NUM_ATTRS)] = { static const char* attr_names[int(NUM_ATTRS)] = {
@ -48,23 +46,18 @@ const char* attr_name(AttrTag t)
// clang-format on // clang-format on
return attr_names[int(t)]; return attr_names[int(t)];
} }
Attr::Attr(AttrTag t, ExprPtr e) : expr(std::move(e)) Attr::Attr(AttrTag t, ExprPtr e) : expr(std::move(e)) {
{
tag = t; tag = t;
SetLocationInfo(&start_location, &end_location); SetLocationInfo(&start_location, &end_location);
} }
Attr::Attr(AttrTag t) : Attr(t, nullptr) { } Attr::Attr(AttrTag t) : Attr(t, nullptr) {}
void Attr::SetAttrExpr(ExprPtr e) void Attr::SetAttrExpr(ExprPtr e) { expr = std::move(e); }
{
expr = std::move(e);
}
std::string Attr::DeprecationMessage() const std::string Attr::DeprecationMessage() const {
{
if ( tag != ATTR_DEPRECATED ) if ( tag != ATTR_DEPRECATED )
return ""; return "";
@ -73,42 +66,35 @@ std::string Attr::DeprecationMessage() const
auto ce = static_cast<ConstExpr*>(expr.get()); auto ce = static_cast<ConstExpr*>(expr.get());
return ce->Value()->AsStringVal()->CheckString(); return ce->Value()->AsStringVal()->CheckString();
} }
void Attr::Describe(ODesc* d) const void Attr::Describe(ODesc* d) const {
{
AddTag(d); AddTag(d);
if ( expr ) if ( expr ) {
{
if ( ! d->IsBinary() ) if ( ! d->IsBinary() )
d->Add("="); d->Add("=");
expr->Describe(d); expr->Describe(d);
} }
} }
void Attr::DescribeReST(ODesc* d, bool shorten) const void Attr::DescribeReST(ODesc* d, bool shorten) const {
{ auto add_long_expr_string = [](ODesc* d, const std::string& s, bool shorten) {
auto add_long_expr_string = [](ODesc* d, const std::string& s, bool shorten)
{
constexpr auto max_expr_chars = 32; constexpr auto max_expr_chars = 32;
constexpr auto shortened_expr = "*...*"; constexpr auto shortened_expr = "*...*";
if ( s.size() > max_expr_chars ) if ( s.size() > max_expr_chars ) {
{
if ( shorten ) if ( shorten )
d->Add(shortened_expr); d->Add(shortened_expr);
else else {
{
// Long inline-literals likely won't wrap well in HTML render // Long inline-literals likely won't wrap well in HTML render
d->Add("*"); d->Add("*");
d->Add(s); d->Add(s);
d->Add("*"); d->Add("*");
} }
} }
else else {
{
d->Add("``"); d->Add("``");
d->Add(s); d->Add(s);
d->Add("``"); d->Add("``");
@ -119,28 +105,24 @@ void Attr::DescribeReST(ODesc* d, bool shorten) const
AddTag(d); AddTag(d);
d->Add("`"); d->Add("`");
if ( expr ) if ( expr ) {
{
d->SP(); d->SP();
d->Add("="); d->Add("=");
d->SP(); d->SP();
if ( expr->Tag() == EXPR_NAME ) if ( expr->Tag() == EXPR_NAME ) {
{
d->Add(":zeek:see:`"); d->Add(":zeek:see:`");
expr->Describe(d); expr->Describe(d);
d->Add("`"); d->Add("`");
} }
else if ( expr->GetType()->Tag() == TYPE_FUNC ) else if ( expr->GetType()->Tag() == TYPE_FUNC ) {
{
d->Add(":zeek:type:`"); d->Add(":zeek:type:`");
d->Add(expr->GetType()->AsFuncType()->FlavorString()); d->Add(expr->GetType()->AsFuncType()->FlavorString());
d->Add("`"); d->Add("`");
} }
else if ( expr->Tag() == EXPR_CONST ) else if ( expr->Tag() == EXPR_CONST ) {
{
ODesc dd; ODesc dd;
dd.SetQuotes(true); dd.SetQuotes(true);
expr->Describe(&dd); expr->Describe(&dd);
@ -148,8 +130,7 @@ void Attr::DescribeReST(ODesc* d, bool shorten) const
add_long_expr_string(d, s, shorten); add_long_expr_string(d, s, shorten);
} }
else else {
{
ODesc dd; ODesc dd;
expr->Eval(nullptr)->Describe(&dd); expr->Eval(nullptr)->Describe(&dd);
std::string s = dd.Description(); std::string s = dd.Description();
@ -161,39 +142,32 @@ void Attr::DescribeReST(ODesc* d, bool shorten) const
add_long_expr_string(d, s, shorten); add_long_expr_string(d, s, shorten);
} }
} }
} }
void Attr::AddTag(ODesc* d) const void Attr::AddTag(ODesc* d) const {
{
if ( d->IsBinary() ) if ( d->IsBinary() )
d->Add(static_cast<zeek_int_t>(Tag())); d->Add(static_cast<zeek_int_t>(Tag()));
else else
d->Add(attr_name(Tag())); d->Add(attr_name(Tag()));
} }
detail::TraversalCode Attr::Traverse(detail::TraversalCallback* cb) const detail::TraversalCode Attr::Traverse(detail::TraversalCallback* cb) const {
{
auto tc = cb->PreAttr(this); auto tc = cb->PreAttr(this);
HANDLE_TC_ATTR_PRE(tc); HANDLE_TC_ATTR_PRE(tc);
if ( expr ) if ( expr ) {
{
auto tc = expr->Traverse(cb); auto tc = expr->Traverse(cb);
HANDLE_TC_ATTR_PRE(tc); HANDLE_TC_ATTR_PRE(tc);
} }
tc = cb->PostAttr(this); tc = cb->PostAttr(this);
HANDLE_TC_ATTR_POST(tc); HANDLE_TC_ATTR_POST(tc);
} }
Attributes::Attributes(TypePtr t, bool arg_in_record, bool is_global) Attributes::Attributes(TypePtr t, bool arg_in_record, bool is_global)
: Attributes(std::vector<AttrPtr>{}, std::move(t), arg_in_record, is_global) : Attributes(std::vector<AttrPtr>{}, std::move(t), arg_in_record, is_global) {}
{
}
Attributes::Attributes(std::vector<AttrPtr> a, TypePtr t, bool arg_in_record, bool is_global) Attributes::Attributes(std::vector<AttrPtr> a, TypePtr t, bool arg_in_record, bool is_global) : type(std::move(t)) {
: type(std::move(t))
{
attrs.reserve(a.size()); attrs.reserve(a.size());
in_record = arg_in_record; in_record = arg_in_record;
global_var = is_global; global_var = is_global;
@ -206,21 +180,17 @@ Attributes::Attributes(std::vector<AttrPtr> a, TypePtr t, bool arg_in_record, bo
for ( auto& attr : a ) for ( auto& attr : a )
AddAttr(std::move(attr)); AddAttr(std::move(attr));
} }
void Attributes::AddAttr(AttrPtr attr, bool is_redef) void Attributes::AddAttr(AttrPtr attr, bool is_redef) {
{ auto acceptable_duplicate_attr = [](const AttrPtr& attr, const AttrPtr& existing) -> bool {
auto acceptable_duplicate_attr = [](const AttrPtr& attr, const AttrPtr& existing) -> bool
{
if ( attr == existing ) if ( attr == existing )
return true; return true;
AttrTag new_tag = attr->Tag(); AttrTag new_tag = attr->Tag();
if ( new_tag == ATTR_DEPRECATED ) if ( new_tag == ATTR_DEPRECATED ) {
{ if ( ! attr->DeprecationMessage().empty() || (existing && ! existing->DeprecationMessage().empty()) )
if ( ! attr->DeprecationMessage().empty() ||
(existing && ! existing->DeprecationMessage().empty()) )
return false; return false;
return true; return true;
@ -233,8 +203,7 @@ void Attributes::AddAttr(AttrPtr attr, bool is_redef)
// A `redef` is allowed to overwrite an existing attribute instead of // A `redef` is allowed to overwrite an existing attribute instead of
// flagging it as ambiguous. // flagging it as ambiguous.
if ( ! is_redef ) if ( ! is_redef ) {
{
auto existing = Find(attr->Tag()); auto existing = Find(attr->Tag());
if ( existing && ! acceptable_duplicate_attr(attr, existing) ) if ( existing && ! acceptable_duplicate_attr(attr, existing) )
reporter->Error("Duplicate %s attribute is ambiguous", attr_name(attr->Tag())); reporter->Error("Duplicate %s attribute is ambiguous", attr_name(attr->Tag()));
@ -254,85 +223,71 @@ void Attributes::AddAttr(AttrPtr attr, bool is_redef)
// For ADD_FUNC or DEL_FUNC, add in an implicit REDEF, since // For ADD_FUNC or DEL_FUNC, add in an implicit REDEF, since
// those attributes only have meaning for a redefinable value. // those attributes only have meaning for a redefinable value.
if ( (attr->Tag() == ATTR_ADD_FUNC || attr->Tag() == ATTR_DEL_FUNC) && ! Find(ATTR_REDEF) ) if ( (attr->Tag() == ATTR_ADD_FUNC || attr->Tag() == ATTR_DEL_FUNC) && ! Find(ATTR_REDEF) ) {
{
auto a = make_intrusive<Attr>(ATTR_REDEF); auto a = make_intrusive<Attr>(ATTR_REDEF);
attrs.emplace_back(std::move(a)); attrs.emplace_back(std::move(a));
} }
// For DEFAULT, add an implicit OPTIONAL if it's not a global. // For DEFAULT, add an implicit OPTIONAL if it's not a global.
if ( ! global_var && attr->Tag() == ATTR_DEFAULT && ! Find(ATTR_OPTIONAL) ) if ( ! global_var && attr->Tag() == ATTR_DEFAULT && ! Find(ATTR_OPTIONAL) ) {
{
auto a = make_intrusive<Attr>(ATTR_OPTIONAL); auto a = make_intrusive<Attr>(ATTR_OPTIONAL);
attrs.emplace_back(std::move(a)); attrs.emplace_back(std::move(a));
} }
} }
void Attributes::AddAttrs(const AttributesPtr& a, bool is_redef) void Attributes::AddAttrs(const AttributesPtr& a, bool is_redef) {
{
for ( const auto& attr : a->GetAttrs() ) for ( const auto& attr : a->GetAttrs() )
AddAttr(attr, is_redef); AddAttr(attr, is_redef);
} }
const AttrPtr& Attributes::Find(AttrTag t) const const AttrPtr& Attributes::Find(AttrTag t) const {
{
for ( const auto& a : attrs ) for ( const auto& a : attrs )
if ( a->Tag() == t ) if ( a->Tag() == t )
return a; return a;
return Attr::nil; return Attr::nil;
} }
void Attributes::RemoveAttr(AttrTag t) void Attributes::RemoveAttr(AttrTag t) {
{ for ( auto it = attrs.begin(); it != attrs.end(); ) {
for ( auto it = attrs.begin(); it != attrs.end(); )
{
if ( (*it)->Tag() == t ) if ( (*it)->Tag() == t )
it = attrs.erase(it); it = attrs.erase(it);
else else
++it; ++it;
} }
} }
void Attributes::Describe(ODesc* d) const void Attributes::Describe(ODesc* d) const {
{ if ( attrs.empty() ) {
if ( attrs.empty() )
{
d->AddCount(0); d->AddCount(0);
return; return;
} }
d->AddCount(static_cast<uint64_t>(attrs.size())); d->AddCount(static_cast<uint64_t>(attrs.size()));
for ( size_t i = 0; i < attrs.size(); ++i ) for ( size_t i = 0; i < attrs.size(); ++i ) {
{
if ( d->IsReadable() && i > 0 ) if ( d->IsReadable() && i > 0 )
d->Add(", "); d->Add(", ");
attrs[i]->Describe(d); attrs[i]->Describe(d);
} }
} }
void Attributes::DescribeReST(ODesc* d, bool shorten) const void Attributes::DescribeReST(ODesc* d, bool shorten) const {
{ for ( size_t i = 0; i < attrs.size(); ++i ) {
for ( size_t i = 0; i < attrs.size(); ++i )
{
if ( i > 0 ) if ( i > 0 )
d->Add(" "); d->Add(" ");
attrs[i]->DescribeReST(d, shorten); attrs[i]->DescribeReST(d, shorten);
} }
} }
void Attributes::CheckAttr(Attr* a) void Attributes::CheckAttr(Attr* a) {
{ switch ( a->Tag() ) {
switch ( a->Tag() )
{
case ATTR_DEPRECATED: case ATTR_DEPRECATED:
case ATTR_REDEF: case ATTR_REDEF:
case ATTR_IS_ASSIGNED: case ATTR_IS_ASSIGNED:
case ATTR_IS_USED: case ATTR_IS_USED: break;
break;
case ATTR_OPTIONAL: case ATTR_OPTIONAL:
if ( global_var ) if ( global_var )
@ -340,67 +295,53 @@ void Attributes::CheckAttr(Attr* a)
break; break;
case ATTR_ADD_FUNC: case ATTR_ADD_FUNC:
case ATTR_DEL_FUNC: case ATTR_DEL_FUNC: {
{
bool is_add = a->Tag() == ATTR_ADD_FUNC; bool is_add = a->Tag() == ATTR_ADD_FUNC;
const auto& at = a->GetExpr()->GetType(); const auto& at = a->GetExpr()->GetType();
if ( at->Tag() != TYPE_FUNC ) if ( at->Tag() != TYPE_FUNC ) {
{ a->GetExpr()->Error(is_add ? "&add_func must be a function" : "&delete_func must be a function");
a->GetExpr()->Error(is_add ? "&add_func must be a function"
: "&delete_func must be a function");
break; break;
} }
FuncType* aft = at->AsFuncType(); FuncType* aft = at->AsFuncType();
if ( ! same_type(aft->Yield(), type) ) if ( ! same_type(aft->Yield(), type) ) {
{ a->GetExpr()->Error(is_add ? "&add_func function must yield same type as variable" :
a->GetExpr()->Error(is_add "&delete_func function must yield same type as variable");
? "&add_func function must yield same type as variable"
: "&delete_func function must yield same type as variable");
break; break;
} }
} } break;
break;
case ATTR_DEFAULT_INSERT: case ATTR_DEFAULT_INSERT: {
{ if ( ! type->IsTable() ) {
if ( ! type->IsTable() )
{
Error("&default_insert only applicable to tables"); Error("&default_insert only applicable to tables");
break; break;
} }
if ( Find(ATTR_DEFAULT) ) if ( Find(ATTR_DEFAULT) ) {
{
Error("&default and &default_insert cannot be used together"); Error("&default and &default_insert cannot be used together");
break; break;
} }
std::string err_msg; std::string err_msg;
if ( ! check_default_attr(a, type, global_var, in_record, err_msg) && if ( ! check_default_attr(a, type, global_var, in_record, err_msg) && ! err_msg.empty() )
! err_msg.empty() )
Error(err_msg.c_str()); Error(err_msg.c_str());
break; break;
} }
case ATTR_DEFAULT: case ATTR_DEFAULT: {
{ if ( Find(ATTR_DEFAULT_INSERT) ) {
if ( Find(ATTR_DEFAULT_INSERT) )
{
Error("&default and &default_insert cannot be used together"); Error("&default and &default_insert cannot be used together");
break; break;
} }
std::string err_msg; std::string err_msg;
if ( ! check_default_attr(a, type, global_var, in_record, err_msg) && if ( ! check_default_attr(a, type, global_var, in_record, err_msg) && ! err_msg.empty() )
! err_msg.empty() )
Error(err_msg.c_str()); Error(err_msg.c_str());
break; break;
} }
case ATTR_EXPIRE_READ: case ATTR_EXPIRE_READ: {
{
if ( Find(ATTR_BROKER_STORE) ) if ( Find(ATTR_BROKER_STORE) )
Error("&broker_store and &read_expire cannot be used simultaneously"); Error("&broker_store and &read_expire cannot be used simultaneously");
@ -410,26 +351,23 @@ void Attributes::CheckAttr(Attr* a)
// fallthrough // fallthrough
case ATTR_EXPIRE_WRITE: case ATTR_EXPIRE_WRITE:
case ATTR_EXPIRE_CREATE: case ATTR_EXPIRE_CREATE: {
{ if ( type->Tag() != TYPE_TABLE ) {
if ( type->Tag() != TYPE_TABLE )
{
Error("expiration only applicable to sets/tables"); Error("expiration only applicable to sets/tables");
break; break;
} }
int num_expires = 0; int num_expires = 0;
for ( const auto& at : attrs ) for ( const auto& at : attrs ) {
{
if ( at->Tag() == ATTR_EXPIRE_READ || at->Tag() == ATTR_EXPIRE_WRITE || if ( at->Tag() == ATTR_EXPIRE_READ || at->Tag() == ATTR_EXPIRE_WRITE ||
at->Tag() == ATTR_EXPIRE_CREATE ) at->Tag() == ATTR_EXPIRE_CREATE )
num_expires++; num_expires++;
} }
if ( num_expires > 1 ) if ( num_expires > 1 ) {
{ Error(
Error("set/table can only have one of &read_expire, &write_expire, " "set/table can only have one of &read_expire, &write_expire, "
"&create_expire"); "&create_expire");
break; break;
} }
@ -443,10 +381,8 @@ void Attributes::CheckAttr(Attr* a)
break; break;
case ATTR_EXPIRE_FUNC: case ATTR_EXPIRE_FUNC: {
{ if ( type->Tag() != TYPE_TABLE ) {
if ( type->Tag() != TYPE_TABLE )
{
Error("expiration only applicable to tables"); Error("expiration only applicable to tables");
break; break;
} }
@ -462,10 +398,8 @@ void Attributes::CheckAttr(Attr* a)
break; break;
} }
case ATTR_ON_CHANGE: case ATTR_ON_CHANGE: {
{ if ( type->Tag() != TYPE_TABLE ) {
if ( type->Tag() != TYPE_TABLE )
{
Error("&on_change only applicable to sets/tables"); Error("&on_change only applicable to sets/tables");
break; break;
} }
@ -478,8 +412,7 @@ void Attributes::CheckAttr(Attr* a)
const FuncType* c_ft = change_func->GetType()->AsFuncType(); const FuncType* c_ft = change_func->GetType()->AsFuncType();
if ( c_ft->Yield()->Tag() != TYPE_VOID ) if ( c_ft->Yield()->Tag() != TYPE_VOID ) {
{
Error("&on_change must not return a value"); Error("&on_change must not return a value");
break; break;
} }
@ -491,47 +424,38 @@ void Attributes::CheckAttr(Attr* a)
const auto& args = c_ft->ParamList()->GetTypes(); const auto& args = c_ft->ParamList()->GetTypes();
const auto& t_indexes = the_table->GetIndexTypes(); const auto& t_indexes = the_table->GetIndexTypes();
if ( args.size() != (type->IsSet() ? 2 : 3) + t_indexes.size() ) if ( args.size() != (type->IsSet() ? 2 : 3) + t_indexes.size() ) {
{
Error("&on_change function has incorrect number of arguments"); Error("&on_change function has incorrect number of arguments");
break; break;
} }
if ( ! same_type(args[0], the_table->AsTableType()) ) if ( ! same_type(args[0], the_table->AsTableType()) ) {
{
Error("&on_change: first argument must be of same type as table"); Error("&on_change: first argument must be of same type as table");
break; break;
} }
// can't check exact type here yet - the data structures don't exist yet. // can't check exact type here yet - the data structures don't exist yet.
if ( args[1]->Tag() != TYPE_ENUM ) if ( args[1]->Tag() != TYPE_ENUM ) {
{
Error("&on_change: second argument must be a TableChange enum"); Error("&on_change: second argument must be a TableChange enum");
break; break;
} }
for ( size_t i = 0; i < t_indexes.size(); i++ ) for ( size_t i = 0; i < t_indexes.size(); i++ ) {
{ if ( ! same_type(args[2 + i], t_indexes[i]) ) {
if ( ! same_type(args[2 + i], t_indexes[i]) )
{
Error("&on_change: index types do not match table"); Error("&on_change: index types do not match table");
break; break;
} }
} }
if ( ! type->IsSet() ) if ( ! type->IsSet() )
if ( ! same_type(args[2 + t_indexes.size()], the_table->Yield()) ) if ( ! same_type(args[2 + t_indexes.size()], the_table->Yield()) ) {
{
Error("&on_change: value type does not match table"); Error("&on_change: value type does not match table");
break; break;
} }
} } break;
break;
case ATTR_BACKEND: case ATTR_BACKEND: {
{ if ( ! global_var || type->Tag() != TYPE_TABLE ) {
if ( ! global_var || type->Tag() != TYPE_TABLE )
{
Error("&backend only applicable to global sets/tables"); Error("&backend only applicable to global sets/tables");
break; break;
} }
@ -539,8 +463,7 @@ void Attributes::CheckAttr(Attr* a)
// cannot do better equality check - the Broker types are not // cannot do better equality check - the Broker types are not
// actually existing yet when we are here. We will do that // actually existing yet when we are here. We will do that
// later - before actually attaching to a broker store // later - before actually attaching to a broker store
if ( a->GetExpr()->GetType()->Tag() != TYPE_ENUM ) if ( a->GetExpr()->GetType()->Tag() != TYPE_ENUM ) {
{
Error("&backend must take an enum argument"); Error("&backend must take an enum argument");
break; break;
} }
@ -549,8 +472,7 @@ void Attributes::CheckAttr(Attr* a)
// explicitly overridden // explicitly overridden
if ( ! type->AsTableType()->IsSet() && if ( ! type->AsTableType()->IsSet() &&
! input::Manager::IsCompatibleType(type->AsTableType()->Yield().get(), true) && ! input::Manager::IsCompatibleType(type->AsTableType()->Yield().get(), true) &&
! Find(ATTR_BROKER_STORE_ALLOW_COMPLEX) ) ! Find(ATTR_BROKER_STORE_ALLOW_COMPLEX) ) {
{
Error("&backend only supports atomic types as table value"); Error("&backend only supports atomic types as table value");
} }
@ -566,16 +488,13 @@ void Attributes::CheckAttr(Attr* a)
break; break;
} }
case ATTR_BROKER_STORE: case ATTR_BROKER_STORE: {
{ if ( type->Tag() != TYPE_TABLE ) {
if ( type->Tag() != TYPE_TABLE )
{
Error("&broker_store only applicable to sets/tables"); Error("&broker_store only applicable to sets/tables");
break; break;
} }
if ( a->GetExpr()->GetType()->Tag() != TYPE_STRING ) if ( a->GetExpr()->GetType()->Tag() != TYPE_STRING ) {
{
Error("&broker_store must take a string argument"); Error("&broker_store must take a string argument");
break; break;
} }
@ -584,8 +503,7 @@ void Attributes::CheckAttr(Attr* a)
// explicitly overridden // explicitly overridden
if ( ! type->AsTableType()->IsSet() && if ( ! type->AsTableType()->IsSet() &&
! input::Manager::IsCompatibleType(type->AsTableType()->Yield().get(), true) && ! input::Manager::IsCompatibleType(type->AsTableType()->Yield().get(), true) &&
! Find(ATTR_BROKER_STORE_ALLOW_COMPLEX) ) ! Find(ATTR_BROKER_STORE_ALLOW_COMPLEX) ) {
{
Error("&broker_store only supports atomic types as table value"); Error("&broker_store only supports atomic types as table value");
} }
@ -601,10 +519,8 @@ void Attributes::CheckAttr(Attr* a)
break; break;
} }
case ATTR_BROKER_STORE_ALLOW_COMPLEX: case ATTR_BROKER_STORE_ALLOW_COMPLEX: {
{ if ( type->Tag() != TYPE_TABLE ) {
if ( type->Tag() != TYPE_TABLE )
{
Error("&broker_allow_complex_type only applicable to sets/tables"); Error("&broker_allow_complex_type only applicable to sets/tables");
break; break;
} }
@ -619,9 +535,7 @@ void Attributes::CheckAttr(Attr* a)
Error("&raw_output only applicable to files"); Error("&raw_output only applicable to files");
break; break;
case ATTR_PRIORITY: case ATTR_PRIORITY: Error("&priority only applicable to event bodies"); break;
Error("&priority only applicable to event bodies");
break;
case ATTR_GROUP: case ATTR_GROUP:
if ( type->Tag() != TYPE_FUNC || type->AsFuncType()->Flavor() != FUNC_FLAVOR_EVENT ) if ( type->Tag() != TYPE_FUNC || type->AsFuncType()->Flavor() != FUNC_FLAVOR_EVENT )
@ -638,18 +552,15 @@ void Attributes::CheckAttr(Attr* a)
Error("&log applied to a type that cannot be logged"); Error("&log applied to a type that cannot be logged");
break; break;
case ATTR_TYPE_COLUMN: case ATTR_TYPE_COLUMN: {
{ if ( type->Tag() != TYPE_PORT ) {
if ( type->Tag() != TYPE_PORT )
{
Error("type_column tag only applicable to ports"); Error("type_column tag only applicable to ports");
break; break;
} }
const auto& atype = a->GetExpr()->GetType(); const auto& atype = a->GetExpr()->GetType();
if ( atype->Tag() != TYPE_STRING ) if ( atype->Tag() != TYPE_STRING ) {
{
Error("type column needs to have a string argument"); Error("type column needs to have a string argument");
break; break;
} }
@ -662,21 +573,18 @@ void Attributes::CheckAttr(Attr* a)
Error("&ordered only applicable to tables"); Error("&ordered only applicable to tables");
break; break;
default: default: BadTag("Attributes::CheckAttr", attr_name(a->Tag()));
BadTag("Attributes::CheckAttr", attr_name(a->Tag()));
}
} }
}
bool Attributes::operator==(const Attributes& other) const bool Attributes::operator==(const Attributes& other) const {
{
if ( attrs.empty() ) if ( attrs.empty() )
return other.attrs.empty(); return other.attrs.empty();
if ( other.attrs.empty() ) if ( other.attrs.empty() )
return false; return false;
for ( const auto& a : attrs ) for ( const auto& a : attrs ) {
{
const auto& o = other.Find(a->Tag()); const auto& o = other.Find(a->Tag());
if ( ! o ) if ( ! o )
@ -686,8 +594,7 @@ bool Attributes::operator==(const Attributes& other) const
return false; return false;
} }
for ( const auto& o : other.attrs ) for ( const auto& o : other.attrs ) {
{
const auto& a = Find(o->Tag()); const auto& a = Find(o->Tag());
if ( ! a ) if ( ! a )
@ -698,25 +605,21 @@ bool Attributes::operator==(const Attributes& other) const
} }
return true; return true;
} }
bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_record, bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_record, std::string& err_msg) {
std::string& err_msg)
{
ASSERT(a->Tag() == ATTR_DEFAULT || a->Tag() == ATTR_DEFAULT_INSERT); ASSERT(a->Tag() == ATTR_DEFAULT || a->Tag() == ATTR_DEFAULT_INSERT);
std::string aname = attr_name(a->Tag()); std::string aname = attr_name(a->Tag());
// &default is allowed for global tables, since it's used in // &default is allowed for global tables, since it's used in
// initialization of table fields. It's not allowed otherwise. // initialization of table fields. It's not allowed otherwise.
if ( global_var && ! type->IsTable() ) if ( global_var && ! type->IsTable() ) {
{
err_msg = aname + " is not valid for global variables except for tables"; err_msg = aname + " is not valid for global variables except for tables";
return false; return false;
} }
const auto& atype = a->GetExpr()->GetType(); const auto& atype = a->GetExpr()->GetType();
if ( type->Tag() != TYPE_TABLE || (type->IsSet() && ! in_record) ) if ( type->Tag() != TYPE_TABLE || (type->IsSet() && ! in_record) ) {
{
if ( same_type(atype, type) ) if ( same_type(atype, type) )
// Ok. // Ok.
return true; return true;
@ -733,8 +636,7 @@ bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_r
auto e = check_and_promote_expr(a->GetExpr(), type); auto e = check_and_promote_expr(a->GetExpr(), type);
if ( e ) if ( e ) {
{
a->SetAttrExpr(std::move(e)); a->SetAttrExpr(std::move(e));
// Ok. // Ok.
return true; return true;
@ -747,17 +649,14 @@ bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_r
TableType* tt = type->AsTableType(); TableType* tt = type->AsTableType();
const auto& ytype = tt->Yield(); const auto& ytype = tt->Yield();
if ( ! in_record ) if ( ! in_record ) { // &default applies to the type itself.
{ // &default applies to the type itself.
if ( same_type(atype, ytype) ) if ( same_type(atype, ytype) )
return true; return true;
// It can still be a default function. // It can still be a default function.
if ( atype->Tag() == TYPE_FUNC ) if ( atype->Tag() == TYPE_FUNC ) {
{
FuncType* f = atype->AsFuncType(); FuncType* f = atype->AsFuncType();
if ( ! f->CheckArgs(tt->GetIndexTypes()) || ! same_type(f->Yield(), ytype) ) if ( ! f->CheckArgs(tt->GetIndexTypes()) || ! same_type(f->Yield(), ytype) ) {
{
err_msg = aname + " function type clash"; err_msg = aname + " function type clash";
return false; return false;
} }
@ -774,8 +673,7 @@ bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_r
auto e = check_and_promote_expr(a->GetExpr(), ytype); auto e = check_and_promote_expr(a->GetExpr(), ytype);
if ( e ) if ( e ) {
{
a->SetAttrExpr(std::move(e)); a->SetAttrExpr(std::move(e));
// Ok. // Ok.
return true; return true;
@ -790,12 +688,10 @@ bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_r
if ( same_type(atype, type) ) if ( same_type(atype, type) )
return true; return true;
if ( (atype->Tag() == TYPE_TABLE && atype->AsTableType()->IsUnspecifiedTable()) ) if ( (atype->Tag() == TYPE_TABLE && atype->AsTableType()->IsUnspecifiedTable()) ) {
{
auto e = check_and_promote_expr(a->GetExpr(), type); auto e = check_and_promote_expr(a->GetExpr(), type);
if ( e ) if ( e ) {
{
a->SetAttrExpr(std::move(e)); a->SetAttrExpr(std::move(e));
return true; return true;
} }
@ -809,21 +705,19 @@ bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_r
err_msg = "&default value has inconsistent type"; err_msg = "&default value has inconsistent type";
return false; return false;
} }
detail::TraversalCode Attributes::Traverse(detail::TraversalCallback* cb) const detail::TraversalCode Attributes::Traverse(detail::TraversalCallback* cb) const {
{
auto tc = cb->PreAttrs(this); auto tc = cb->PreAttrs(this);
HANDLE_TC_ATTRS_PRE(tc); HANDLE_TC_ATTRS_PRE(tc);
for ( const auto& a : attrs ) for ( const auto& a : attrs ) {
{
tc = a->Traverse(cb); tc = a->Traverse(cb);
HANDLE_TC_ATTRS_PRE(tc); HANDLE_TC_ATTRS_PRE(tc);
} }
tc = cb->PostAttrs(this); tc = cb->PostAttrs(this);
HANDLE_TC_ATTRS_POST(tc); HANDLE_TC_ATTRS_POST(tc);
} }
} } // namespace zeek::detail

View file

@ -14,20 +14,17 @@
// modify expressions or supply metadata on types, and the kind that // modify expressions or supply metadata on types, and the kind that
// are extra metadata on every variable instance. // are extra metadata on every variable instance.
namespace zeek namespace zeek {
{
class Type; class Type;
using TypePtr = IntrusivePtr<Type>; using TypePtr = IntrusivePtr<Type>;
namespace detail namespace detail {
{
class Expr; class Expr;
using ExprPtr = IntrusivePtr<Expr>; using ExprPtr = IntrusivePtr<Expr>;
enum AttrTag enum AttrTag {
{
ATTR_OPTIONAL, ATTR_OPTIONAL,
ATTR_DEFAULT, ATTR_DEFAULT,
ATTR_DEFAULT_INSERT, // insert default value on failed lookups ATTR_DEFAULT_INSERT, // insert default value on failed lookups
@ -54,15 +51,14 @@ enum AttrTag
ATTR_IS_USED, // to suppress usage warnings ATTR_IS_USED, // to suppress usage warnings
ATTR_ORDERED, // used to store tables in ordered mode ATTR_ORDERED, // used to store tables in ordered mode
NUM_ATTRS // this item should always be last NUM_ATTRS // this item should always be last
}; };
class Attr; class Attr;
using AttrPtr = IntrusivePtr<Attr>; using AttrPtr = IntrusivePtr<Attr>;
class Attributes; class Attributes;
using AttributesPtr = IntrusivePtr<Attributes>; using AttributesPtr = IntrusivePtr<Attributes>;
class Attr final : public Obj class Attr final : public Obj {
{
public: public:
static inline const AttrPtr nil; static inline const AttrPtr nil;
@ -86,8 +82,7 @@ public:
*/ */
std::string DeprecationMessage() const; std::string DeprecationMessage() const;
bool operator==(const Attr& other) const bool operator==(const Attr& other) const {
{
if ( tag != other.tag ) if ( tag != other.tag )
return false; return false;
@ -108,11 +103,10 @@ protected:
AttrTag tag; AttrTag tag;
ExprPtr expr; ExprPtr expr;
}; };
// Manages a collection of attributes. // Manages a collection of attributes.
class Attributes final : public Obj class Attributes final : public Obj {
{
public: public:
Attributes(std::vector<AttrPtr> a, TypePtr t, bool in_record, bool is_global); Attributes(std::vector<AttrPtr> a, TypePtr t, bool in_record, bool is_global);
Attributes(TypePtr t, bool in_record, bool is_global); Attributes(TypePtr t, bool in_record, bool is_global);
@ -144,7 +138,7 @@ protected:
bool in_record; bool in_record;
bool global_var; bool global_var;
}; };
// Checks whether default attribute "a" is compatible with the given type. // Checks whether default attribute "a" is compatible with the given type.
// "global_var" specifies whether the attribute is being associated with // "global_var" specifies whether the attribute is being associated with
@ -154,8 +148,7 @@ protected:
// Returns true on compatibility (which might include modifying "a"), false // Returns true on compatibility (which might include modifying "a"), false
// on an error. If an error message hasn't been directly generated, then // on an error. If an error message hasn't been directly generated, then
// it will be returned in err_msg. // it will be returned in err_msg.
extern bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_record, extern bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_record, std::string& err_msg);
std::string& err_msg);
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -8,15 +8,13 @@
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
#include "zeek/ZeekString.h" #include "zeek/ZeekString.h"
namespace zeek::detail namespace zeek::detail {
{
int Base64Converter::default_base64_table[256]; int Base64Converter::default_base64_table[256];
const std::string Base64Converter::default_alphabet = const std::string Base64Converter::default_alphabet =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
void Base64Converter::Encode(int len, const unsigned char* data, int* pblen, char** pbuf) void Base64Converter::Encode(int len, const unsigned char* data, int* pblen, char** pbuf) {
{
int blen; int blen;
char* buf; char* buf;
@ -26,20 +24,17 @@ void Base64Converter::Encode(int len, const unsigned char* data, int* pblen, cha
if ( *pbuf && (*pblen % 4 != 0) ) if ( *pbuf && (*pblen % 4 != 0) )
reporter->InternalError("Base64 encode buffer not a multiple of 4"); reporter->InternalError("Base64 encode buffer not a multiple of 4");
if ( *pbuf ) if ( *pbuf ) {
{
buf = *pbuf; buf = *pbuf;
blen = *pblen; blen = *pblen;
} }
else else {
{
blen = (int)(4 * ceil((double)len / 3)); blen = (int)(4 * ceil((double)len / 3));
*pbuf = buf = new char[blen]; *pbuf = buf = new char[blen];
*pblen = blen; *pblen = blen;
} }
for ( int i = 0, j = 0; (i < len) && (j < blen); ) for ( int i = 0, j = 0; (i < len) && (j < blen); ) {
{
uint32_t bit32 = data[i++] << 16; uint32_t bit32 = data[i++] << 16;
bit32 += (i++ < len ? data[i - 1] : 0) << 8; bit32 += (i++ < len ? data[i - 1] : 0) << 8;
bit32 += i++ < len ? data[i - 1] : 0; bit32 += i++ < len ? data[i - 1] : 0;
@ -49,10 +44,9 @@ void Base64Converter::Encode(int len, const unsigned char* data, int* pblen, cha
buf[j++] = (i == (len + 2)) ? '=' : alphabet[(bit32 >> 6) & 0x3f]; buf[j++] = (i == (len + 2)) ? '=' : alphabet[(bit32 >> 6) & 0x3f];
buf[j++] = (i >= (len + 1)) ? '=' : alphabet[bit32 & 0x3f]; buf[j++] = (i >= (len + 1)) ? '=' : alphabet[bit32 & 0x3f];
} }
} }
int* Base64Converter::InitBase64Table(const std::string& alphabet) int* Base64Converter::InitBase64Table(const std::string& alphabet) {
{
assert(alphabet.size() == 64); assert(alphabet.size() == 64);
static bool default_table_initialized = false; static bool default_table_initialized = false;
@ -62,8 +56,7 @@ int* Base64Converter::InitBase64Table(const std::string& alphabet)
int* base64_table = nullptr; int* base64_table = nullptr;
if ( alphabet == default_alphabet ) if ( alphabet == default_alphabet ) {
{
base64_table = default_base64_table; base64_table = default_base64_table;
default_table_initialized = true; default_table_initialized = true;
} }
@ -74,8 +67,7 @@ int* Base64Converter::InitBase64Table(const std::string& alphabet)
for ( i = 0; i < 256; ++i ) for ( i = 0; i < 256; ++i )
base64_table[i] = -1; base64_table[i] = -1;
for ( i = 0; i < 26; ++i ) for ( i = 0; i < 26; ++i ) {
{
base64_table[int(alphabet[0 + i])] = i; base64_table[int(alphabet[0 + i])] = i;
base64_table[int(alphabet[26 + i])] = i + 26; base64_table[int(alphabet[26 + i])] = i + 26;
} }
@ -89,17 +81,14 @@ int* Base64Converter::InitBase64Table(const std::string& alphabet)
base64_table[int('=')] = 0; base64_table[int('=')] = 0;
return base64_table; return base64_table;
} }
Base64Converter::Base64Converter(Connection* arg_conn, const std::string& arg_alphabet) Base64Converter::Base64Converter(Connection* arg_conn, const std::string& arg_alphabet) {
{ if ( arg_alphabet.size() > 0 ) {
if ( arg_alphabet.size() > 0 )
{
assert(arg_alphabet.size() == 64); assert(arg_alphabet.size() == 64);
alphabet = arg_alphabet; alphabet = arg_alphabet;
} }
else else {
{
alphabet = default_alphabet; alphabet = default_alphabet;
} }
@ -108,16 +97,14 @@ Base64Converter::Base64Converter(Connection* arg_conn, const std::string& arg_al
base64_padding = base64_after_padding = 0; base64_padding = base64_after_padding = 0;
errored = 0; errored = 0;
conn = arg_conn; conn = arg_conn;
} }
Base64Converter::~Base64Converter() Base64Converter::~Base64Converter() {
{
if ( base64_table != default_base64_table ) if ( base64_table != default_base64_table )
delete[] base64_table; delete[] base64_table;
} }
int Base64Converter::Decode(int len, const char* data, int* pblen, char** pbuf) int Base64Converter::Decode(int len, const char* data, int* pblen, char** pbuf) {
{
int blen; int blen;
char* buf; char* buf;
@ -128,13 +115,11 @@ int Base64Converter::Decode(int len, const char* data, int* pblen, char** pbuf)
if ( ! pbuf ) if ( ! pbuf )
reporter->InternalError("nil pointer to decoding result buffer"); reporter->InternalError("nil pointer to decoding result buffer");
if ( *pbuf ) if ( *pbuf ) {
{
buf = *pbuf; buf = *pbuf;
blen = *pblen; blen = *pblen;
} }
else else {
{
// Estimate the maximal number of 3-byte groups needed, // Estimate the maximal number of 3-byte groups needed,
// plus 1 byte for the optional ending NUL. // plus 1 byte for the optional ending NUL.
blen = int((len + base64_group_next + 3) / 4) * 3 + 1; blen = int((len + base64_group_next + 3) / 4) * 3 + 1;
@ -143,14 +128,11 @@ int Base64Converter::Decode(int len, const char* data, int* pblen, char** pbuf)
int dlen = 0; int dlen = 0;
while ( true ) while ( true ) {
{ if ( base64_group_next == 4 ) {
if ( base64_group_next == 4 )
{
// For every group of 4 6-bit numbers, // For every group of 4 6-bit numbers,
// write the decoded 3 bytes to the buffer. // write the decoded 3 bytes to the buffer.
if ( base64_after_padding ) if ( base64_after_padding ) {
{
if ( ++errored == 1 ) if ( ++errored == 1 )
IllegalEncoding("extra base64 groups after '=' padding are ignored"); IllegalEncoding("extra base64 groups after '=' padding are ignored");
base64_group_next = 0; base64_group_next = 0;
@ -172,7 +154,7 @@ int Base64Converter::Decode(int len, const char* data, int* pblen, char** pbuf)
*buf++ = char((bit32 >> 8) & 0xff); *buf++ = char((bit32 >> 8) & 0xff);
if ( --num_octets >= 0 ) if ( --num_octets >= 0 )
*buf++ = char((bit32)&0xff); *buf++ = char((bit32) & 0xff);
if ( base64_padding > 0 ) if ( base64_padding > 0 )
base64_after_padding = 1; base64_after_padding = 1;
@ -191,8 +173,7 @@ int Base64Converter::Decode(int len, const char* data, int* pblen, char** pbuf)
int k = base64_table[c]; int k = base64_table[c];
if ( k >= 0 ) if ( k >= 0 )
base64_group[base64_group_next++] = k; base64_group[base64_group_next++] = k;
else else {
{
if ( ++errored == 1 ) if ( ++errored == 1 )
IllegalEncoding(util::fmt("character %d ignored by Base64 decoding", (int)c)); IllegalEncoding(util::fmt("character %d ignored by Base64 decoding", (int)c));
} }
@ -202,17 +183,15 @@ int Base64Converter::Decode(int len, const char* data, int* pblen, char** pbuf)
*pblen = buf - *pbuf; *pblen = buf - *pbuf;
return dlen; return dlen;
} }
int Base64Converter::Done(int* pblen, char** pbuf) int Base64Converter::Done(int* pblen, char** pbuf) {
{
const char* padding = "==="; const char* padding = "===";
if ( base64_group_next != 0 ) if ( base64_group_next != 0 ) {
{
if ( base64_group_next < 4 ) if ( base64_group_next < 4 )
IllegalEncoding(util::fmt("incomplete base64 group, padding with %d bits of 0", IllegalEncoding(
(4 - base64_group_next) * 6)); util::fmt("incomplete base64 group, padding with %d bits of 0", (4 - base64_group_next) * 6));
Decode(4 - base64_group_next, padding, pblen, pbuf); Decode(4 - base64_group_next, padding, pblen, pbuf);
return -1; return -1;
} }
@ -221,21 +200,18 @@ int Base64Converter::Done(int* pblen, char** pbuf)
*pblen = 0; *pblen = 0;
return 0; return 0;
} }
void Base64Converter::IllegalEncoding(const char* msg) void Base64Converter::IllegalEncoding(const char* msg) {
{
// strncpy(error_msg, msg, sizeof(error_msg)); // strncpy(error_msg, msg, sizeof(error_msg));
if ( conn ) if ( conn )
conn->Weird("base64_illegal_encoding", msg); conn->Weird("base64_illegal_encoding", msg);
else else
reporter->Error("%s", msg); reporter->Error("%s", msg);
} }
String* decode_base64(const String* s, const String* a, Connection* conn) String* decode_base64(const String* s, const String* a, Connection* conn) {
{ if ( a && a->Len() != 0 && a->Len() != 64 ) {
if ( a && a->Len() != 0 && a->Len() != 64 )
{
reporter->Error("base64 decoding alphabet is not 64 characters: %s", a->CheckString()); reporter->Error("base64 decoding alphabet is not 64 characters: %s", a->CheckString());
return nullptr; return nullptr;
} }
@ -263,12 +239,10 @@ String* decode_base64(const String* s, const String* a, Connection* conn)
err: err:
delete[] rbuf; delete[] rbuf;
return nullptr; return nullptr;
} }
String* encode_base64(const String* s, const String* a, Connection* conn) String* encode_base64(const String* s, const String* a, Connection* conn) {
{ if ( a && a->Len() != 0 && a->Len() != 64 ) {
if ( a && a->Len() != 0 && a->Len() != 64 )
{
reporter->Error("base64 alphabet is not 64 characters: %s", a->CheckString()); reporter->Error("base64 alphabet is not 64 characters: %s", a->CheckString());
return nullptr; return nullptr;
} }
@ -279,6 +253,6 @@ String* encode_base64(const String* s, const String* a, Connection* conn)
enc.Encode(s->Len(), (const unsigned char*)s->Bytes(), &outlen, &outbuf); enc.Encode(s->Len(), (const unsigned char*)s->Bytes(), &outlen, &outbuf);
return new String(true, (u_char*)outbuf, outlen); return new String(true, (u_char*)outbuf, outlen);
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -4,18 +4,15 @@
#include <string> #include <string>
namespace zeek namespace zeek {
{
class String; class String;
class Connection; class Connection;
namespace detail namespace detail {
{
// Maybe we should have a base class for generic decoders? // Maybe we should have a base class for generic decoders?
class Base64Converter class Base64Converter {
{
public: public:
// <conn> is used for error reporting. If it is set to zero (as, // <conn> is used for error reporting. If it is set to zero (as,
// e.g., done by the built-in functions decode_base64() and // e.g., done by the built-in functions decode_base64() and
@ -63,10 +60,10 @@ protected:
int* base64_table; int* base64_table;
int errored; // if true, we encountered an error - skip further processing int errored; // if true, we encountered an error - skip further processing
Connection* conn; Connection* conn;
}; };
String* decode_base64(const String* s, const String* a = nullptr, Connection* conn = nullptr); String* decode_base64(const String* s, const String* a = nullptr, Connection* conn = nullptr);
String* encode_base64(const String* s, const String* a = nullptr, Connection* conn = nullptr); String* encode_base64(const String* s, const String* a = nullptr, Connection* conn = nullptr);
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -4,9 +4,8 @@
#include "zeek/Val.h" #include "zeek/Val.h"
namespace zeek::detail namespace zeek::detail {
{
BifReturnVal::BifReturnVal(std::nullptr_t) noexcept { } BifReturnVal::BifReturnVal(std::nullptr_t) noexcept {}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -6,31 +6,27 @@
#include "zeek/IntrusivePtr.h" #include "zeek/IntrusivePtr.h"
namespace zeek namespace zeek {
{
class Val; class Val;
using ValPtr = IntrusivePtr<Val>; using ValPtr = IntrusivePtr<Val>;
namespace detail namespace detail {
{
/** /**
* A simple wrapper class to use for the return value of BIFs so that * A simple wrapper class to use for the return value of BIFs so that
* they may return either a Val* or IntrusivePtr<Val> (the former could * they may return either a Val* or IntrusivePtr<Val> (the former could
* potentially be deprecated). * potentially be deprecated).
*/ */
class BifReturnVal class BifReturnVal {
{
public: public:
template <typename T> BifReturnVal(IntrusivePtr<T> v) noexcept : rval(AdoptRef{}, v.release()) template<typename T>
{ BifReturnVal(IntrusivePtr<T> v) noexcept : rval(AdoptRef{}, v.release()) {}
}
BifReturnVal(std::nullptr_t) noexcept; BifReturnVal(std::nullptr_t) noexcept;
ValPtr rval; ValPtr rval;
}; };
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -9,30 +9,23 @@
#include "zeek/DFA.h" #include "zeek/DFA.h"
#include "zeek/RE.h" #include "zeek/RE.h"
namespace zeek::detail namespace zeek::detail {
{
CCL::CCL() CCL::CCL() {
{
syms = new int_list; syms = new int_list;
index = -(rem->InsertCCL(this) + 1); index = -(rem->InsertCCL(this) + 1);
negated = 0; negated = 0;
} }
CCL::~CCL() CCL::~CCL() { delete syms; }
{
delete syms;
}
void CCL::Negate() void CCL::Negate() {
{
negated = 1; negated = 1;
Add(SYM_BOL); Add(SYM_BOL);
Add(SYM_EOL); Add(SYM_EOL);
} }
void CCL::Add(int sym) void CCL::Add(int sym) {
{
auto sym_p = static_cast<std::intptr_t>(sym); auto sym_p = static_cast<std::intptr_t>(sym);
// Check to see if the character is already in the ccl. // Check to see if the character is already in the ccl.
@ -41,11 +34,8 @@ void CCL::Add(int sym)
return; return;
syms->push_back(sym_p); syms->push_back(sym_p);
} }
void CCL::Sort() void CCL::Sort() { std::sort(syms->begin(), syms->end()); }
{
std::sort(syms->begin(), syms->end());
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -5,13 +5,11 @@
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
namespace zeek::detail namespace zeek::detail {
{
using int_list = std::vector<std::intptr_t>; using int_list = std::vector<std::intptr_t>;
class CCL class CCL {
{
public: public:
CCL(); CCL();
~CCL(); ~CCL();
@ -25,8 +23,7 @@ public:
int_list* Syms() { return syms; } int_list* Syms() { return syms; }
void ReplaceSyms(int_list* new_syms) void ReplaceSyms(int_list* new_syms) {
{
delete syms; delete syms;
syms = new_syms; syms = new_syms;
} }
@ -35,6 +32,6 @@ protected:
int_list* syms; int_list* syms;
int negated; int negated;
int index; int index;
}; };
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -15,36 +15,31 @@
#include "zeek/Val.h" #include "zeek/Val.h"
#include "zeek/ZeekString.h" #include "zeek/ZeekString.h"
namespace zeek::detail namespace zeek::detail {
{
// A comparison callable to assist with consistent iteration order over tables // A comparison callable to assist with consistent iteration order over tables
// during reservation & writes. // during reservation & writes.
struct HashKeyComparer struct HashKeyComparer {
{ bool operator()(const std::unique_ptr<HashKey>& a, const std::unique_ptr<HashKey>& b) const {
bool operator()(const std::unique_ptr<HashKey>& a, const std::unique_ptr<HashKey>& b) const
{
if ( a->Hash() != b->Hash() ) if ( a->Hash() != b->Hash() )
return a->Hash() < b->Hash(); return a->Hash() < b->Hash();
if ( a->Size() != b->Size() ) if ( a->Size() != b->Size() )
return a->Size() < b->Size(); return a->Size() < b->Size();
return memcmp(a->Key(), b->Key(), a->Size()) < 0; return memcmp(a->Key(), b->Key(), a->Size()) < 0;
} }
}; };
using HashkeyMap = std::map<std::unique_ptr<HashKey>, ListValPtr, HashKeyComparer>; using HashkeyMap = std::map<std::unique_ptr<HashKey>, ListValPtr, HashKeyComparer>;
using HashkeyMapPtr = std::unique_ptr<HashkeyMap>; using HashkeyMapPtr = std::unique_ptr<HashkeyMap>;
// Helper that produces a table from HashKeys to the ListVal indexes into the // Helper that produces a table from HashKeys to the ListVal indexes into the
// table, that we can iterate over in sorted-Hashkey order. // table, that we can iterate over in sorted-Hashkey order.
const HashkeyMapPtr ordered_hashkeys(const TableVal* tv) const HashkeyMapPtr ordered_hashkeys(const TableVal* tv) {
{
auto res = std::make_unique<HashkeyMap>(); auto res = std::make_unique<HashkeyMap>();
auto tbl = tv->AsTable(); auto tbl = tv->AsTable();
auto idx = 0; auto idx = 0;
for ( const auto& entry : *tbl ) for ( const auto& entry : *tbl ) {
{
auto k = entry.GetHashKey(); auto k = entry.GetHashKey();
// Potential optimization: we could do without the following if // Potential optimization: we could do without the following if
// the caller uses k directly to determine key length & // the caller uses k directly to determine key length &
@ -58,27 +53,23 @@ const HashkeyMapPtr ordered_hashkeys(const TableVal* tv)
} }
return res; return res;
} }
CompositeHash::CompositeHash(TypeListPtr composite_type) : type(std::move(composite_type)) CompositeHash::CompositeHash(TypeListPtr composite_type) : type(std::move(composite_type)) {
{
if ( type->GetTypes().size() == 1 ) if ( type->GetTypes().size() == 1 )
is_singleton = true; is_singleton = true;
} }
std::unique_ptr<HashKey> CompositeHash::MakeHashKey(const Val& argv, bool type_check) const std::unique_ptr<HashKey> CompositeHash::MakeHashKey(const Val& argv, bool type_check) const {
{
auto res = std::make_unique<HashKey>(); auto res = std::make_unique<HashKey>();
const auto& tl = type->GetTypes(); const auto& tl = type->GetTypes();
if ( is_singleton ) if ( is_singleton ) {
{
const Val* v = &argv; const Val* v = &argv;
// This is the "singleton" case -- actually just a single value // This is the "singleton" case -- actually just a single value
// that may come bundled in a list. If so, unwrap it. // that may come bundled in a list. If so, unwrap it.
if ( v->GetType()->Tag() == TYPE_LIST ) if ( v->GetType()->Tag() == TYPE_LIST ) {
{
auto lv = v->AsListVal(); auto lv = v->AsListVal();
if ( type_check && lv->Length() != 1 ) if ( type_check && lv->Length() != 1 )
@ -105,25 +96,21 @@ std::unique_ptr<HashKey> CompositeHash::MakeHashKey(const Val& argv, bool type_c
// The size computation resulted in a requested buffer size; allocate it. // The size computation resulted in a requested buffer size; allocate it.
res->Allocate(); res->Allocate();
for ( auto i = 0u; i < tl.size(); ++i ) for ( auto i = 0u; i < tl.size(); ++i ) {
{ if ( ! SingleValHash(*res, argv.AsListVal()->Idx(i).get(), tl[i].get(), type_check, false, false) )
if ( ! SingleValHash(*res, argv.AsListVal()->Idx(i).get(), tl[i].get(), type_check, false,
false) )
return nullptr; return nullptr;
} }
return res; return res;
} }
ListValPtr CompositeHash::RecoverVals(const HashKey& hk) const ListValPtr CompositeHash::RecoverVals(const HashKey& hk) const {
{
auto l = make_intrusive<ListVal>(TYPE_ANY); auto l = make_intrusive<ListVal>(TYPE_ANY);
const auto& tl = type->GetTypes(); const auto& tl = type->GetTypes();
hk.ResetRead(); hk.ResetRead();
for ( const auto& type : tl ) for ( const auto& type : tl ) {
{
ValPtr v; ValPtr v;
if ( ! RecoverOneVal(hk, type.get(), &v, false, is_singleton) ) if ( ! RecoverOneVal(hk, type.get(), &v, false, is_singleton) )
@ -134,30 +121,24 @@ ListValPtr CompositeHash::RecoverVals(const HashKey& hk) const
} }
return l; return l;
} }
bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool optional, bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool optional, bool singleton) const {
bool singleton) const
{
TypeTag tag = t->Tag(); TypeTag tag = t->Tag();
InternalTypeTag it = t->InternalType(); InternalTypeTag it = t->InternalType();
if ( optional ) if ( optional ) {
{
bool opt; bool opt;
hk.Read("optional", opt); hk.Read("optional", opt);
if ( ! opt ) if ( ! opt ) {
{
*pval = nullptr; *pval = nullptr;
return true; return true;
} }
} }
switch ( it ) switch ( it ) {
{ case TYPE_INTERNAL_INT: {
case TYPE_INTERNAL_INT:
{
zeek_int_t i; zeek_int_t i;
hk.Read("int", i); hk.Read("int", i);
@ -167,42 +148,30 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
*pval = val_mgr->Bool(i); *pval = val_mgr->Bool(i);
else if ( tag == TYPE_INT ) else if ( tag == TYPE_INT )
*pval = val_mgr->Int(i); *pval = val_mgr->Int(i);
else else {
{ reporter->InternalError("bad internal unsigned int in CompositeHash::RecoverOneVal()");
reporter->InternalError(
"bad internal unsigned int in CompositeHash::RecoverOneVal()");
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
} } break;
break;
case TYPE_INTERNAL_UNSIGNED: case TYPE_INTERNAL_UNSIGNED: {
{
zeek_uint_t u; zeek_uint_t u;
hk.Read("unsigned", u); hk.Read("unsigned", u);
switch ( tag ) switch ( tag ) {
{ case TYPE_COUNT: *pval = val_mgr->Count(u); break;
case TYPE_COUNT:
*pval = val_mgr->Count(u);
break;
case TYPE_PORT: case TYPE_PORT: *pval = val_mgr->Port(u); break;
*pval = val_mgr->Port(u);
break;
default: default:
reporter->InternalError( reporter->InternalError("bad internal unsigned int in CompositeHash::RecoverOneVal()");
"bad internal unsigned int in CompositeHash::RecoverOneVal()");
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
} } break;
break;
case TYPE_INTERNAL_DOUBLE: case TYPE_INTERNAL_DOUBLE: {
{
double d; double d;
hk.Read("double", d); hk.Read("double", d);
@ -212,33 +181,25 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
*pval = make_intrusive<TimeVal>(d); *pval = make_intrusive<TimeVal>(d);
else else
*pval = make_intrusive<DoubleVal>(d); *pval = make_intrusive<DoubleVal>(d);
} } break;
break;
case TYPE_INTERNAL_ADDR: case TYPE_INTERNAL_ADDR: {
{
hk.AlignRead(sizeof(uint32_t)); hk.AlignRead(sizeof(uint32_t));
hk.EnsureReadSpace(sizeof(uint32_t) * 4); hk.EnsureReadSpace(sizeof(uint32_t) * 4);
IPAddr addr(IPv6, static_cast<const uint32_t*>(hk.KeyAtRead()), IPAddr::Network); IPAddr addr(IPv6, static_cast<const uint32_t*>(hk.KeyAtRead()), IPAddr::Network);
hk.SkipRead("addr", sizeof(uint32_t) * 4); hk.SkipRead("addr", sizeof(uint32_t) * 4);
switch ( tag ) switch ( tag ) {
{ case TYPE_ADDR: *pval = make_intrusive<AddrVal>(addr); break;
case TYPE_ADDR:
*pval = make_intrusive<AddrVal>(addr);
break;
default: default:
reporter->InternalError( reporter->InternalError("bad internal address in CompositeHash::RecoverOneVal()");
"bad internal address in CompositeHash::RecoverOneVal()");
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
} } break;
break;
case TYPE_INTERNAL_SUBNET: case TYPE_INTERNAL_SUBNET: {
{
hk.AlignRead(sizeof(uint32_t)); hk.AlignRead(sizeof(uint32_t));
hk.EnsureReadSpace(sizeof(uint32_t) * 4); hk.EnsureReadSpace(sizeof(uint32_t) * 4);
IPAddr addr(IPv6, static_cast<const uint32_t*>(hk.KeyAtRead()), IPAddr::Network); IPAddr addr(IPv6, static_cast<const uint32_t*>(hk.KeyAtRead()), IPAddr::Network);
@ -247,16 +208,12 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
uint32_t width; uint32_t width;
hk.Read("subnet-width", width); hk.Read("subnet-width", width);
*pval = make_intrusive<SubNetVal>(addr, width); *pval = make_intrusive<SubNetVal>(addr, width);
} } break;
break;
case TYPE_INTERNAL_VOID: case TYPE_INTERNAL_VOID:
case TYPE_INTERNAL_OTHER: case TYPE_INTERNAL_OTHER: {
{ switch ( t->Tag() ) {
switch ( t->Tag() ) case TYPE_FUNC: {
{
case TYPE_FUNC:
{
uint32_t id; uint32_t id;
hk.Read("func", id); hk.Read("func", id);
@ -273,36 +230,29 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
const auto& pvt = (*pval)->GetType(); const auto& pvt = (*pval)->GetType();
if ( ! pvt ) if ( ! pvt )
reporter->InternalError( reporter->InternalError("bad aggregate Val in CompositeHash::RecoverOneVal()");
"bad aggregate Val in CompositeHash::RecoverOneVal()");
else if ( t->Tag() != TYPE_FUNC && ! same_type(pvt, t) ) else if ( t->Tag() != TYPE_FUNC && ! same_type(pvt, t) )
// ### Maybe fix later, but may be fundamentally un-checkable --US // ### Maybe fix later, but may be fundamentally un-checkable --US
{ {
reporter->InternalError( reporter->InternalError("inconsistent aggregate Val in CompositeHash::RecoverOneVal()");
"inconsistent aggregate Val in CompositeHash::RecoverOneVal()");
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
// ### A crude approximation for now. // ### A crude approximation for now.
else if ( t->Tag() == TYPE_FUNC && pvt->Tag() != TYPE_FUNC ) else if ( t->Tag() == TYPE_FUNC && pvt->Tag() != TYPE_FUNC ) {
{ reporter->InternalError("inconsistent aggregate Val in CompositeHash::RecoverOneVal()");
reporter->InternalError(
"inconsistent aggregate Val in CompositeHash::RecoverOneVal()");
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
} } break;
break;
case TYPE_PATTERN: case TYPE_PATTERN: {
{
const char* texts[2] = {nullptr, nullptr}; const char* texts[2] = {nullptr, nullptr};
uint64_t lens[2] = {0, 0}; uint64_t lens[2] = {0, 0};
if ( ! singleton ) if ( ! singleton ) {
{
hk.Read("pattern-len1", lens[0]); hk.Read("pattern-len1", lens[0]);
hk.Read("pattern-len2", lens[1]); hk.Read("pattern-len2", lens[1]);
} }
@ -315,29 +265,23 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
RE_Matcher* re = new RE_Matcher(texts[0], texts[1]); RE_Matcher* re = new RE_Matcher(texts[0], texts[1]);
if ( ! re->Compile() ) if ( ! re->Compile() )
reporter->InternalError("failed compiling table/set key pattern: %s", reporter->InternalError("failed compiling table/set key pattern: %s", re->PatternText());
re->PatternText());
*pval = make_intrusive<PatternVal>(re); *pval = make_intrusive<PatternVal>(re);
} } break;
break;
case TYPE_RECORD: case TYPE_RECORD: {
{
auto rt = t->AsRecordType(); auto rt = t->AsRecordType();
int num_fields = rt->NumFields(); int num_fields = rt->NumFields();
std::vector<ValPtr> values; std::vector<ValPtr> values;
int i; int i;
for ( i = 0; i < num_fields; ++i ) for ( i = 0; i < num_fields; ++i ) {
{
ValPtr v; ValPtr v;
Attributes* a = rt->FieldDecl(i)->attrs.get(); Attributes* a = rt->FieldDecl(i)->attrs.get();
bool is_optional = (a && a->Find(ATTR_OPTIONAL)); bool is_optional = (a && a->Find(ATTR_OPTIONAL));
if ( ! RecoverOneVal(hk, rt->GetFieldType(i).get(), &v, is_optional, if ( ! RecoverOneVal(hk, rt->GetFieldType(i).get(), &v, is_optional, false) ) {
false) )
{
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
@ -346,10 +290,8 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
// abort() and broken the call tree that clang-tidy is relying on to // abort() and broken the call tree that clang-tidy is relying on to
// get the error described. // get the error described.
// NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Branch) // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Branch)
if ( ! (v || is_optional) ) if ( ! (v || is_optional) ) {
{ reporter->InternalError("didn't recover expected number of fields from HashKey");
reporter->InternalError(
"didn't recover expected number of fields from HashKey");
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
@ -359,39 +301,32 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
ASSERT(int(values.size()) == num_fields); ASSERT(int(values.size()) == num_fields);
auto rv = make_intrusive<RecordVal>(IntrusivePtr{NewRef{}, rt}, auto rv = make_intrusive<RecordVal>(IntrusivePtr{NewRef{}, rt}, false /* init_fields */);
false /* init_fields */);
for ( int i = 0; i < num_fields; ++i ) for ( int i = 0; i < num_fields; ++i )
rv->AppendField(std::move(values[i]), rt->GetFieldType(i)); rv->AppendField(std::move(values[i]), rt->GetFieldType(i));
*pval = std::move(rv); *pval = std::move(rv);
} } break;
break;
case TYPE_TABLE: case TYPE_TABLE: {
{
int n; int n;
hk.Read("table-size", n); hk.Read("table-size", n);
auto tt = t->AsTableType(); auto tt = t->AsTableType();
auto tv = make_intrusive<TableVal>(IntrusivePtr{NewRef{}, tt}); auto tv = make_intrusive<TableVal>(IntrusivePtr{NewRef{}, tt});
for ( int i = 0; i < n; ++i ) for ( int i = 0; i < n; ++i ) {
{
ValPtr key; ValPtr key;
if ( ! RecoverOneVal(hk, tt->GetIndices().get(), &key, false, false) ) if ( ! RecoverOneVal(hk, tt->GetIndices().get(), &key, false, false) ) {
{
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
if ( t->IsSet() ) if ( t->IsSet() )
tv->Assign(std::move(key), nullptr); tv->Assign(std::move(key), nullptr);
else else {
{
ValPtr value; ValPtr value;
if ( ! RecoverOneVal(hk, tt->Yield().get(), &value, false, false) ) if ( ! RecoverOneVal(hk, tt->Yield().get(), &value, false, false) ) {
{
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
@ -400,27 +335,22 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
} }
*pval = std::move(tv); *pval = std::move(tv);
} } break;
break;
case TYPE_VECTOR: case TYPE_VECTOR: {
{
unsigned int n; unsigned int n;
hk.Read("vector-size", n); hk.Read("vector-size", n);
auto vt = t->AsVectorType(); auto vt = t->AsVectorType();
auto vv = make_intrusive<VectorVal>(IntrusivePtr{NewRef{}, vt}); auto vv = make_intrusive<VectorVal>(IntrusivePtr{NewRef{}, vt});
for ( unsigned int i = 0; i < n; ++i ) for ( unsigned int i = 0; i < n; ++i ) {
{
unsigned int index; unsigned int index;
hk.Read("vector-idx", index); hk.Read("vector-idx", index);
bool have_val; bool have_val;
hk.Read("vector-idx-present", have_val); hk.Read("vector-idx-present", have_val);
ValPtr value; ValPtr value;
if ( have_val && if ( have_val && ! RecoverOneVal(hk, vt->Yield().get(), &value, false, false) ) {
! RecoverOneVal(hk, vt->Yield().get(), &value, false, false) )
{
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
@ -429,18 +359,15 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
} }
*pval = std::move(vv); *pval = std::move(vv);
} } break;
break;
case TYPE_LIST: case TYPE_LIST: {
{
int n; int n;
hk.Read("list-size", n); hk.Read("list-size", n);
auto tl = t->AsTypeList(); auto tl = t->AsTypeList();
auto lv = make_intrusive<ListVal>(TYPE_ANY); auto lv = make_intrusive<ListVal>(TYPE_ANY);
for ( int i = 0; i < n; ++i ) for ( int i = 0; i < n; ++i ) {
{
ValPtr v; ValPtr v;
Type* it = tl->GetTypes()[i].get(); Type* it = tl->GetTypes()[i].get();
if ( ! RecoverOneVal(hk, it, &v, false, false) ) if ( ! RecoverOneVal(hk, it, &v, false, false) )
@ -449,55 +376,45 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
} }
*pval = std::move(lv); *pval = std::move(lv);
} } break;
break;
default: default: {
{
reporter->InternalError("bad index type in CompositeHash::RecoverOneVal"); reporter->InternalError("bad index type in CompositeHash::RecoverOneVal");
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
} }
} } break;
break;
case TYPE_INTERNAL_STRING: case TYPE_INTERNAL_STRING: {
{
int n = hk.Size(); int n = hk.Size();
if ( ! singleton ) if ( ! singleton ) {
{
hk.Read("string-len", n); hk.Read("string-len", n);
hk.EnsureReadSpace(n); hk.EnsureReadSpace(n);
} }
*pval = make_intrusive<StringVal>(new String((const byte_vec)hk.KeyAtRead(), n, true)); *pval = make_intrusive<StringVal>(new String((const byte_vec)hk.KeyAtRead(), n, true));
hk.SkipRead("string", n); hk.SkipRead("string", n);
} } break;
break;
case TYPE_INTERNAL_ERROR: case TYPE_INTERNAL_ERROR: break;
break;
} }
return true; return true;
} }
bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type_check, bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type_check, bool optional,
bool optional, bool singleton) const bool singleton) const {
{
InternalTypeTag t = bt->InternalType(); InternalTypeTag t = bt->InternalType();
if ( type_check && v ) if ( type_check && v ) {
{
InternalTypeTag vt = v->GetType()->InternalType(); InternalTypeTag vt = v->GetType()->InternalType();
if ( vt != t ) if ( vt != t )
return false; return false;
} }
if ( optional ) if ( optional ) {
{
// Add a marker saying whether the optional field is set. // Add a marker saying whether the optional field is set.
hk.Write("optional", v != nullptr); hk.Write("optional", v != nullptr);
@ -510,15 +427,10 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
if ( ! v ) if ( ! v )
return false; return false;
switch ( t ) switch ( t ) {
{ case TYPE_INTERNAL_INT: hk.Write("int", v->AsInt()); break;
case TYPE_INTERNAL_INT:
hk.Write("int", v->AsInt());
break;
case TYPE_INTERNAL_UNSIGNED: case TYPE_INTERNAL_UNSIGNED: hk.Write("unsigned", v->AsCount()); break;
hk.Write("unsigned", v->AsCount());
break;
case TYPE_INTERNAL_ADDR: case TYPE_INTERNAL_ADDR:
if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) if ( ! EnsureTypeReserve(hk, v, bt, type_check) )
@ -541,17 +453,12 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
hk.Write("subnet-width", v->AsSubNet().Length()); hk.Write("subnet-width", v->AsSubNet().Length());
break; break;
case TYPE_INTERNAL_DOUBLE: case TYPE_INTERNAL_DOUBLE: hk.Write("double", v->InternalDouble()); break;
hk.Write("double", v->InternalDouble());
break;
case TYPE_INTERNAL_VOID: case TYPE_INTERNAL_VOID:
case TYPE_INTERNAL_OTHER: case TYPE_INTERNAL_OTHER: {
{ switch ( v->GetType()->Tag() ) {
switch ( v->GetType()->Tag() ) case TYPE_FUNC: {
{
case TYPE_FUNC:
{
auto f = v->AsFunc(); auto f = v->AsFunc();
if ( ! func_to_func_id ) if ( ! func_to_func_id )
@ -560,8 +467,7 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
auto id_mapping = func_to_func_id->find(f); auto id_mapping = func_to_func_id->find(f);
uint32_t id; uint32_t id;
if ( id_mapping == func_to_func_id->end() ) if ( id_mapping == func_to_func_id->end() ) {
{
// We need the pointer to stick around // We need the pointer to stick around
// for our lifetime, so we have to get // for our lifetime, so we have to get
// a non-const version we can ref. // a non-const version we can ref.
@ -575,22 +481,17 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
id = id_mapping->second; id = id_mapping->second;
hk.Write("func", id); hk.Write("func", id);
} } break;
break;
case TYPE_PATTERN: case TYPE_PATTERN: {
{ const char* texts[2] = {v->AsPattern()->PatternText(), v->AsPattern()->AnywherePatternText()};
const char* texts[2] = {v->AsPattern()->PatternText(),
v->AsPattern()->AnywherePatternText()};
uint64_t lens[2] = {strlen(texts[0]) + 1, strlen(texts[1]) + 1}; uint64_t lens[2] = {strlen(texts[0]) + 1, strlen(texts[1]) + 1};
if ( ! singleton ) if ( ! singleton ) {
{
hk.Write("pattern-len1", lens[0]); hk.Write("pattern-len1", lens[0]);
hk.Write("pattern-len2", lens[1]); hk.Write("pattern-len2", lens[1]);
} }
else else {
{
hk.Reserve("pattern", lens[0] + lens[1]); hk.Reserve("pattern", lens[0] + lens[1]);
hk.Allocate(); hk.Allocate();
} }
@ -600,8 +501,7 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
break; break;
} }
case TYPE_RECORD: case TYPE_RECORD: {
{
auto rv = v->AsRecordVal(); auto rv = v->AsRecordVal();
auto rt = bt->AsRecordType(); auto rt = bt->AsRecordType();
int num_fields = rt->NumFields(); int num_fields = rt->NumFields();
@ -609,8 +509,7 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) if ( ! EnsureTypeReserve(hk, v, bt, type_check) )
return false; return false;
for ( int i = 0; i < num_fields; ++i ) for ( int i = 0; i < num_fields; ++i ) {
{
auto rv_i = rv->GetField(i); auto rv_i = rv->GetField(i);
Attributes* a = rt->FieldDecl(i)->attrs.get(); Attributes* a = rt->FieldDecl(i)->attrs.get();
@ -619,15 +518,14 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
if ( ! (rv_i || optional_attr) ) if ( ! (rv_i || optional_attr) )
return false; return false;
if ( ! SingleValHash(hk, rv_i.get(), rt->GetFieldType(i).get(), type_check, if ( ! SingleValHash(hk, rv_i.get(), rt->GetFieldType(i).get(), type_check, optional_attr,
optional_attr, false) ) false) )
return false; return false;
} }
break; break;
} }
case TYPE_TABLE: case TYPE_TABLE: {
{
if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) if ( ! EnsureTypeReserve(hk, v, bt, type_check) )
return false; return false;
@ -636,28 +534,22 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
hk.Write("table-size", tv->Size()); hk.Write("table-size", tv->Size());
for ( auto& kv : *hashkeys ) for ( auto& kv : *hashkeys ) {
{
auto key = kv.second; auto key = kv.second;
if ( ! SingleValHash(hk, key.get(), key->GetType().get(), type_check, false, if ( ! SingleValHash(hk, key.get(), key->GetType().get(), type_check, false, false) )
false) )
return false; return false;
if ( ! v->GetType()->IsSet() ) if ( ! v->GetType()->IsSet() ) {
{
auto val = const_cast<TableVal*>(tv)->FindOrDefault(key); auto val = const_cast<TableVal*>(tv)->FindOrDefault(key);
if ( ! SingleValHash(hk, val.get(), val->GetType().get(), type_check, if ( ! SingleValHash(hk, val.get(), val->GetType().get(), type_check, false, false) )
false, false) )
return false; return false;
} }
} }
} } break;
break;
case TYPE_VECTOR: case TYPE_VECTOR: {
{
if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) if ( ! EnsureTypeReserve(hk, v, bt, type_check) )
return false; return false;
@ -666,25 +558,19 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
hk.Write("vector-size", vv->Size()); hk.Write("vector-size", vv->Size());
for ( unsigned int i = 0; i < vv->Size(); ++i ) for ( unsigned int i = 0; i < vv->Size(); ++i ) {
{
auto val = vv->ValAt(i); auto val = vv->ValAt(i);
hk.Write("vector-idx", i); hk.Write("vector-idx", i);
hk.Write("vector-idx-present", val != nullptr); hk.Write("vector-idx-present", val != nullptr);
if ( val && ! SingleValHash(hk, val.get(), vt->Yield().get(), type_check, if ( val && ! SingleValHash(hk, val.get(), vt->Yield().get(), type_check, false, false) )
false, false) )
return false; return false;
} }
} } break;
break;
case TYPE_LIST: case TYPE_LIST: {
{ if ( ! hk.IsAllocated() ) {
if ( ! hk.IsAllocated() ) if ( ! ReserveSingleTypeKeySize(hk, bt, v, type_check, false, false, false) )
{
if ( ! ReserveSingleTypeKeySize(hk, bt, v, type_check, false, false,
false) )
return false; return false;
hk.Allocate(); hk.Allocate();
@ -694,18 +580,14 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
hk.Write("list-size", lv->Length()); hk.Write("list-size", lv->Length());
for ( int i = 0; i < lv->Length(); ++i ) for ( int i = 0; i < lv->Length(); ++i ) {
{
Val* entry_val = lv->Idx(i).get(); Val* entry_val = lv->Idx(i).get();
if ( ! SingleValHash(hk, entry_val, entry_val->GetType().get(), type_check, if ( ! SingleValHash(hk, entry_val, entry_val->GetType().get(), type_check, false, false) )
false, false) )
return false; return false;
} }
} } break;
break;
default: default: {
{
reporter->InternalError("bad index type in CompositeHash::SingleValHash"); reporter->InternalError("bad index type in CompositeHash::SingleValHash");
return false; return false;
} }
@ -714,8 +596,7 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
break; // case TYPE_INTERNAL_VOID/OTHER break; // case TYPE_INTERNAL_VOID/OTHER
} }
case TYPE_INTERNAL_STRING: case TYPE_INTERNAL_STRING: {
{
if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) if ( ! EnsureTypeReserve(hk, v, bt, type_check) )
return false; return false;
@ -725,18 +606,15 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
hk.Write("string-len", sval->Len()); hk.Write("string-len", sval->Len());
hk.Write("string", sval->Bytes(), sval->Len()); hk.Write("string", sval->Bytes(), sval->Len());
} } break;
break;
default: default: return false;
return false;
} }
return true; return true;
} }
bool CompositeHash::EnsureTypeReserve(HashKey& hk, const Val* v, Type* bt, bool type_check) const bool CompositeHash::EnsureTypeReserve(HashKey& hk, const Val* v, Type* bt, bool type_check) const {
{
if ( hk.IsAllocated() ) if ( hk.IsAllocated() )
return true; return true;
@ -745,96 +623,71 @@ bool CompositeHash::EnsureTypeReserve(HashKey& hk, const Val* v, Type* bt, bool
hk.Allocate(); hk.Allocate();
return true; return true;
} }
bool CompositeHash::ReserveKeySize(HashKey& hk, const Val* v, bool type_check, bool CompositeHash::ReserveKeySize(HashKey& hk, const Val* v, bool type_check, bool calc_static_size) const {
bool calc_static_size) const
{
const auto& tl = type->GetTypes(); const auto& tl = type->GetTypes();
for ( auto i = 0u; i < tl.size(); ++i ) for ( auto i = 0u; i < tl.size(); ++i ) {
{ if ( ! ReserveSingleTypeKeySize(hk, tl[i].get(), v ? v->AsListVal()->Idx(i).get() : nullptr, type_check, false,
if ( ! ReserveSingleTypeKeySize(hk, tl[i].get(), v ? v->AsListVal()->Idx(i).get() : nullptr, calc_static_size, is_singleton) )
type_check, false, calc_static_size, is_singleton) )
return false; return false;
} }
return true; return true;
} }
bool CompositeHash::ReserveSingleTypeKeySize(HashKey& hk, Type* bt, const Val* v, bool type_check, bool CompositeHash::ReserveSingleTypeKeySize(HashKey& hk, Type* bt, const Val* v, bool type_check, bool optional,
bool optional, bool calc_static_size, bool calc_static_size, bool singleton) const {
bool singleton) const
{
InternalTypeTag t = bt->InternalType(); InternalTypeTag t = bt->InternalType();
if ( optional ) if ( optional ) {
{
hk.ReserveType<bool>("optional"); hk.ReserveType<bool>("optional");
if ( ! v ) if ( ! v )
return true; return true;
} }
if ( type_check && v ) if ( type_check && v ) {
{
InternalTypeTag vt = v->GetType()->InternalType(); InternalTypeTag vt = v->GetType()->InternalType();
if ( vt != t ) if ( vt != t )
return false; return false;
} }
switch ( t ) switch ( t ) {
{ case TYPE_INTERNAL_INT: hk.ReserveType<zeek_int_t>("int"); break;
case TYPE_INTERNAL_INT:
hk.ReserveType<zeek_int_t>("int");
break;
case TYPE_INTERNAL_UNSIGNED: case TYPE_INTERNAL_UNSIGNED: hk.ReserveType<zeek_int_t>("unsigned"); break;
hk.ReserveType<zeek_int_t>("unsigned");
break;
case TYPE_INTERNAL_ADDR: case TYPE_INTERNAL_ADDR: hk.Reserve("addr", sizeof(uint32_t) * 4, sizeof(uint32_t)); break;
hk.Reserve("addr", sizeof(uint32_t) * 4, sizeof(uint32_t));
break;
case TYPE_INTERNAL_SUBNET: case TYPE_INTERNAL_SUBNET: hk.Reserve("subnet", sizeof(uint32_t) * 5, sizeof(uint32_t)); break;
hk.Reserve("subnet", sizeof(uint32_t) * 5, sizeof(uint32_t));
break;
case TYPE_INTERNAL_DOUBLE: case TYPE_INTERNAL_DOUBLE: hk.ReserveType<double>("double"); break;
hk.ReserveType<double>("double");
break;
case TYPE_INTERNAL_VOID: case TYPE_INTERNAL_VOID:
case TYPE_INTERNAL_OTHER: case TYPE_INTERNAL_OTHER: {
{ switch ( bt->Tag() ) {
switch ( bt->Tag() ) case TYPE_FUNC: {
{
case TYPE_FUNC:
{
hk.ReserveType<uint32_t>("func"); hk.ReserveType<uint32_t>("func");
break; break;
} }
case TYPE_PATTERN: case TYPE_PATTERN: {
{
if ( ! v ) if ( ! v )
return (optional && ! calc_static_size); return (optional && ! calc_static_size);
if ( ! singleton ) if ( ! singleton ) {
{
hk.ReserveType<uint64_t>("pattern-len1"); hk.ReserveType<uint64_t>("pattern-len1");
hk.ReserveType<uint64_t>("pattern-len2"); hk.ReserveType<uint64_t>("pattern-len2");
} }
// +1 in the following to include null terminators // +1 in the following to include null terminators
hk.Reserve("pattern-string1", strlen(v->AsPattern()->PatternText()) + 1, 0); hk.Reserve("pattern-string1", strlen(v->AsPattern()->PatternText()) + 1, 0);
hk.Reserve("pattern-string1", strlen(v->AsPattern()->AnywherePatternText()) + 1, hk.Reserve("pattern-string1", strlen(v->AsPattern()->AnywherePatternText()) + 1, 0);
0);
break; break;
} }
case TYPE_RECORD: case TYPE_RECORD: {
{
if ( ! v ) if ( ! v )
return (optional && ! calc_static_size); return (optional && ! calc_static_size);
@ -842,22 +695,19 @@ bool CompositeHash::ReserveSingleTypeKeySize(HashKey& hk, Type* bt, const Val* v
RecordType* rt = bt->AsRecordType(); RecordType* rt = bt->AsRecordType();
int num_fields = rt->NumFields(); int num_fields = rt->NumFields();
for ( int i = 0; i < num_fields; ++i ) for ( int i = 0; i < num_fields; ++i ) {
{
Attributes* a = rt->FieldDecl(i)->attrs.get(); Attributes* a = rt->FieldDecl(i)->attrs.get();
bool optional_attr = (a && a->Find(ATTR_OPTIONAL)); bool optional_attr = (a && a->Find(ATTR_OPTIONAL));
auto rv_v = rv ? rv->GetField(i) : nullptr; auto rv_v = rv ? rv->GetField(i) : nullptr;
if ( ! ReserveSingleTypeKeySize(hk, rt->GetFieldType(i).get(), rv_v.get(), if ( ! ReserveSingleTypeKeySize(hk, rt->GetFieldType(i).get(), rv_v.get(), type_check,
type_check, optional_attr, calc_static_size, optional_attr, calc_static_size, false) )
false) )
return false; return false;
} }
break; break;
} }
case TYPE_TABLE: case TYPE_TABLE: {
{
if ( ! v ) if ( ! v )
return (optional && ! calc_static_size); return (optional && ! calc_static_size);
@ -866,21 +716,17 @@ bool CompositeHash::ReserveSingleTypeKeySize(HashKey& hk, Type* bt, const Val* v
hk.ReserveType<int>("table-size"); hk.ReserveType<int>("table-size");
for ( auto& kv : *hashkeys ) for ( auto& kv : *hashkeys ) {
{
auto key = kv.second; auto key = kv.second;
if ( ! ReserveSingleTypeKeySize(hk, key->GetType().get(), key.get(), if ( ! ReserveSingleTypeKeySize(hk, key->GetType().get(), key.get(), type_check, false,
type_check, false, calc_static_size, calc_static_size, false) )
false) )
return false; return false;
if ( ! bt->IsSet() ) if ( ! bt->IsSet() ) {
{
auto val = const_cast<TableVal*>(tv)->FindOrDefault(key); auto val = const_cast<TableVal*>(tv)->FindOrDefault(key);
if ( ! ReserveSingleTypeKeySize(hk, val->GetType().get(), val.get(), if ( ! ReserveSingleTypeKeySize(hk, val->GetType().get(), val.get(), type_check, false,
type_check, false, calc_static_size, calc_static_size, false) )
false) )
return false; return false;
} }
} }
@ -888,48 +734,40 @@ bool CompositeHash::ReserveSingleTypeKeySize(HashKey& hk, Type* bt, const Val* v
break; break;
} }
case TYPE_VECTOR: case TYPE_VECTOR: {
{
if ( ! v ) if ( ! v )
return (optional && ! calc_static_size); return (optional && ! calc_static_size);
hk.ReserveType<int>("vector-size"); hk.ReserveType<int>("vector-size");
VectorVal* vv = const_cast<VectorVal*>(v->AsVectorVal()); VectorVal* vv = const_cast<VectorVal*>(v->AsVectorVal());
for ( unsigned int i = 0; i < vv->Size(); ++i ) for ( unsigned int i = 0; i < vv->Size(); ++i ) {
{
auto val = vv->ValAt(i); auto val = vv->ValAt(i);
hk.ReserveType<unsigned int>("vector-idx"); hk.ReserveType<unsigned int>("vector-idx");
hk.ReserveType<unsigned int>("vector-idx-present"); hk.ReserveType<unsigned int>("vector-idx-present");
if ( val && ! ReserveSingleTypeKeySize( if ( val && ! ReserveSingleTypeKeySize(hk, bt->AsVectorType()->Yield().get(), val.get(),
hk, bt->AsVectorType()->Yield().get(), val.get(),
type_check, false, calc_static_size, false) ) type_check, false, calc_static_size, false) )
return false; return false;
} }
break; break;
} }
case TYPE_LIST: case TYPE_LIST: {
{
if ( ! v ) if ( ! v )
return (optional && ! calc_static_size); return (optional && ! calc_static_size);
hk.ReserveType<int>("list-size"); hk.ReserveType<int>("list-size");
ListVal* lv = const_cast<ListVal*>(v->AsListVal()); ListVal* lv = const_cast<ListVal*>(v->AsListVal());
for ( int i = 0; i < lv->Length(); ++i ) for ( int i = 0; i < lv->Length(); ++i ) {
{ if ( ! ReserveSingleTypeKeySize(hk, lv->Idx(i)->GetType().get(), lv->Idx(i).get(), type_check,
if ( ! ReserveSingleTypeKeySize(hk, lv->Idx(i)->GetType().get(), false, calc_static_size, false) )
lv->Idx(i).get(), type_check, false,
calc_static_size, false) )
return false; return false;
} }
break; break;
} }
default: default: {
{ reporter->InternalError("bad index type in CompositeHash::ReserveSingleTypeKeySize");
reporter->InternalError(
"bad index type in CompositeHash::ReserveSingleTypeKeySize");
return false; return false;
} }
} }
@ -945,11 +783,10 @@ bool CompositeHash::ReserveSingleTypeKeySize(HashKey& hk, Type* bt, const Val* v
hk.Reserve("string", v->AsString()->Len()); hk.Reserve("string", v->AsString()->Len());
break; break;
case TYPE_INTERNAL_ERROR: case TYPE_INTERNAL_ERROR: return false;
return false;
} }
return true; return true;
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -7,21 +7,18 @@
#include "zeek/Func.h" #include "zeek/Func.h"
#include "zeek/Type.h" #include "zeek/Type.h"
namespace zeek namespace zeek {
{
class ListVal; class ListVal;
using ListValPtr = zeek::IntrusivePtr<ListVal>; using ListValPtr = zeek::IntrusivePtr<ListVal>;
} // namespace zeek } // namespace zeek
namespace zeek::detail namespace zeek::detail {
{
class HashKey; class HashKey;
class CompositeHash class CompositeHash {
{
public: public:
explicit CompositeHash(TypeListPtr composite_type); explicit CompositeHash(TypeListPtr composite_type);
@ -33,15 +30,13 @@ public:
ListValPtr RecoverVals(const HashKey& k) const; ListValPtr RecoverVals(const HashKey& k) const;
protected: protected:
bool SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type_check, bool optional, bool SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type_check, bool optional, bool singleton) const;
bool singleton) const;
// Recovers just one Val of possibly many; called from RecoverVals. // Recovers just one Val of possibly many; called from RecoverVals.
// Upon return, pval will point to the recovered Val of type t. // Upon return, pval will point to the recovered Val of type t.
// Returns and updated kp for the next Val. Calls reporter->InternalError() // Returns and updated kp for the next Val. Calls reporter->InternalError()
// upon errors, so there is no return value for invalid input. // upon errors, so there is no return value for invalid input.
bool RecoverOneVal(const HashKey& k, Type* t, ValPtr* pval, bool optional, bool RecoverOneVal(const HashKey& k, Type* t, ValPtr* pval, bool optional, bool singleton) const;
bool singleton) const;
// Compute the size of the composite key. If v is non-nil then // Compute the size of the composite key. If v is non-nil then
// the value is computed for the particular list of values. // the value is computed for the particular list of values.
@ -60,14 +55,13 @@ protected:
// lower for the common case of these not being needed. // lower for the common case of these not being needed.
std::unique_ptr<std::unordered_map<const Func*, uint32_t>> func_to_func_id; std::unique_ptr<std::unordered_map<const Func*, uint32_t>> func_to_func_id;
std::unique_ptr<std::vector<FuncPtr>> func_id_to_func; std::unique_ptr<std::vector<FuncPtr>> func_id_to_func;
void BuildFuncMappings() void BuildFuncMappings() {
{
func_to_func_id = std::make_unique<std::unordered_map<const Func*, uint32_t>>(); func_to_func_id = std::make_unique<std::unordered_map<const Func*, uint32_t>>();
func_id_to_func = std::make_unique<std::vector<FuncPtr>>(); func_id_to_func = std::make_unique<std::vector<FuncPtr>>();
} }
TypeListPtr type; TypeListPtr type;
bool is_singleton = false; // if just one type in index bool is_singleton = false; // if just one type in index
}; };
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -22,18 +22,13 @@
#include "zeek/packet_analysis/protocol/tcp/TCP.h" #include "zeek/packet_analysis/protocol/tcp/TCP.h"
#include "zeek/session/Manager.h" #include "zeek/session/Manager.h"
namespace zeek namespace zeek {
{
uint64_t Connection::total_connections = 0; uint64_t Connection::total_connections = 0;
uint64_t Connection::current_connections = 0; uint64_t Connection::current_connections = 0;
Connection::Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, Connection::Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, const Packet* pkt)
const Packet* pkt) : Session(t, connection_timeout, connection_status_update, detail::connection_status_update_interval), key(k) {
: Session(t, connection_timeout, connection_status_update,
detail::connection_status_update_interval),
key(k)
{
orig_addr = id->src_addr; orig_addr = id->src_addr;
resp_addr = id->dst_addr; resp_addr = id->dst_addr;
orig_port = id->src_port; orig_port = id->src_port;
@ -73,10 +68,9 @@ Connection::Connection(const detail::ConnKey& k, double t, const ConnTuple* id,
++total_connections; ++total_connections;
encapsulation = pkt->encap; encapsulation = pkt->encap;
} }
Connection::~Connection() Connection::~Connection() {
{
if ( ! finished ) if ( ! finished )
reporter->InternalError("Done() not called before destruction of Connection"); reporter->InternalError("Done() not called before destruction of Connection");
@ -88,18 +82,13 @@ Connection::~Connection()
delete adapter; delete adapter;
--current_connections; --current_connections;
} }
void Connection::CheckEncapsulation(const std::shared_ptr<EncapsulationStack>& arg_encap) void Connection::CheckEncapsulation(const std::shared_ptr<EncapsulationStack>& arg_encap) {
{ if ( encapsulation && arg_encap ) {
if ( encapsulation && arg_encap ) if ( *encapsulation != *arg_encap ) {
{ if ( tunnel_changed && (zeek::detail::tunnel_max_changes_per_connection == 0 ||
if ( *encapsulation != *arg_encap ) tunnel_changes < zeek::detail::tunnel_max_changes_per_connection) ) {
{
if ( tunnel_changed &&
(zeek::detail::tunnel_max_changes_per_connection == 0 ||
tunnel_changes < zeek::detail::tunnel_max_changes_per_connection) )
{
tunnel_changes++; tunnel_changes++;
EnqueueEvent(tunnel_changed, nullptr, GetVal(), arg_encap->ToVal()); EnqueueEvent(tunnel_changed, nullptr, GetVal(), arg_encap->ToVal());
} }
@ -108,10 +97,8 @@ void Connection::CheckEncapsulation(const std::shared_ptr<EncapsulationStack>& a
} }
} }
else if ( encapsulation ) else if ( encapsulation ) {
{ if ( tunnel_changed ) {
if ( tunnel_changed )
{
EncapsulationStack empty; EncapsulationStack empty;
EnqueueEvent(tunnel_changed, nullptr, GetVal(), empty.ToVal()); EnqueueEvent(tunnel_changed, nullptr, GetVal(), empty.ToVal());
} }
@ -119,23 +106,19 @@ void Connection::CheckEncapsulation(const std::shared_ptr<EncapsulationStack>& a
encapsulation = nullptr; encapsulation = nullptr;
} }
else if ( arg_encap ) else if ( arg_encap ) {
{
if ( tunnel_changed ) if ( tunnel_changed )
EnqueueEvent(tunnel_changed, nullptr, GetVal(), arg_encap->ToVal()); EnqueueEvent(tunnel_changed, nullptr, GetVal(), arg_encap->ToVal());
encapsulation = std::make_shared<EncapsulationStack>(*arg_encap); encapsulation = std::make_shared<EncapsulationStack>(*arg_encap);
} }
} }
void Connection::Done() void Connection::Done() {
{
finished = 1; finished = 1;
if ( adapter ) if ( adapter ) {
{ if ( ConnTransport() == TRANSPORT_TCP ) {
if ( ConnTransport() == TRANSPORT_TCP )
{
auto* ta = static_cast<packet_analysis::TCP::TCPSessionAdapter*>(adapter); auto* ta = static_cast<packet_analysis::TCP::TCPSessionAdapter*>(adapter);
assert(ta->IsAnalyzer("TCP")); assert(ta->IsAnalyzer("TCP"));
analyzer::tcp::TCP_Endpoint* to = ta->Orig(); analyzer::tcp::TCP_Endpoint* to = ta->Orig();
@ -147,18 +130,16 @@ void Connection::Done()
if ( ! adapter->IsFinished() ) if ( ! adapter->IsFinished() )
adapter->Done(); adapter->Done();
} }
} }
void Connection::NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, int caplen, void Connection::NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, int caplen, const u_char*& data,
const u_char*& data, int& record_packet, int& record_content, int& record_packet, int& record_content,
// arguments for reproducing packets // arguments for reproducing packets
const Packet* pkt) const Packet* pkt) {
{
run_state::current_timestamp = t; run_state::current_timestamp = t;
run_state::current_pkt = pkt; run_state::current_pkt = pkt;
if ( adapter ) if ( adapter ) {
{
if ( adapter->Skipping() ) if ( adapter->Skipping() )
return; return;
@ -173,18 +154,12 @@ void Connection::NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, i
run_state::current_timestamp = 0; run_state::current_timestamp = 0;
run_state::current_pkt = nullptr; run_state::current_pkt = nullptr;
} }
bool Connection::IsReuse(double t, const u_char* pkt) bool Connection::IsReuse(double t, const u_char* pkt) { return adapter && adapter->IsReuse(t, pkt); }
{
return adapter && adapter->IsReuse(t, pkt);
}
bool Connection::ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scaling_threshold, bool Connection::ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scaling_threshold, uint32_t scaling_base) {
uint32_t scaling_base) if ( ++counter == scaling_threshold ) {
{
if ( ++counter == scaling_threshold )
{
AddHistory(code); AddHistory(code);
auto new_threshold = scaling_threshold * scaling_base; auto new_threshold = scaling_threshold * scaling_base;
@ -202,10 +177,9 @@ bool Connection::ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scal
} }
return false; return false;
} }
void Connection::HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t threshold) void Connection::HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t threshold) {
{
if ( ! e ) if ( ! e )
return; return;
@ -215,15 +189,13 @@ void Connection::HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t
return; return;
EnqueueEvent(e, nullptr, GetVal(), val_mgr->Bool(is_orig), val_mgr->Count(threshold)); EnqueueEvent(e, nullptr, GetVal(), val_mgr->Bool(is_orig), val_mgr->Count(threshold));
} }
namespace namespace {
{
// Flip everything that needs to be flipped in the connection // Flip everything that needs to be flipped in the connection
// record that is known on this level. This needs to align // record that is known on this level. This needs to align
// with GetVal() and connection's layout in init-bare. // with GetVal() and connection's layout in init-bare.
void flip_conn_val(const RecordValPtr& conn_val) void flip_conn_val(const RecordValPtr& conn_val) {
{
// Flip the the conn_id (c$id). // Flip the the conn_id (c$id).
const auto& id_val = conn_val->GetField<zeek::RecordVal>(0); const auto& id_val = conn_val->GetField<zeek::RecordVal>(0);
const auto& tmp_addr = id_val->GetField<zeek::AddrVal>(0); const auto& tmp_addr = id_val->GetField<zeek::AddrVal>(0);
@ -237,13 +209,11 @@ void flip_conn_val(const RecordValPtr& conn_val)
const auto& tmp_endp = conn_val->GetField<zeek::RecordVal>(1); const auto& tmp_endp = conn_val->GetField<zeek::RecordVal>(1);
conn_val->Assign(1, conn_val->GetField(2)); conn_val->Assign(1, conn_val->GetField(2));
conn_val->Assign(2, tmp_endp); conn_val->Assign(2, tmp_endp);
} }
} } // namespace
const RecordValPtr& Connection::GetVal() const RecordValPtr& Connection::GetVal() {
{ if ( ! conn_val ) {
if ( ! conn_val )
{
conn_val = make_intrusive<RecordVal>(id::connection); conn_val = make_intrusive<RecordVal>(id::connection);
TransportProto prot_type = ConnTransport(); TransportProto prot_type = ConnTransport();
@ -301,8 +271,7 @@ const RecordValPtr& Connection::GetVal()
conn_val->AssignTime(3, start_time); // ### conn_val->AssignTime(3, start_time); // ###
conn_val->AssignInterval(4, last_time - start_time); conn_val->AssignInterval(4, last_time - start_time);
if ( ! history.empty() ) if ( ! history.empty() ) {
{
auto v = conn_val->GetFieldAs<StringVal>(6); auto v = conn_val->GetFieldAs<StringVal>(6);
if ( *v != history ) if ( *v != history )
conn_val->Assign(6, history); conn_val->Assign(6, history);
@ -311,54 +280,42 @@ const RecordValPtr& Connection::GetVal()
conn_val->SetOrigin(this); conn_val->SetOrigin(this);
return conn_val; return conn_val;
} }
analyzer::Analyzer* Connection::FindAnalyzer(analyzer::ID id) analyzer::Analyzer* Connection::FindAnalyzer(analyzer::ID id) { return adapter ? adapter->FindChild(id) : nullptr; }
{
return adapter ? adapter->FindChild(id) : nullptr;
}
analyzer::Analyzer* Connection::FindAnalyzer(const zeek::Tag& tag) analyzer::Analyzer* Connection::FindAnalyzer(const zeek::Tag& tag) {
{
return adapter ? adapter->FindChild(tag) : nullptr; return adapter ? adapter->FindChild(tag) : nullptr;
} }
analyzer::Analyzer* Connection::FindAnalyzer(const char* name) analyzer::Analyzer* Connection::FindAnalyzer(const char* name) { return adapter->FindChild(name); }
{
return adapter->FindChild(name);
}
void Connection::AppendAddl(const char* str) void Connection::AppendAddl(const char* str) {
{
const auto& cv = GetVal(); const auto& cv = GetVal();
const char* old = cv->GetFieldAs<StringVal>(6)->CheckString(); const char* old = cv->GetFieldAs<StringVal>(6)->CheckString();
const char* format = *old ? "%s %s" : "%s%s"; const char* format = *old ? "%s %s" : "%s%s";
cv->Assign(6, util::fmt(format, old, str)); cv->Assign(6, util::fmt(format, old, str));
} }
void Connection::Match(detail::Rule::PatternType type, const u_char* data, int len, bool is_orig, void Connection::Match(detail::Rule::PatternType type, const u_char* data, int len, bool is_orig, bool bol, bool eol,
bool bol, bool eol, bool clear_state) bool clear_state) {
{
if ( primary_PIA ) if ( primary_PIA )
primary_PIA->Match(type, data, len, is_orig, bol, eol, clear_state); primary_PIA->Match(type, data, len, is_orig, bol, eol, clear_state);
} }
void Connection::RemovalEvent() void Connection::RemovalEvent() {
{
if ( connection_state_remove ) if ( connection_state_remove )
EnqueueEvent(connection_state_remove, nullptr, GetVal()); EnqueueEvent(connection_state_remove, nullptr, GetVal());
} }
void Connection::Weird(const char* name, const char* addl, const char* source) void Connection::Weird(const char* name, const char* addl, const char* source) {
{
weird = 1; weird = 1;
reporter->Weird(this, name, addl ? addl : "", source ? source : ""); reporter->Weird(this, name, addl ? addl : "", source ? source : "");
} }
void Connection::FlipRoles() void Connection::FlipRoles() {
{
IPAddr tmp_addr = resp_addr; IPAddr tmp_addr = resp_addr;
resp_addr = orig_addr; resp_addr = orig_addr;
orig_addr = tmp_addr; orig_addr = tmp_addr;
@ -393,25 +350,17 @@ void Connection::FlipRoles()
if ( connection_flipped ) if ( connection_flipped )
EnqueueEvent(connection_flipped, nullptr, GetVal()); EnqueueEvent(connection_flipped, nullptr, GetVal());
} }
void Connection::Describe(ODesc* d) const void Connection::Describe(ODesc* d) const {
{
session::Session::Describe(d); session::Session::Describe(d);
switch ( proto ) switch ( proto ) {
{ case TRANSPORT_TCP: d->Add("TCP"); break;
case TRANSPORT_TCP:
d->Add("TCP");
break;
case TRANSPORT_UDP: case TRANSPORT_UDP: d->Add("UDP"); break;
d->Add("UDP");
break;
case TRANSPORT_ICMP: case TRANSPORT_ICMP: d->Add("ICMP"); break;
d->Add("ICMP");
break;
case TRANSPORT_UNKNOWN: case TRANSPORT_UNKNOWN:
d->Add("unknown"); d->Add("unknown");
@ -419,8 +368,7 @@ void Connection::Describe(ODesc* d) const
break; break;
default: default: reporter->InternalError("unhandled transport type in Connection::Describe");
reporter->InternalError("unhandled transport type in Connection::Describe");
} }
d->SP(); d->SP();
@ -436,10 +384,9 @@ void Connection::Describe(ODesc* d) const
d->Add(ntohs(resp_port)); d->Add(ntohs(resp_port));
d->NL(); d->NL();
} }
void Connection::IDString(ODesc* d) const void Connection::IDString(ODesc* d) const {
{
d->Add(orig_addr); d->Add(orig_addr);
d->AddRaw(":", 1); d->AddRaw(":", 1);
d->Add(ntohs(orig_port)); d->Add(ntohs(orig_port));
@ -447,29 +394,23 @@ void Connection::IDString(ODesc* d) const
d->Add(resp_addr); d->Add(resp_addr);
d->AddRaw(":", 1); d->AddRaw(":", 1);
d->Add(ntohs(resp_port)); d->Add(ntohs(resp_port));
} }
void Connection::SetSessionAdapter(packet_analysis::IP::SessionAdapter* aa, analyzer::pia::PIA* pia) void Connection::SetSessionAdapter(packet_analysis::IP::SessionAdapter* aa, analyzer::pia::PIA* pia) {
{
adapter = aa; adapter = aa;
primary_PIA = pia; primary_PIA = pia;
} }
void Connection::CheckFlowLabel(bool is_orig, uint32_t flow_label) void Connection::CheckFlowLabel(bool is_orig, uint32_t flow_label) {
{
uint32_t& my_flow_label = is_orig ? orig_flow_label : resp_flow_label; uint32_t& my_flow_label = is_orig ? orig_flow_label : resp_flow_label;
if ( my_flow_label != flow_label ) if ( my_flow_label != flow_label ) {
{ if ( conn_val ) {
if ( conn_val )
{
RecordVal* endp = conn_val->GetFieldAs<RecordVal>(is_orig ? 1 : 2); RecordVal* endp = conn_val->GetFieldAs<RecordVal>(is_orig ? 1 : 2);
endp->Assign(4, flow_label); endp->Assign(4, flow_label);
} }
if ( connection_flow_label_changed && if ( connection_flow_label_changed && (is_orig ? saw_first_orig_packet : saw_first_resp_packet) ) {
(is_orig ? saw_first_orig_packet : saw_first_resp_packet) )
{
EnqueueEvent(connection_flow_label_changed, nullptr, GetVal(), val_mgr->Bool(is_orig), EnqueueEvent(connection_flow_label_changed, nullptr, GetVal(), val_mgr->Bool(is_orig),
val_mgr->Count(my_flow_label), val_mgr->Count(flow_label)); val_mgr->Count(my_flow_label), val_mgr->Count(flow_label));
} }
@ -481,11 +422,10 @@ void Connection::CheckFlowLabel(bool is_orig, uint32_t flow_label)
saw_first_orig_packet = 1; saw_first_orig_packet = 1;
else else
saw_first_resp_packet = 1; saw_first_resp_packet = 1;
} }
bool Connection::PermitWeird(const char* name, uint64_t threshold, uint64_t rate, double duration) bool Connection::PermitWeird(const char* name, uint64_t threshold, uint64_t rate, double duration) {
{
return detail::PermitWeird(weird_state, name, threshold, rate, duration); return detail::PermitWeird(weird_state, name, threshold, rate, duration);
} }
} // namespace zeek } // namespace zeek

View file

@ -19,8 +19,7 @@
#include "zeek/iosource/Packet.h" #include "zeek/iosource/Packet.h"
#include "zeek/session/Session.h" #include "zeek/session/Session.h"
namespace zeek namespace zeek {
{
class Connection; class Connection;
class EncapsulationStack; class EncapsulationStack;
@ -30,57 +29,47 @@ class RecordVal;
using ValPtr = IntrusivePtr<Val>; using ValPtr = IntrusivePtr<Val>;
using RecordValPtr = IntrusivePtr<RecordVal>; using RecordValPtr = IntrusivePtr<RecordVal>;
namespace session namespace session {
{
class Manager; class Manager;
} }
namespace detail namespace detail {
{
class Specific_RE_Matcher; class Specific_RE_Matcher;
class RuleEndpointState; class RuleEndpointState;
class RuleHdrTest; class RuleHdrTest;
} // namespace detail } // namespace detail
namespace analyzer namespace analyzer {
{
class Analyzer; class Analyzer;
} }
namespace packet_analysis::IP namespace packet_analysis::IP {
{
class SessionAdapter; class SessionAdapter;
} }
enum ConnEventToFlag enum ConnEventToFlag {
{
NUL_IN_LINE, NUL_IN_LINE,
SINGULAR_CR, SINGULAR_CR,
SINGULAR_LF, SINGULAR_LF,
NUM_EVENTS_TO_FLAG, NUM_EVENTS_TO_FLAG,
}; };
struct ConnTuple struct ConnTuple {
{
IPAddr src_addr; IPAddr src_addr;
IPAddr dst_addr; IPAddr dst_addr;
uint32_t src_port = 0; uint32_t src_port = 0;
uint32_t dst_port = 0; uint32_t dst_port = 0;
bool is_one_way = false; // if true, don't canonicalize order bool is_one_way = false; // if true, don't canonicalize order
TransportProto proto = TRANSPORT_UNKNOWN; TransportProto proto = TRANSPORT_UNKNOWN;
}; };
static inline int addr_port_canon_lt(const IPAddr& addr1, uint32_t p1, const IPAddr& addr2, static inline int addr_port_canon_lt(const IPAddr& addr1, uint32_t p1, const IPAddr& addr2, uint32_t p2) {
uint32_t p2)
{
return addr1 < addr2 || (addr1 == addr2 && p1 < p2); return addr1 < addr2 || (addr1 == addr2 && p1 < p2);
} }
class Connection final : public session::Session class Connection final : public session::Session {
{
public: public:
Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, const Packet* pkt);
const Packet* pkt);
~Connection() override; ~Connection() override;
/** /**
@ -109,8 +98,8 @@ public:
// If record_content is true, then its entire contents should // If record_content is true, then its entire contents should
// be recorded, otherwise just up through the transport header. // be recorded, otherwise just up through the transport header.
// Both are assumed set to true when called. // Both are assumed set to true when called.
void NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, int caplen, void NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, int caplen, const u_char*& data,
const u_char*& data, int& record_packet, int& record_content, int& record_packet, int& record_content,
// arguments for reproducing packets // arguments for reproducing packets
const Packet* pkt); const Packet* pkt);
@ -118,10 +107,8 @@ public:
// connection is in the session map. If it is removed, the key // connection is in the session map. If it is removed, the key
// should be marked invalid. // should be marked invalid.
const detail::ConnKey& Key() const { return key; } const detail::ConnKey& Key() const { return key; }
session::detail::Key SessionKey(bool copy) const override session::detail::Key SessionKey(bool copy) const override {
{ return session::detail::Key{&key, sizeof(key), session::detail::Key::CONNECTION_KEY_TYPE, copy};
return session::detail::Key{&key, sizeof(key), session::detail::Key::CONNECTION_KEY_TYPE,
copy};
} }
const IPAddr& OrigAddr() const { return orig_addr; } const IPAddr& OrigAddr() const { return orig_addr; }
@ -137,8 +124,7 @@ public:
analyzer::Analyzer* FindAnalyzer(const char* name); // find first in tree. analyzer::Analyzer* FindAnalyzer(const char* name); // find first in tree.
TransportProto ConnTransport() const { return proto; } TransportProto ConnTransport() const { return proto; }
std::string TransportIdentifier() const override std::string TransportIdentifier() const override {
{
if ( proto == TRANSPORT_TCP ) if ( proto == TRANSPORT_TCP )
return "tcp"; return "tcp";
else if ( proto == TRANSPORT_UDP ) else if ( proto == TRANSPORT_UDP )
@ -164,8 +150,8 @@ public:
*/ */
void AppendAddl(const char* str); void AppendAddl(const char* str);
void Match(detail::Rule::PatternType type, const u_char* data, int len, bool is_orig, bool bol, void Match(detail::Rule::PatternType type, const u_char* data, int len, bool is_orig, bool bol, bool eol,
bool eol, bool clear_state); bool clear_state);
/** /**
* Generates connection removal event(s). * Generates connection removal event(s).
@ -175,10 +161,8 @@ public:
void Weird(const char* name, const char* addl = "", const char* source = ""); void Weird(const char* name, const char* addl = "", const char* source = "");
bool DidWeird() const { return weird != 0; } bool DidWeird() const { return weird != 0; }
inline bool FlagEvent(ConnEventToFlag e) inline bool FlagEvent(ConnEventToFlag e) {
{ if ( e >= 0 && e < NUM_EVENTS_TO_FLAG ) {
if ( e >= 0 && e < NUM_EVENTS_TO_FLAG )
{
if ( suppress_event & (1 << e) ) if ( suppress_event & (1 << e) )
return false; return false;
suppress_event |= 1 << e; suppress_event |= 1 << e;
@ -196,10 +180,8 @@ public:
static uint64_t CurrentConnections() { return current_connections; } static uint64_t CurrentConnections() { return current_connections; }
// Returns true if the history was already seen, false otherwise. // Returns true if the history was already seen, false otherwise.
bool CheckHistory(uint32_t mask, char code) bool CheckHistory(uint32_t mask, char code) {
{ if ( (hist_seen & mask) == 0 ) {
if ( (hist_seen & mask) == 0 )
{
hist_seen |= mask; hist_seen |= mask;
AddHistory(code); AddHistory(code);
return false; return false;
@ -212,8 +194,7 @@ public:
// code if it has crossed the next scaling threshold. Scaling // code if it has crossed the next scaling threshold. Scaling
// is done in terms of powers of the third argument. // is done in terms of powers of the third argument.
// Returns true if the threshold was crossed, false otherwise. // Returns true if the threshold was crossed, false otherwise.
bool ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scaling_threshold, bool ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scaling_threshold, uint32_t scaling_base = 10);
uint32_t scaling_base = 10);
void HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t threshold); void HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t threshold);
@ -277,6 +258,6 @@ private:
// Count number of connections. // Count number of connections.
static uint64_t total_connections; static uint64_t total_connections;
static uint64_t current_connections; static uint64_t current_connections;
}; };
} // namespace zeek } // namespace zeek

View file

@ -8,12 +8,10 @@
#include "zeek/EquivClass.h" #include "zeek/EquivClass.h"
#include "zeek/Hash.h" #include "zeek/Hash.h"
namespace zeek::detail namespace zeek::detail {
{
DFA_State::DFA_State(int arg_state_num, const EquivClass* ec, NFA_state_list* arg_nfa_states, DFA_State::DFA_State(int arg_state_num, const EquivClass* ec, NFA_state_list* arg_nfa_states,
AcceptingSet* arg_accept) AcceptingSet* arg_accept) {
{
state_num = arg_state_num; state_num = arg_state_num;
num_sym = ec->NumClasses(); num_sym = ec->NumClasses();
nfa_states = arg_nfa_states; nfa_states = arg_nfa_states;
@ -26,41 +24,33 @@ DFA_State::DFA_State(int arg_state_num, const EquivClass* ec, NFA_state_list* ar
for ( int i = 0; i < num_sym; ++i ) for ( int i = 0; i < num_sym; ++i )
xtions[i] = DFA_UNCOMPUTED_STATE_PTR; xtions[i] = DFA_UNCOMPUTED_STATE_PTR;
} }
DFA_State::~DFA_State() DFA_State::~DFA_State() {
{
delete[] xtions; delete[] xtions;
delete nfa_states; delete nfa_states;
delete accept; delete accept;
delete meta_ec; delete meta_ec;
} }
void DFA_State::AddXtion(int sym, DFA_State* next_state) void DFA_State::AddXtion(int sym, DFA_State* next_state) { xtions[sym] = next_state; }
{
xtions[sym] = next_state;
}
void DFA_State::SymPartition(const EquivClass* ec) void DFA_State::SymPartition(const EquivClass* ec) {
{
// Partitioning is done by creating equivalence classes for those // Partitioning is done by creating equivalence classes for those
// characters which have out-transitions from the given state. Thus // characters which have out-transitions from the given state. Thus
// we are really creating equivalence classes of equivalence classes. // we are really creating equivalence classes of equivalence classes.
meta_ec = new EquivClass(ec->NumClasses()); meta_ec = new EquivClass(ec->NumClasses());
assert(nfa_states); assert(nfa_states);
for ( int i = 0; i < nfa_states->length(); ++i ) for ( int i = 0; i < nfa_states->length(); ++i ) {
{
NFA_State* n = (*nfa_states)[i]; NFA_State* n = (*nfa_states)[i];
int sym = n->TransSym(); int sym = n->TransSym();
if ( sym == SYM_EPSILON ) if ( sym == SYM_EPSILON )
continue; continue;
if ( sym != SYM_CCL ) if ( sym != SYM_CCL ) { // character transition
{ // character transition if ( ec->IsRep(sym) ) {
if ( ec->IsRep(sym) )
{
sym = ec->SymEquivClass(sym); sym = ec->SymEquivClass(sym);
meta_ec->UniqueChar(sym); meta_ec->UniqueChar(sym);
} }
@ -72,13 +62,11 @@ void DFA_State::SymPartition(const EquivClass* ec)
} }
meta_ec->BuildECs(); meta_ec->BuildECs();
} }
DFA_State* DFA_State::ComputeXtion(int sym, DFA_Machine* machine) DFA_State* DFA_State::ComputeXtion(int sym, DFA_Machine* machine) {
{
int equiv_sym = meta_ec->EquivRep(sym); int equiv_sym = meta_ec->EquivRep(sym);
if ( xtions[equiv_sym] != DFA_UNCOMPUTED_STATE_PTR ) if ( xtions[equiv_sym] != DFA_UNCOMPUTED_STATE_PTR ) {
{
AddXtion(sym, xtions[equiv_sym]); AddXtion(sym, xtions[equiv_sym]);
return xtions[sym]; return xtions[sym];
} }
@ -88,14 +76,12 @@ DFA_State* DFA_State::ComputeXtion(int sym, DFA_Machine* machine)
DFA_State* next_d; DFA_State* next_d;
NFA_state_list* ns = SymFollowSet(equiv_sym, ec); NFA_state_list* ns = SymFollowSet(equiv_sym, ec);
if ( ns->length() > 0 ) if ( ns->length() > 0 ) {
{
NFA_state_list* state_set = epsilon_closure(ns); NFA_state_list* state_set = epsilon_closure(ns);
if ( ! machine->StateSetToDFA_State(state_set, next_d, ec) ) if ( ! machine->StateSetToDFA_State(state_set, next_d, ec) )
delete state_set; delete state_set;
} }
else else {
{
delete ns; delete ns;
next_d = nullptr; // Jam next_d = nullptr; // Jam
} }
@ -105,37 +91,31 @@ DFA_State* DFA_State::ComputeXtion(int sym, DFA_Machine* machine)
AddXtion(sym, next_d); AddXtion(sym, next_d);
return xtions[sym]; return xtions[sym];
} }
void DFA_State::AppendIfNew(int sym, int_list* sym_list) void DFA_State::AppendIfNew(int sym, int_list* sym_list) {
{
for ( auto value : *sym_list ) for ( auto value : *sym_list )
if ( value == sym ) if ( value == sym )
return; return;
sym_list->push_back(sym); sym_list->push_back(sym);
} }
NFA_state_list* DFA_State::SymFollowSet(int ec_sym, const EquivClass* ec) NFA_state_list* DFA_State::SymFollowSet(int ec_sym, const EquivClass* ec) {
{
NFA_state_list* ns = new NFA_state_list; NFA_state_list* ns = new NFA_state_list;
assert(nfa_states); assert(nfa_states);
for ( int i = 0; i < nfa_states->length(); ++i ) for ( int i = 0; i < nfa_states->length(); ++i ) {
{
NFA_State* n = (*nfa_states)[i]; NFA_State* n = (*nfa_states)[i];
if ( n->TransSym() == SYM_CCL ) if ( n->TransSym() == SYM_CCL ) { // it's a character class
{ // it's a character class
CCL* ccl = n->TransCCL(); CCL* ccl = n->TransCCL();
int_list* syms = ccl->Syms(); int_list* syms = ccl->Syms();
if ( ccl->IsNegated() ) if ( ccl->IsNegated() ) {
{
size_t j; size_t j;
for ( j = 0; j < syms->size(); ++j ) for ( j = 0; j < syms->size(); ++j ) {
{
// Loop through (sorted) negated // Loop through (sorted) negated
// character class, which has // character class, which has
// presumably already been converted // presumably already been converted
@ -151,25 +131,21 @@ NFA_state_list* DFA_State::SymFollowSet(int ec_sym, const EquivClass* ec)
continue; continue;
} }
for ( auto sym : *syms ) for ( auto sym : *syms ) {
{
if ( sym > ec_sym ) if ( sym > ec_sym )
break; break;
if ( sym == ec_sym ) if ( sym == ec_sym ) {
{
n->AddXtionsTo(ns); n->AddXtionsTo(ns);
break; break;
} }
} }
} }
else if ( n->TransSym() == SYM_EPSILON ) else if ( n->TransSym() == SYM_EPSILON ) { // do nothing
{ // do nothing
} }
else if ( ec->IsRep(n->TransSym()) ) else if ( ec->IsRep(n->TransSym()) ) {
{
if ( ec_sym == ec->SymEquivClass(n->TransSym()) ) if ( ec_sym == ec->SymEquivClass(n->TransSym()) )
n->AddXtionsTo(ns); n->AddXtionsTo(ns);
} }
@ -177,38 +153,30 @@ NFA_state_list* DFA_State::SymFollowSet(int ec_sym, const EquivClass* ec)
ns->resize(0); ns->resize(0);
return ns; return ns;
} }
void DFA_State::ClearMarks() void DFA_State::ClearMarks() {
{ if ( mark ) {
if ( mark )
{
SetMark(nullptr); SetMark(nullptr);
for ( int i = 0; i < num_sym; ++i ) for ( int i = 0; i < num_sym; ++i ) {
{
DFA_State* s = xtions[i]; DFA_State* s = xtions[i];
if ( s && s != DFA_UNCOMPUTED_STATE_PTR ) if ( s && s != DFA_UNCOMPUTED_STATE_PTR )
xtions[i]->ClearMarks(); xtions[i]->ClearMarks();
} }
} }
} }
void DFA_State::Describe(ODesc* d) const void DFA_State::Describe(ODesc* d) const { d->Add("DFA state"); }
{
d->Add("DFA state");
}
void DFA_State::Dump(FILE* f, DFA_Machine* m) void DFA_State::Dump(FILE* f, DFA_Machine* m) {
{
if ( mark ) if ( mark )
return; return;
fprintf(f, "\nDFA state %d:", StateNum()); fprintf(f, "\nDFA state %d:", StateNum());
if ( accept ) if ( accept ) {
{
AcceptingSet::const_iterator it; AcceptingSet::const_iterator it;
for ( it = accept->begin(); it != accept->end(); ++it ) for ( it = accept->begin(); it != accept->end(); ++it )
@ -218,8 +186,7 @@ void DFA_State::Dump(FILE* f, DFA_Machine* m)
fprintf(f, "\n"); fprintf(f, "\n");
int num_trans = 0; int num_trans = 0;
for ( int sym = 0; sym < num_sym; ++sym ) for ( int sym = 0; sym < num_sym; ++sym ) {
{
DFA_State* s = xtions[sym]; DFA_State* s = xtions[sym];
if ( ! s ) if ( ! s )
@ -244,11 +211,9 @@ void DFA_State::Dump(FILE* f, DFA_Machine* m)
snprintf(xbuf, xbuf_size, "'%c'-'%c'", r, m->Rep(i - 1)); snprintf(xbuf, xbuf_size, "'%c'-'%c'", r, m->Rep(i - 1));
if ( s == DFA_UNCOMPUTED_STATE_PTR ) if ( s == DFA_UNCOMPUTED_STATE_PTR )
fprintf(f, "%stransition on %s to <uncomputed>", ++num_trans == 1 ? "\t" : "\n\t", fprintf(f, "%stransition on %s to <uncomputed>", ++num_trans == 1 ? "\t" : "\n\t", xbuf);
xbuf);
else else
fprintf(f, "%stransition on %s to state %d", ++num_trans == 1 ? "\t" : "\n\t", xbuf, fprintf(f, "%stransition on %s to state %d", ++num_trans == 1 ? "\t" : "\n\t", xbuf, s->StateNum());
s->StateNum());
delete[] xbuf; delete[] xbuf;
@ -260,19 +225,16 @@ void DFA_State::Dump(FILE* f, DFA_Machine* m)
SetMark(this); SetMark(this);
for ( int sym = 0; sym < num_sym; ++sym ) for ( int sym = 0; sym < num_sym; ++sym ) {
{
DFA_State* s = xtions[sym]; DFA_State* s = xtions[sym];
if ( s && s != DFA_UNCOMPUTED_STATE_PTR ) if ( s && s != DFA_UNCOMPUTED_STATE_PTR )
s->Dump(f, m); s->Dump(f, m);
} }
} }
void DFA_State::Stats(unsigned int* computed, unsigned int* uncomputed) void DFA_State::Stats(unsigned int* computed, unsigned int* uncomputed) {
{ for ( int sym = 0; sym < num_sym; ++sym ) {
for ( int sym = 0; sym < num_sym; ++sym )
{
DFA_State* s = xtions[sym]; DFA_State* s = xtions[sym];
if ( s == DFA_UNCOMPUTED_STATE_PTR ) if ( s == DFA_UNCOMPUTED_STATE_PTR )
@ -280,48 +242,38 @@ void DFA_State::Stats(unsigned int* computed, unsigned int* uncomputed)
else else
(*computed)++; (*computed)++;
} }
} }
unsigned int DFA_State::Size() unsigned int DFA_State::Size() {
{
return sizeof(*this) + util::pad_size(sizeof(DFA_State*) * num_sym) + return sizeof(*this) + util::pad_size(sizeof(DFA_State*) * num_sym) +
(accept ? util::pad_size(sizeof(int) * accept->size()) : 0) + (accept ? util::pad_size(sizeof(int) * accept->size()) : 0) +
(nfa_states ? util::pad_size(sizeof(NFA_State*) * nfa_states->length()) : 0) + (nfa_states ? util::pad_size(sizeof(NFA_State*) * nfa_states->length()) : 0) +
(meta_ec ? meta_ec->Size() : 0); (meta_ec ? meta_ec->Size() : 0);
} }
DFA_State_Cache::DFA_State_Cache() DFA_State_Cache::DFA_State_Cache() { hits = misses = 0; }
{
hits = misses = 0;
}
DFA_State_Cache::~DFA_State_Cache() DFA_State_Cache::~DFA_State_Cache() {
{ for ( auto& entry : states ) {
for ( auto& entry : states )
{
assert(entry.second); assert(entry.second);
Unref(entry.second); Unref(entry.second);
} }
states.clear(); states.clear();
} }
DFA_State* DFA_State_Cache::Lookup(const NFA_state_list& nfas, DigestStr* digest) DFA_State* DFA_State_Cache::Lookup(const NFA_state_list& nfas, DigestStr* digest) {
{
// We assume that state ID's don't exceed 10 digits, plus // We assume that state ID's don't exceed 10 digits, plus
// we allow one more character for the delimiter. // we allow one more character for the delimiter.
auto id_tag_buf = std::make_unique<u_char[]>(nfas.length() * 11 + 1); auto id_tag_buf = std::make_unique<u_char[]>(nfas.length() * 11 + 1);
auto id_tag = id_tag_buf.get(); auto id_tag = id_tag_buf.get();
u_char* p = id_tag; u_char* p = id_tag;
for ( int i = 0; i < nfas.length(); ++i ) for ( int i = 0; i < nfas.length(); ++i ) {
{
NFA_State* n = nfas[i]; NFA_State* n = nfas[i];
if ( n->TransSym() != SYM_EPSILON || n->Accept() != NO_ACCEPT ) if ( n->TransSym() != SYM_EPSILON || n->Accept() != NO_ACCEPT ) {
{
int id = n->ID(); int id = n->ID();
do do {
{
*p++ = '0' + (char)(id % 10); *p++ = '0' + (char)(id % 10);
id /= 10; id /= 10;
} while ( id > 0 ); } while ( id > 0 );
@ -338,8 +290,7 @@ DFA_State* DFA_State_Cache::Lookup(const NFA_state_list& nfas, DigestStr* digest
*digest = DigestStr(reinterpret_cast<const unsigned char*>(hash), 16); *digest = DigestStr(reinterpret_cast<const unsigned char*>(hash), 16);
auto entry = states.find(*digest); auto entry = states.find(*digest);
if ( entry == states.end() ) if ( entry == states.end() ) {
{
++misses; ++misses;
return nullptr; return nullptr;
} }
@ -348,16 +299,14 @@ DFA_State* DFA_State_Cache::Lookup(const NFA_state_list& nfas, DigestStr* digest
digest->clear(); digest->clear();
return entry->second; return entry->second;
} }
DFA_State* DFA_State_Cache::Insert(DFA_State* state, DigestStr digest) DFA_State* DFA_State_Cache::Insert(DFA_State* state, DigestStr digest) {
{
states.emplace(std::move(digest), state); states.emplace(std::move(digest), state);
return state; return state;
} }
void DFA_State_Cache::GetStats(Stats* s) void DFA_State_Cache::GetStats(Stats* s) {
{
s->dfa_states = 0; s->dfa_states = 0;
s->nfa_states = 0; s->nfa_states = 0;
s->computed = 0; s->computed = 0;
@ -366,18 +315,16 @@ void DFA_State_Cache::GetStats(Stats* s)
s->hits = hits; s->hits = hits;
s->misses = misses; s->misses = misses;
for ( const auto& state : states ) for ( const auto& state : states ) {
{
DFA_State* e = state.second; DFA_State* e = state.second;
++s->dfa_states; ++s->dfa_states;
s->nfa_states += e->NFAStateNum(); s->nfa_states += e->NFAStateNum();
e->Stats(&s->computed, &s->uncomputed); e->Stats(&s->computed, &s->uncomputed);
s->mem += util::pad_size(e->Size()) + padded_sizeof(*e); s->mem += util::pad_size(e->Size()) + padded_sizeof(*e);
} }
} }
DFA_Machine::DFA_Machine(NFA_Machine* n, EquivClass* arg_ec) DFA_Machine::DFA_Machine(NFA_Machine* n, EquivClass* arg_ec) {
{
state_count = 0; state_count = 0;
nfa = n; nfa = n;
@ -390,38 +337,29 @@ DFA_Machine::DFA_Machine(NFA_Machine* n, EquivClass* arg_ec)
NFA_state_list* ns = new NFA_state_list; NFA_state_list* ns = new NFA_state_list;
ns->push_back(n->FirstState()); ns->push_back(n->FirstState());
if ( ns->length() > 0 ) if ( ns->length() > 0 ) {
{
NFA_state_list* state_set = epsilon_closure(ns); NFA_state_list* state_set = epsilon_closure(ns);
StateSetToDFA_State(state_set, start_state, ec); StateSetToDFA_State(state_set, start_state, ec);
} }
else else {
{
start_state = nullptr; // Jam start_state = nullptr; // Jam
delete ns; delete ns;
} }
} }
DFA_Machine::~DFA_Machine() DFA_Machine::~DFA_Machine() {
{
delete dfa_state_cache; delete dfa_state_cache;
Unref(nfa); Unref(nfa);
} }
void DFA_Machine::Describe(ODesc* d) const void DFA_Machine::Describe(ODesc* d) const { d->Add("DFA machine"); }
{
d->Add("DFA machine");
}
void DFA_Machine::Dump(FILE* f) void DFA_Machine::Dump(FILE* f) {
{
start_state->Dump(f, this); start_state->Dump(f, this);
start_state->ClearMarks(); start_state->ClearMarks();
} }
bool DFA_Machine::StateSetToDFA_State(NFA_state_list* state_set, DFA_State*& d, bool DFA_Machine::StateSetToDFA_State(NFA_state_list* state_set, DFA_State*& d, const EquivClass* ec) {
const EquivClass* ec)
{
DigestStr digest; DigestStr digest;
d = dfa_state_cache->Lookup(*state_set, &digest); d = dfa_state_cache->Lookup(*state_set, &digest);
@ -430,16 +368,14 @@ bool DFA_Machine::StateSetToDFA_State(NFA_state_list* state_set, DFA_State*& d,
AcceptingSet* accept = new AcceptingSet; AcceptingSet* accept = new AcceptingSet;
for ( int i = 0; i < state_set->length(); ++i ) for ( int i = 0; i < state_set->length(); ++i ) {
{
int acc = (*state_set)[i]->Accept(); int acc = (*state_set)[i]->Accept();
if ( acc != NO_ACCEPT ) if ( acc != NO_ACCEPT )
accept->insert(acc); accept->insert(acc);
} }
if ( accept->empty() ) if ( accept->empty() ) {
{
delete accept; delete accept;
accept = nullptr; accept = nullptr;
} }
@ -448,15 +384,14 @@ bool DFA_Machine::StateSetToDFA_State(NFA_state_list* state_set, DFA_State*& d,
d = dfa_state_cache->Insert(ds, std::move(digest)); d = dfa_state_cache->Insert(ds, std::move(digest));
return true; return true;
} }
int DFA_Machine::Rep(int sym) int DFA_Machine::Rep(int sym) {
{
for ( int i = 0; i < NUM_SYM; ++i ) for ( int i = 0; i < NUM_SYM; ++i )
if ( ec->SymEquivClass(i) == sym ) if ( ec->SymEquivClass(i) == sym )
return i; return i;
return -1; return -1;
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -11,8 +11,7 @@
#include "zeek/Obj.h" #include "zeek/Obj.h"
#include "zeek/RE.h" // for typedef AcceptingSet #include "zeek/RE.h" // for typedef AcceptingSet
namespace zeek::detail namespace zeek::detail {
{
class DFA_State; class DFA_State;
class DFA_Machine; class DFA_Machine;
@ -22,11 +21,9 @@ class DFA_Machine;
#define DFA_UNCOMPUTED_STATE -2 #define DFA_UNCOMPUTED_STATE -2
#define DFA_UNCOMPUTED_STATE_PTR ((DFA_State*)DFA_UNCOMPUTED_STATE) #define DFA_UNCOMPUTED_STATE_PTR ((DFA_State*)DFA_UNCOMPUTED_STATE)
class DFA_State : public Obj class DFA_State : public Obj {
{
public: public:
DFA_State(int state_num, const EquivClass* ec, NFA_state_list* nfa_states, DFA_State(int state_num, const EquivClass* ec, NFA_state_list* nfa_states, AcceptingSet* accept);
AcceptingSet* accept);
~DFA_State() override; ~DFA_State() override;
int StateNum() const { return state_num; } int StateNum() const { return state_num; }
@ -68,12 +65,11 @@ protected:
NFA_state_list* nfa_states; NFA_state_list* nfa_states;
EquivClass* meta_ec; // which ec's make same transition EquivClass* meta_ec; // which ec's make same transition
DFA_State* mark; DFA_State* mark;
}; };
using DigestStr = std::basic_string<u_char>; using DigestStr = std::basic_string<u_char>;
class DFA_State_Cache class DFA_State_Cache {
{
public: public:
DFA_State_Cache(); DFA_State_Cache();
~DFA_State_Cache(); ~DFA_State_Cache();
@ -86,8 +82,7 @@ public:
int NumEntries() const { return states.size(); } int NumEntries() const { return states.size(); }
struct Stats struct Stats {
{
// Sum of all NFA states // Sum of all NFA states
unsigned int nfa_states; unsigned int nfa_states;
unsigned int dfa_states; unsigned int dfa_states;
@ -106,10 +101,9 @@ private:
// Hash indexed by NFA states (MD5s of them, actually). // Hash indexed by NFA states (MD5s of them, actually).
std::map<DigestStr, DFA_State*> states; std::map<DigestStr, DFA_State*> states;
}; };
class DFA_Machine : public Obj class DFA_Machine : public Obj {
{
public: public:
DFA_Machine(NFA_Machine* n, EquivClass* ec); DFA_Machine(NFA_Machine* n, EquivClass* ec);
~DFA_Machine() override; ~DFA_Machine() override;
@ -140,14 +134,13 @@ protected:
DFA_State_Cache* dfa_state_cache; DFA_State_Cache* dfa_state_cache;
NFA_Machine* nfa; NFA_Machine* nfa;
}; };
inline DFA_State* DFA_State::Xtion(int sym, DFA_Machine* machine) inline DFA_State* DFA_State::Xtion(int sym, DFA_Machine* machine) {
{
if ( xtions[sym] == DFA_UNCOMPUTED_STATE_PTR ) if ( xtions[sym] == DFA_UNCOMPUTED_STATE_PTR )
return ComputeXtion(sym, machine); return ComputeXtion(sym, machine);
else else
return xtions[sym]; return xtions[sym];
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -6,11 +6,9 @@
#include "zeek/DNS_Mgr.h" #include "zeek/DNS_Mgr.h"
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
namespace zeek::detail namespace zeek::detail {
{
DNS_Mapping::DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type) DNS_Mapping::DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type) {
{
Init(h); Init(h);
req_host = host; req_host = host;
req_ttl = ttl; req_ttl = ttl;
@ -18,18 +16,16 @@ DNS_Mapping::DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int
if ( names.empty() ) if ( names.empty() )
names.push_back(std::move(host)); names.push_back(std::move(host));
} }
DNS_Mapping::DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl) DNS_Mapping::DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl) {
{
Init(h); Init(h);
req_addr = addr; req_addr = addr;
req_ttl = ttl; req_ttl = ttl;
req_type = T_PTR; req_type = T_PTR;
} }
DNS_Mapping::DNS_Mapping(FILE* f) DNS_Mapping::DNS_Mapping(FILE* f) {
{
Clear(); Clear();
init_failed = true; init_failed = true;
@ -38,8 +34,7 @@ DNS_Mapping::DNS_Mapping(FILE* f)
char buf[512]; char buf[512];
if ( ! fgets(buf, sizeof(buf), f) ) if ( ! fgets(buf, sizeof(buf), f) ) {
{
no_mapping = true; no_mapping = true;
return; return;
} }
@ -49,9 +44,8 @@ DNS_Mapping::DNS_Mapping(FILE* f)
int failed_local; int failed_local;
int num_addrs; int num_addrs;
if ( sscanf(buf, "%lf %d %512s %d %512s %d %d %" PRIu32, &creation_time, &is_req_host, req_buf, if ( sscanf(buf, "%lf %d %512s %d %512s %d %d %" PRIu32, &creation_time, &is_req_host, req_buf, &failed_local,
&failed_local, name_buf, &req_type, &num_addrs, &req_ttl) != 8 ) name_buf, &req_type, &num_addrs, &req_ttl) != 8 ) {
{
no_mapping = true; no_mapping = true;
return; return;
} }
@ -65,8 +59,7 @@ DNS_Mapping::DNS_Mapping(FILE* f)
names.emplace_back(name_buf); names.emplace_back(name_buf);
for ( int i = 0; i < num_addrs; ++i ) for ( int i = 0; i < num_addrs; ++i ) {
{
if ( ! fgets(buf, sizeof(buf), f) ) if ( ! fgets(buf, sizeof(buf), f) )
return; return;
@ -78,15 +71,13 @@ DNS_Mapping::DNS_Mapping(FILE* f)
} }
init_failed = false; init_failed = false;
} }
ListValPtr DNS_Mapping::Addrs() ListValPtr DNS_Mapping::Addrs() {
{
if ( failed ) if ( failed )
return nullptr; return nullptr;
if ( ! addrs_val ) if ( ! addrs_val ) {
{
addrs_val = make_intrusive<ListVal>(TYPE_ADDR); addrs_val = make_intrusive<ListVal>(TYPE_ADDR);
for ( const auto& addr : addrs ) for ( const auto& addr : addrs )
@ -94,20 +85,18 @@ ListValPtr DNS_Mapping::Addrs()
} }
return addrs_val; return addrs_val;
} }
TableValPtr DNS_Mapping::AddrsSet() TableValPtr DNS_Mapping::AddrsSet() {
{
auto l = Addrs(); auto l = Addrs();
if ( ! l || l->Length() == 0 ) if ( ! l || l->Length() == 0 )
return DNS_Mgr::empty_addr_set(); return DNS_Mgr::empty_addr_set();
return l->ToSetVal(); return l->ToSetVal();
} }
StringValPtr DNS_Mapping::Host() StringValPtr DNS_Mapping::Host() {
{
if ( failed || names.empty() ) if ( failed || names.empty() )
return nullptr; return nullptr;
@ -115,18 +104,16 @@ StringValPtr DNS_Mapping::Host()
host_val = make_intrusive<StringVal>(names[0]); host_val = make_intrusive<StringVal>(names[0]);
return host_val; return host_val;
} }
void DNS_Mapping::Init(struct hostent* h) void DNS_Mapping::Init(struct hostent* h) {
{
no_mapping = false; no_mapping = false;
init_failed = false; init_failed = false;
creation_time = util::current_time(); creation_time = util::current_time();
host_val = nullptr; host_val = nullptr;
addrs_val = nullptr; addrs_val = nullptr;
if ( ! h ) if ( ! h ) {
{
Clear(); Clear();
return; return;
} }
@ -136,10 +123,8 @@ void DNS_Mapping::Init(struct hostent* h)
// TODO: this could easily be expanded to include all of the aliases as well // TODO: this could easily be expanded to include all of the aliases as well
names.emplace_back(h->h_name); names.emplace_back(h->h_name);
if ( h->h_addr_list ) if ( h->h_addr_list ) {
{ for ( int i = 0; h->h_addr_list[i] != NULL; ++i ) {
for ( int i = 0; h->h_addr_list[i] != NULL; ++i )
{
if ( h->h_addrtype == AF_INET ) if ( h->h_addrtype == AF_INET )
addrs.emplace_back(IPv4, (uint32_t*)h->h_addr_list[i], IPAddr::Network); addrs.emplace_back(IPv4, (uint32_t*)h->h_addr_list[i], IPAddr::Network);
else if ( h->h_addrtype == AF_INET6 ) else if ( h->h_addrtype == AF_INET6 )
@ -148,10 +133,9 @@ void DNS_Mapping::Init(struct hostent* h)
} }
failed = false; failed = false;
} }
void DNS_Mapping::Clear() void DNS_Mapping::Clear() {
{
names.clear(); names.clear();
host_val = nullptr; host_val = nullptr;
addrs.clear(); addrs.clear();
@ -159,65 +143,56 @@ void DNS_Mapping::Clear()
no_mapping = false; no_mapping = false;
req_type = 0; req_type = 0;
failed = true; failed = true;
} }
void DNS_Mapping::Save(FILE* f) const void DNS_Mapping::Save(FILE* f) const {
{
fprintf(f, "%.0f %d %s %d %s %d %zu %" PRIu32 "\n", creation_time, ! req_host.empty(), fprintf(f, "%.0f %d %s %d %s %d %zu %" PRIu32 "\n", creation_time, ! req_host.empty(),
req_host.empty() ? req_addr.AsString().c_str() : req_host.c_str(), failed, req_host.empty() ? req_addr.AsString().c_str() : req_host.c_str(), failed,
names.empty() ? "*" : names[0].c_str(), req_type, addrs.size(), req_ttl); names.empty() ? "*" : names[0].c_str(), req_type, addrs.size(), req_ttl);
for ( const auto& addr : addrs ) for ( const auto& addr : addrs )
fprintf(f, "%s\n", addr.AsString().c_str()); fprintf(f, "%s\n", addr.AsString().c_str());
} }
void DNS_Mapping::Merge(const DNS_MappingPtr& other) void DNS_Mapping::Merge(const DNS_MappingPtr& other) {
{
std::copy(other->names.begin(), other->names.end(), std::back_inserter(names)); std::copy(other->names.begin(), other->names.end(), std::back_inserter(names));
std::copy(other->addrs.begin(), other->addrs.end(), std::back_inserter(addrs)); std::copy(other->addrs.begin(), other->addrs.end(), std::back_inserter(addrs));
} }
// This value needs to be incremented if something changes in the data stored by Save(). This // This value needs to be incremented if something changes in the data stored by Save(). This
// allows us to change the structure of the cache without breaking something in DNS_Mgr. // allows us to change the structure of the cache without breaking something in DNS_Mgr.
constexpr int FILE_VERSION = 1; constexpr int FILE_VERSION = 1;
void DNS_Mapping::InitializeCache(FILE* f) void DNS_Mapping::InitializeCache(FILE* f) { fprintf(f, "%d\n", FILE_VERSION); }
{
fprintf(f, "%d\n", FILE_VERSION);
}
bool DNS_Mapping::ValidateCacheVersion(FILE* f) bool DNS_Mapping::ValidateCacheVersion(FILE* f) {
{
char buf[512]; char buf[512];
if ( ! fgets(buf, sizeof(buf), f) ) if ( ! fgets(buf, sizeof(buf), f) )
return false; return false;
int version; int version;
if ( sscanf(buf, "%d", &version) != 1 ) if ( sscanf(buf, "%d", &version) != 1 ) {
{
reporter->Warning("Existing DNS cache did not have correct version, ignoring"); reporter->Warning("Existing DNS cache did not have correct version, ignoring");
return false; return false;
} }
return FILE_VERSION == version; return FILE_VERSION == version;
} }
////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////
TEST_CASE("dns_mapping init null hostent") TEST_CASE("dns_mapping init null hostent") {
{
DNS_Mapping mapping("www.apple.com", nullptr, 123, T_A); DNS_Mapping mapping("www.apple.com", nullptr, 123, T_A);
CHECK(! mapping.Valid()); CHECK(! mapping.Valid());
CHECK(mapping.Addrs() == nullptr); CHECK(mapping.Addrs() == nullptr);
CHECK(mapping.AddrsSet()->EqualTo(DNS_Mgr::empty_addr_set())); CHECK(mapping.AddrsSet()->EqualTo(DNS_Mgr::empty_addr_set()));
CHECK(mapping.Host() == nullptr); CHECK(mapping.Host() == nullptr);
} }
TEST_CASE("dns_mapping init host") TEST_CASE("dns_mapping init host") {
{
IPAddr addr("1.2.3.4"); IPAddr addr("1.2.3.4");
in4_addr in4; in4_addr in4;
addr.CopyIPv4(&in4); addr.CopyIPv4(&in4);
@ -253,10 +228,9 @@ TEST_CASE("dns_mapping init host")
CHECK(svh->ToStdString() == "testing.home"); CHECK(svh->ToStdString() == "testing.home");
delete[] he.h_name; delete[] he.h_name;
} }
TEST_CASE("dns_mapping init addr") TEST_CASE("dns_mapping init addr") {
{
IPAddr addr("1.2.3.4"); IPAddr addr("1.2.3.4");
in4_addr in4; in4_addr in4;
addr.CopyIPv4(&in4); addr.CopyIPv4(&in4);
@ -292,10 +266,9 @@ TEST_CASE("dns_mapping init addr")
CHECK(svh->ToStdString() == "testing.home"); CHECK(svh->ToStdString() == "testing.home");
delete[] he.h_name; delete[] he.h_name;
} }
TEST_CASE("dns_mapping save reload") TEST_CASE("dns_mapping save reload") {
{
// TODO: this test uses fmemopen and mkdtemp, both of which aren't available on // TODO: this test uses fmemopen and mkdtemp, both of which aren't available on
// Windows. We'll have to figure out another way to do this test there. // Windows. We'll have to figure out another way to do this test there.
#ifndef _MSC_VER #ifndef _MSC_VER
@ -360,10 +333,9 @@ TEST_CASE("dns_mapping save reload")
delete[] he.h_name; delete[] he.h_name;
#endif #endif
} }
TEST_CASE("dns_mapping multiple addresses") TEST_CASE("dns_mapping multiple addresses") {
{
IPAddr addr("1.2.3.4"); IPAddr addr("1.2.3.4");
in4_addr in4_1; in4_addr in4_1;
addr.CopyIPv4(&in4_1); addr.CopyIPv4(&in4_1);
@ -397,10 +369,9 @@ TEST_CASE("dns_mapping multiple addresses")
CHECK(lvae->Get().AsString() == "5.6.7.8"); CHECK(lvae->Get().AsString() == "5.6.7.8");
delete[] he.h_name; delete[] he.h_name;
} }
TEST_CASE("dns_mapping ipv6") TEST_CASE("dns_mapping ipv6") {
{
IPAddr addr("64:ff9b:1::"); IPAddr addr("64:ff9b:1::");
in6_addr in6; in6_addr in6;
addr.CopyIPv6(&in6); addr.CopyIPv6(&in6);
@ -428,6 +399,6 @@ TEST_CASE("dns_mapping ipv6")
CHECK(lvae->Get().AsString() == "64:ff9b:1::"); CHECK(lvae->Get().AsString() == "64:ff9b:1::");
delete[] he.h_name; delete[] he.h_name;
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -8,14 +8,12 @@
#include "zeek/IPAddr.h" #include "zeek/IPAddr.h"
#include "zeek/Val.h" #include "zeek/Val.h"
namespace zeek::detail namespace zeek::detail {
{
class DNS_Mapping; class DNS_Mapping;
using DNS_MappingPtr = std::shared_ptr<DNS_Mapping>; using DNS_MappingPtr = std::shared_ptr<DNS_Mapping>;
class DNS_Mapping class DNS_Mapping {
{
public: public:
DNS_Mapping() = delete; DNS_Mapping() = delete;
DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type); DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type);
@ -75,6 +73,6 @@ protected:
bool no_mapping = false; // when initializing from a file, immediately hit EOF bool no_mapping = false; // when initializing from a file, immediately hit EOF
bool init_failed = false; bool init_failed = false;
bool failed = false; bool failed = false;
}; };
} // namespace zeek::detail } // namespace zeek::detail

File diff suppressed because it is too large Load diff

View file

@ -27,43 +27,39 @@ typedef struct ares_channeldata* ares_channel;
#define T_TXT 16 #define T_TXT 16
#endif #endif
namespace zeek namespace zeek {
{
class Val; class Val;
class ListVal; class ListVal;
class TableVal; class TableVal;
class StringVal; class StringVal;
template <class T> class IntrusivePtr; template<class T>
class IntrusivePtr;
using ValPtr = IntrusivePtr<Val>; using ValPtr = IntrusivePtr<Val>;
using ListValPtr = IntrusivePtr<ListVal>; using ListValPtr = IntrusivePtr<ListVal>;
using TableValPtr = IntrusivePtr<TableVal>; using TableValPtr = IntrusivePtr<TableVal>;
using StringValPtr = IntrusivePtr<StringVal>; using StringValPtr = IntrusivePtr<StringVal>;
} // namespace zeek } // namespace zeek
namespace zeek::detail namespace zeek::detail {
{
class DNS_Mapping; class DNS_Mapping;
using DNS_MappingPtr = std::shared_ptr<DNS_Mapping>; using DNS_MappingPtr = std::shared_ptr<DNS_Mapping>;
class DNS_Request; class DNS_Request;
enum DNS_MgrMode enum DNS_MgrMode {
{
DNS_PRIME, // used to prime the cache DNS_PRIME, // used to prime the cache
DNS_FORCE, // internal error if cache miss DNS_FORCE, // internal error if cache miss
DNS_DEFAULT, // lookup names as they're requested DNS_DEFAULT, // lookup names as they're requested
DNS_FAKE, // don't look up names, just return dummy results DNS_FAKE, // don't look up names, just return dummy results
}; };
class DNS_Mgr : public iosource::IOSource class DNS_Mgr : public iosource::IOSource {
{
public: public:
/** /**
* Base class for callback handling for asynchronous lookups. * Base class for callback handling for asynchronous lookups.
*/ */
class LookupCallback class LookupCallback {
{
public: public:
virtual ~LookupCallback() = default; virtual ~LookupCallback() = default;
@ -86,7 +82,7 @@ public:
* *
* @param val A Val containing the data from the query. * @param val A Val containing the data from the query.
*/ */
virtual void Resolved(ValPtr data, int request_type) { } virtual void Resolved(ValPtr data, int request_type) {}
/** /**
* Called when a timeout request occurs. * Called when a timeout request occurs.
@ -200,8 +196,7 @@ public:
*/ */
bool Save(); bool Save();
struct Stats struct Stats {
{
unsigned long requests; // These count only async requests. unsigned long requests; // These count only async requests.
unsigned long successful; unsigned long successful;
unsigned long failed; unsigned long failed;
@ -254,12 +249,9 @@ protected:
friend class LookupCallback; friend class LookupCallback;
friend class DNS_Request; friend class DNS_Request;
StringValPtr LookupAddrInCache(const IPAddr& addr, bool cleanup_expired = false, StringValPtr LookupAddrInCache(const IPAddr& addr, bool cleanup_expired = false, bool check_failed = false);
bool check_failed = false); TableValPtr LookupNameInCache(const std::string& name, bool cleanup_expired = false, bool check_failed = false);
TableValPtr LookupNameInCache(const std::string& name, bool cleanup_expired = false, StringValPtr LookupOtherInCache(const std::string& name, int request_type, bool cleanup_expired = false);
bool check_failed = false);
StringValPtr LookupOtherInCache(const std::string& name, int request_type,
bool cleanup_expired = false);
// Finish the request if we have a result. If not, time it out if // Finish the request if we have a result. If not, time it out if
// requested. // requested.
@ -307,8 +299,7 @@ protected:
using CallbackList = std::list<LookupCallback*>; using CallbackList = std::list<LookupCallback*>;
struct AsyncRequest struct AsyncRequest {
{
double time = 0.0; double time = 0.0;
IPAddr addr; IPAddr addr;
std::string host; std::string host;
@ -316,18 +307,15 @@ protected:
int type = 0; int type = 0;
bool processed = false; bool processed = false;
AsyncRequest(std::string host, int request_type) : host(std::move(host)), type(request_type) AsyncRequest(std::string host, int request_type) : host(std::move(host)), type(request_type) {}
{ AsyncRequest(const IPAddr& addr) : addr(addr), type(T_PTR) {}
}
AsyncRequest(const IPAddr& addr) : addr(addr), type(T_PTR) { }
void Resolved(const std::string& name); void Resolved(const std::string& name);
void Resolved(TableValPtr addrs); void Resolved(TableValPtr addrs);
void Timeout(); void Timeout();
}; };
struct AsyncRequestCompare struct AsyncRequestCompare {
{
bool operator()(const AsyncRequest* a, const AsyncRequest* b) { return a->time > b->time; } bool operator()(const AsyncRequest* a, const AsyncRequest* b) { return a->time > b->time; }
}; };
@ -345,8 +333,8 @@ protected:
std::set<int> write_socket_fds; std::set<int> write_socket_fds;
bool shutting_down = false; bool shutting_down = false;
}; };
extern DNS_Mgr* dns_mgr; extern DNS_Mgr* dns_mgr;
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -18,34 +18,27 @@
#include "zeek/Val.h" #include "zeek/Val.h"
#include "zeek/module_util.h" #include "zeek/module_util.h"
namespace zeek::detail namespace zeek::detail {
{
// BreakpointTimer used for time-based breakpoints // BreakpointTimer used for time-based breakpoints
class BreakpointTimer final : public Timer class BreakpointTimer final : public Timer {
{
public: public:
BreakpointTimer(DbgBreakpoint* arg_bp, double arg_t) : Timer(arg_t, TIMER_BREAKPOINT) BreakpointTimer(DbgBreakpoint* arg_bp, double arg_t) : Timer(arg_t, TIMER_BREAKPOINT) { bp = arg_bp; }
{
bp = arg_bp;
}
void Dispatch(double t, bool is_expire) override; void Dispatch(double t, bool is_expire) override;
protected: protected:
DbgBreakpoint* bp; DbgBreakpoint* bp;
}; };
void BreakpointTimer::Dispatch(double t, bool is_expire) void BreakpointTimer::Dispatch(double t, bool is_expire) {
{
if ( is_expire ) if ( is_expire )
return; return;
bp->ShouldBreak(t); bp->ShouldBreak(t);
} }
DbgBreakpoint::DbgBreakpoint() DbgBreakpoint::DbgBreakpoint() {
{
kind = BP_STMT; kind = BP_STMT;
enabled = temporary = false; enabled = temporary = false;
@ -59,16 +52,14 @@ DbgBreakpoint::DbgBreakpoint()
description[0] = 0; description[0] = 0;
source_filename = nullptr; source_filename = nullptr;
source_line = 0; source_line = 0;
} }
DbgBreakpoint::~DbgBreakpoint() DbgBreakpoint::~DbgBreakpoint() {
{
SetEnable(false); // clean up any active state SetEnable(false); // clean up any active state
RemoveFromGlobalMap(); RemoveFromGlobalMap();
} }
bool DbgBreakpoint::SetEnable(bool do_enable) bool DbgBreakpoint::SetEnable(bool do_enable) {
{
bool old_value = enabled; bool old_value = enabled;
enabled = do_enable; enabled = do_enable;
@ -80,25 +71,21 @@ bool DbgBreakpoint::SetEnable(bool do_enable)
RemoveFromStmt(); RemoveFromStmt();
return old_value; return old_value;
} }
void DbgBreakpoint::AddToGlobalMap() void DbgBreakpoint::AddToGlobalMap() {
{
// Make sure it's not there already. // Make sure it's not there already.
RemoveFromGlobalMap(); RemoveFromGlobalMap();
g_debugger_state.breakpoint_map.insert(BPMapType::value_type(at_stmt, this)); g_debugger_state.breakpoint_map.insert(BPMapType::value_type(at_stmt, this));
} }
void DbgBreakpoint::RemoveFromGlobalMap() void DbgBreakpoint::RemoveFromGlobalMap() {
{
std::pair<BPMapType::iterator, BPMapType::iterator> p; std::pair<BPMapType::iterator, BPMapType::iterator> p;
p = g_debugger_state.breakpoint_map.equal_range(at_stmt); p = g_debugger_state.breakpoint_map.equal_range(at_stmt);
for ( BPMapType::iterator i = p.first; i != p.second; ) for ( BPMapType::iterator i = p.first; i != p.second; ) {
{ if ( i->second == this ) {
if ( i->second == this )
{
BPMapType::iterator next = i; BPMapType::iterator next = i;
++next; ++next;
g_debugger_state.breakpoint_map.erase(i); g_debugger_state.breakpoint_map.erase(i);
@ -107,36 +94,30 @@ void DbgBreakpoint::RemoveFromGlobalMap()
else else
++i; ++i;
} }
} }
void DbgBreakpoint::AddToStmt() void DbgBreakpoint::AddToStmt() {
{
if ( at_stmt ) if ( at_stmt )
at_stmt->IncrBPCount(); at_stmt->IncrBPCount();
} }
void DbgBreakpoint::RemoveFromStmt() void DbgBreakpoint::RemoveFromStmt() {
{
if ( at_stmt ) if ( at_stmt )
at_stmt->DecrBPCount(); at_stmt->DecrBPCount();
} }
bool DbgBreakpoint::SetLocation(ParseLocationRec plr, std::string_view loc_str) bool DbgBreakpoint::SetLocation(ParseLocationRec plr, std::string_view loc_str) {
{ if ( plr.type == PLR_UNKNOWN ) {
if ( plr.type == PLR_UNKNOWN )
{
debug_msg("Breakpoint specifier invalid or operation canceled.\n"); debug_msg("Breakpoint specifier invalid or operation canceled.\n");
return false; return false;
} }
if ( plr.type == PLR_FILE_AND_LINE ) if ( plr.type == PLR_FILE_AND_LINE ) {
{
kind = BP_LINE; kind = BP_LINE;
source_filename = plr.filename; source_filename = plr.filename;
source_line = plr.line; source_line = plr.line;
if ( ! plr.stmt ) if ( ! plr.stmt ) {
{
debug_msg("No statement at that line.\n"); debug_msg("No statement at that line.\n");
return false; return false;
} }
@ -147,15 +128,13 @@ bool DbgBreakpoint::SetLocation(ParseLocationRec plr, std::string_view loc_str)
debug_msg("Breakpoint %d set at %s\n", GetID(), Description()); debug_msg("Breakpoint %d set at %s\n", GetID(), Description());
} }
else if ( plr.type == PLR_FUNCTION ) else if ( plr.type == PLR_FUNCTION ) {
{
std::string loc_s(loc_str); std::string loc_s(loc_str);
kind = BP_FUNC; kind = BP_FUNC;
function_name = make_full_var_name(current_module.c_str(), loc_s.c_str()); function_name = make_full_var_name(current_module.c_str(), loc_s.c_str());
at_stmt = plr.stmt; at_stmt = plr.stmt;
const Location* loc = at_stmt->GetLocationInfo(); const Location* loc = at_stmt->GetLocationInfo();
snprintf(description, sizeof(description), "%s at %s:%d", function_name.c_str(), snprintf(description, sizeof(description), "%s at %s:%d", function_name.c_str(), loc->filename, loc->last_line);
loc->filename, loc->last_line);
debug_msg("Breakpoint %d set at %s\n", GetID(), Description()); debug_msg("Breakpoint %d set at %s\n", GetID(), Description());
} }
@ -163,10 +142,9 @@ bool DbgBreakpoint::SetLocation(ParseLocationRec plr, std::string_view loc_str)
SetEnable(true); SetEnable(true);
AddToGlobalMap(); AddToGlobalMap();
return true; return true;
} }
bool DbgBreakpoint::SetLocation(Stmt* stmt) bool DbgBreakpoint::SetLocation(Stmt* stmt) {
{
if ( ! stmt ) if ( ! stmt )
return false; return false;
@ -182,10 +160,9 @@ bool DbgBreakpoint::SetLocation(Stmt* stmt)
debug_msg("Breakpoint %d set at %s\n", GetID(), Description()); debug_msg("Breakpoint %d set at %s\n", GetID(), Description());
return true; return true;
} }
bool DbgBreakpoint::SetLocation(double t) bool DbgBreakpoint::SetLocation(double t) {
{
debug_msg("SetLocation(time) has not been debugged."); debug_msg("SetLocation(time) has not been debugged.");
return false; return false;
@ -196,24 +173,20 @@ bool DbgBreakpoint::SetLocation(double t)
debug_msg("Time-based breakpoints not yet supported.\n"); debug_msg("Time-based breakpoints not yet supported.\n");
return false; return false;
} }
bool DbgBreakpoint::Reset() bool DbgBreakpoint::Reset() {
{
ParseLocationRec plr; ParseLocationRec plr;
switch ( kind ) switch ( kind ) {
{ case BP_TIME: debug_msg("Time-based breakpoints not yet supported.\n"); break;
case BP_TIME:
debug_msg("Time-based breakpoints not yet supported.\n");
break;
case BP_FUNC: case BP_FUNC:
case BP_STMT: case BP_STMT:
case BP_LINE: case BP_LINE:
plr.type = PLR_FUNCTION; plr.type = PLR_FUNCTION;
//### How to deal with wildcards? // ### How to deal with wildcards?
//### perhaps save user choices?--tough... // ### perhaps save user choices?--tough...
break; break;
} }
@ -221,61 +194,50 @@ bool DbgBreakpoint::Reset()
// Cannot be reached. // Cannot be reached.
return false; return false;
} }
bool DbgBreakpoint::SetCondition(const std::string& new_condition) bool DbgBreakpoint::SetCondition(const std::string& new_condition) {
{
condition = new_condition; condition = new_condition;
return true; return true;
} }
bool DbgBreakpoint::SetRepeatCount(int count) bool DbgBreakpoint::SetRepeatCount(int count) {
{
repeat_count = count; repeat_count = count;
return true; return true;
} }
BreakCode DbgBreakpoint::HasHit() BreakCode DbgBreakpoint::HasHit() {
{ if ( temporary ) {
if ( temporary )
{
SetEnable(false); SetEnable(false);
return BC_HIT_AND_DELETE; return BC_HIT_AND_DELETE;
} }
if ( condition.size() ) if ( condition.size() ) {
{
// TODO: ### evaluate using debugger frame too // TODO: ### evaluate using debugger frame too
auto yes = dbg_eval_expr(condition.c_str()); auto yes = dbg_eval_expr(condition.c_str());
if ( ! yes ) if ( ! yes ) {
{ debug_msg("Breakpoint condition '%s' invalid, removing condition.\n", condition.c_str());
debug_msg("Breakpoint condition '%s' invalid, removing condition.\n",
condition.c_str());
SetCondition(""); SetCondition("");
PrintHitMsg(); PrintHitMsg();
return BC_HIT; return BC_HIT;
} }
if ( ! IsIntegral(yes->GetType()->Tag()) && ! IsBool(yes->GetType()->Tag()) ) if ( ! IsIntegral(yes->GetType()->Tag()) && ! IsBool(yes->GetType()->Tag()) ) {
{
PrintHitMsg(); PrintHitMsg();
debug_msg("Breakpoint condition should return an integral type"); debug_msg("Breakpoint condition should return an integral type");
return BC_HIT_AND_DELETE; return BC_HIT_AND_DELETE;
} }
yes->CoerceToInt(); yes->CoerceToInt();
if ( yes->IsZero() ) if ( yes->IsZero() ) {
{
return BC_NO_HIT; return BC_NO_HIT;
} }
} }
int repcount = GetRepeatCount(); int repcount = GetRepeatCount();
if ( repcount ) if ( repcount ) {
{ if ( ++hit_count == repcount ) {
if ( ++hit_count == repcount )
{
hit_count = 0; hit_count = 0;
PrintHitMsg(); PrintHitMsg();
return BC_HIT; return BC_HIT;
@ -286,15 +248,13 @@ BreakCode DbgBreakpoint::HasHit()
PrintHitMsg(); PrintHitMsg();
return BC_HIT; return BC_HIT;
} }
BreakCode DbgBreakpoint::ShouldBreak(Stmt* s) BreakCode DbgBreakpoint::ShouldBreak(Stmt* s) {
{
if ( ! IsEnabled() ) if ( ! IsEnabled() )
return BC_NO_HIT; return BC_NO_HIT;
switch ( kind ) switch ( kind ) {
{
case BP_STMT: case BP_STMT:
case BP_FUNC: case BP_FUNC:
if ( at_stmt != s ) if ( at_stmt != s )
@ -302,15 +262,12 @@ BreakCode DbgBreakpoint::ShouldBreak(Stmt* s)
break; break;
case BP_LINE: case BP_LINE:
assert(s->GetLocationInfo()->first_line <= source_line && assert(s->GetLocationInfo()->first_line <= source_line && s->GetLocationInfo()->last_line >= source_line);
s->GetLocationInfo()->last_line >= source_line);
break; break;
case BP_TIME: case BP_TIME: assert(false);
assert(false);
default: default: reporter->InternalError("Invalid breakpoint type in DbgBreakpoint::ShouldBreak");
reporter->InternalError("Invalid breakpoint type in DbgBreakpoint::ShouldBreak");
} }
// If we got here, that means that the breakpoint could hit, // If we got here, that means that the breakpoint could hit,
@ -321,10 +278,9 @@ BreakCode DbgBreakpoint::ShouldBreak(Stmt* s)
g_debugger_state.BreakBeforeNextStmt(true); g_debugger_state.BreakBeforeNextStmt(true);
return code; return code;
} }
BreakCode DbgBreakpoint::ShouldBreak(double t) BreakCode DbgBreakpoint::ShouldBreak(double t) {
{
if ( kind != BP_TIME ) if ( kind != BP_TIME )
reporter->InternalError("Calling ShouldBreak(time) on a non-time breakpoint"); reporter->InternalError("Calling ShouldBreak(time) on a non-time breakpoint");
@ -339,16 +295,13 @@ BreakCode DbgBreakpoint::ShouldBreak(double t)
g_debugger_state.BreakBeforeNextStmt(true); g_debugger_state.BreakBeforeNextStmt(true);
return code; return code;
} }
void DbgBreakpoint::PrintHitMsg() void DbgBreakpoint::PrintHitMsg() {
{ switch ( kind ) {
switch ( kind )
{
case BP_STMT: case BP_STMT:
case BP_FUNC: case BP_FUNC:
case BP_LINE: case BP_LINE: {
{
ODesc d; ODesc d;
Frame* f = g_frame_stack.back(); Frame* f = g_frame_stack.back();
const ScriptFunc* func = f->GetFunction(); const ScriptFunc* func = f->GetFunction();
@ -358,17 +311,14 @@ void DbgBreakpoint::PrintHitMsg()
const Location* loc = at_stmt->GetLocationInfo(); const Location* loc = at_stmt->GetLocationInfo();
debug_msg("Breakpoint %d, %s at %s:%d\n", GetID(), d.Description(), loc->filename, debug_msg("Breakpoint %d, %s at %s:%d\n", GetID(), d.Description(), loc->filename, loc->first_line);
loc->first_line);
} }
return; return;
case BP_TIME: case BP_TIME: assert(false);
assert(false);
default: default: reporter->InternalError("Missed a case in DbgBreakpoint::PrintHitMsg\n");
reporter->InternalError("Missed a case in DbgBreakpoint::PrintHitMsg\n");
}
} }
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -6,27 +6,14 @@
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek::detail namespace zeek::detail {
{
class Stmt; class Stmt;
class ParseLocationRec; class ParseLocationRec;
enum BreakCode enum BreakCode { BC_NO_HIT, BC_HIT, BC_HIT_AND_DELETE };
{ class DbgBreakpoint {
BC_NO_HIT, enum Kind { BP_STMT = 0, BP_FUNC, BP_LINE, BP_TIME };
BC_HIT,
BC_HIT_AND_DELETE
};
class DbgBreakpoint
{
enum Kind
{
BP_STMT = 0,
BP_FUNC,
BP_LINE,
BP_TIME
};
public: public:
DbgBreakpoint(); DbgBreakpoint();
@ -95,6 +82,6 @@ protected:
int32_t hit_count; // how many times it's been hit (w/o breaking) int32_t hit_count; // how many times it's been hit (w/o breaking)
std::string condition; // condition to evaluate; nil for none std::string condition; // condition to evaluate; nil for none
}; };
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -2,20 +2,17 @@
#pragma once #pragma once
namespace zeek::detail namespace zeek::detail {
{
class Expr; class Expr;
// Automatic displays: display these at each stoppage. // Automatic displays: display these at each stoppage.
class DbgDisplay class DbgDisplay {
{
public: public:
DbgDisplay(Expr* expr_to_display); DbgDisplay(Expr* expr_to_display);
bool IsEnabled() { return enabled; } bool IsEnabled() { return enabled; }
bool SetEnable(bool do_enable) bool SetEnable(bool do_enable) {
{
bool old_value = enabled; bool old_value = enabled;
enabled = do_enable; enabled = do_enable;
return old_value; return old_value;
@ -26,6 +23,6 @@ public:
protected: protected:
bool enabled; bool enabled;
Expr* expression; Expr* expression;
}; };
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -7,18 +7,11 @@
#include "zeek/Debug.h" #include "zeek/Debug.h"
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
namespace zeek::detail namespace zeek::detail {
{
// Support classes // Support classes
DbgWatch::DbgWatch(zeek::Obj* var_to_watch) DbgWatch::DbgWatch(zeek::Obj* var_to_watch) { reporter->InternalError("DbgWatch unimplemented"); }
{
reporter->InternalError("DbgWatch unimplemented");
}
DbgWatch::DbgWatch(Expr* expr_to_watch) DbgWatch::DbgWatch(Expr* expr_to_watch) { reporter->InternalError("DbgWatch unimplemented"); }
{
reporter->InternalError("DbgWatch unimplemented");
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -4,18 +4,15 @@
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek namespace zeek {
{
class Obj; class Obj;
} }
namespace zeek::detail namespace zeek::detail {
{
class Expr; class Expr;
class DbgWatch class DbgWatch {
{
public: public:
explicit DbgWatch(Obj* var_to_watch); explicit DbgWatch(Obj* var_to_watch);
explicit DbgWatch(Expr* expr_to_watch); explicit DbgWatch(Expr* expr_to_watch);
@ -24,6 +21,6 @@ public:
protected: protected:
Obj* var; Obj* var;
Expr* expr; Expr* expr;
}; };
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -32,10 +32,9 @@
#include "zeek/module_util.h" #include "zeek/module_util.h"
#include "zeek/util.h" #include "zeek/util.h"
extern "C" extern "C" {
{
#include "zeek/3rdparty/setsignal.h" #include "zeek/3rdparty/setsignal.h"
} }
using namespace std; using namespace std;
@ -66,11 +65,9 @@ extern YYLTYPE yylloc; // holds start line and column of token
extern int line_number; extern int line_number;
extern const char* filename; extern const char* filename;
namespace zeek::detail namespace zeek::detail {
{
DebuggerState::DebuggerState() DebuggerState::DebuggerState() {
{
next_bp_id = next_watch_id = next_display_id = 1; next_bp_id = next_watch_id = next_display_id = 1;
BreakBeforeNextStmt(false); BreakBeforeNextStmt(false);
curr_frame_idx = 0; curr_frame_idx = 0;
@ -79,25 +76,20 @@ DebuggerState::DebuggerState()
// ### Don't choose this arbitrary size! Extend Frame. // ### Don't choose this arbitrary size! Extend Frame.
dbg_locals = new Frame(1024, /* func = */ nullptr, /* fn_args = */ nullptr); dbg_locals = new Frame(1024, /* func = */ nullptr, /* fn_args = */ nullptr);
} }
DebuggerState::~DebuggerState() DebuggerState::~DebuggerState() { Unref(dbg_locals); }
{
Unref(dbg_locals);
}
bool StmtLocMapping::StartsAfter(const StmtLocMapping* m2) bool StmtLocMapping::StartsAfter(const StmtLocMapping* m2) {
{
if ( ! m2 ) if ( ! m2 )
reporter->InternalError("Assertion failed: m2 != 0"); reporter->InternalError("Assertion failed: m2 != 0");
return loc.first_line > m2->loc.first_line || return loc.first_line > m2->loc.first_line ||
(loc.first_line == m2->loc.first_line && loc.first_column > m2->loc.first_column); (loc.first_line == m2->loc.first_line && loc.first_column > m2->loc.first_column);
} }
// Generic debug message output. // Generic debug message output.
int debug_msg(const char* fmt, ...) int debug_msg(const char* fmt, ...) {
{
va_list args; va_list args;
int retval; int retval;
@ -106,12 +98,11 @@ int debug_msg(const char* fmt, ...)
va_end(args); va_end(args);
return retval; return retval;
} }
// Trace message output // Trace message output
FILE* TraceState::SetTraceFile(const char* trace_filename) FILE* TraceState::SetTraceFile(const char* trace_filename) {
{
FILE* newfile; FILE* newfile;
if ( util::streq(trace_filename, "-") ) if ( util::streq(trace_filename, "-") )
@ -120,33 +111,28 @@ FILE* TraceState::SetTraceFile(const char* trace_filename)
newfile = fopen(trace_filename, "w"); newfile = fopen(trace_filename, "w");
FILE* oldfile = trace_file; FILE* oldfile = trace_file;
if ( newfile ) if ( newfile ) {
{
trace_file = newfile; trace_file = newfile;
} }
else else {
{
fprintf(stderr, "Unable to open trace file %s\n", trace_filename); fprintf(stderr, "Unable to open trace file %s\n", trace_filename);
trace_file = nullptr; trace_file = nullptr;
} }
return oldfile; return oldfile;
} }
void TraceState::TraceOn() void TraceState::TraceOn() {
{
fprintf(stderr, "Execution tracing ON.\n"); fprintf(stderr, "Execution tracing ON.\n");
dbgtrace = true; dbgtrace = true;
} }
void TraceState::TraceOff() void TraceState::TraceOff() {
{
fprintf(stderr, "Execution tracing OFF.\n"); fprintf(stderr, "Execution tracing OFF.\n");
dbgtrace = false; dbgtrace = false;
} }
int TraceState::LogTrace(const char* fmt, ...) int TraceState::LogTrace(const char* fmt, ...) {
{
va_list args; va_list args;
int retval; int retval;
@ -159,21 +145,18 @@ int TraceState::LogTrace(const char* fmt, ...)
Location loc; Location loc;
loc.filename = nullptr; loc.filename = nullptr;
if ( g_frame_stack.size() > 0 && g_frame_stack.back() ) if ( g_frame_stack.size() > 0 && g_frame_stack.back() ) {
{
stmt = g_frame_stack.back()->GetNextStmt(); stmt = g_frame_stack.back()->GetNextStmt();
if ( stmt ) if ( stmt )
loc = *stmt->GetLocationInfo(); loc = *stmt->GetLocationInfo();
else else {
{
const ScriptFunc* f = g_frame_stack.back()->GetFunction(); const ScriptFunc* f = g_frame_stack.back()->GetFunction();
if ( f ) if ( f )
loc = *f->GetLocationInfo(); loc = *f->GetLocationInfo();
} }
} }
if ( ! loc.filename ) if ( ! loc.filename ) {
{
loc.filename = util::copy_string("<no filename>"); loc.filename = util::copy_string("<no filename>");
loc.last_line = 0; loc.last_line = 0;
} }
@ -190,20 +173,17 @@ int TraceState::LogTrace(const char* fmt, ...)
va_end(args); va_end(args);
return retval; return retval;
} }
// Helper functions. // Helper functions.
void get_first_statement(Stmt* list, Stmt*& first, Location& loc) void get_first_statement(Stmt* list, Stmt*& first, Location& loc) {
{ if ( ! list ) {
if ( ! list )
{
first = nullptr; first = nullptr;
return; return;
} }
first = list; first = list;
while ( first->Tag() == STMT_LIST ) while ( first->Tag() == STMT_LIST ) {
{
if ( first->AsStmtList()->Stmts()[0] ) if ( first->AsStmtList()->Stmts()[0] )
first = first->AsStmtList()->Stmts()[0].get(); first = first->AsStmtList()->Stmts()[0].get();
else else
@ -211,30 +191,26 @@ void get_first_statement(Stmt* list, Stmt*& first, Location& loc)
} }
loc = *first->GetLocationInfo(); loc = *first->GetLocationInfo();
} }
static void parse_function_name(vector<ParseLocationRec>& result, ParseLocationRec& plr, static void parse_function_name(vector<ParseLocationRec>& result, ParseLocationRec& plr,
const string& s) const string& s) { // function name
{ // function name
const auto& id = lookup_ID(s.c_str(), current_module.c_str()); const auto& id = lookup_ID(s.c_str(), current_module.c_str());
if ( ! id ) if ( ! id ) {
{
string fullname = make_full_var_name(current_module.c_str(), s.c_str()); string fullname = make_full_var_name(current_module.c_str(), s.c_str());
debug_msg("Function %s not defined.\n", fullname.c_str()); debug_msg("Function %s not defined.\n", fullname.c_str());
plr.type = PLR_UNKNOWN; plr.type = PLR_UNKNOWN;
return; return;
} }
if ( ! id->GetType()->AsFuncType() ) if ( ! id->GetType()->AsFuncType() ) {
{
debug_msg("Function %s not declared.\n", id->Name()); debug_msg("Function %s not declared.\n", id->Name());
plr.type = PLR_UNKNOWN; plr.type = PLR_UNKNOWN;
return; return;
} }
if ( ! id->HasVal() ) if ( ! id->HasVal() ) {
{
debug_msg("Function %s declared but not defined.\n", id->Name()); debug_msg("Function %s declared but not defined.\n", id->Name());
plr.type = PLR_UNKNOWN; plr.type = PLR_UNKNOWN;
return; return;
@ -243,8 +219,7 @@ static void parse_function_name(vector<ParseLocationRec>& result, ParseLocationR
const Func* func = id->GetVal()->AsFunc(); const Func* func = id->GetVal()->AsFunc();
const vector<Func::Body>& bodies = func->GetBodies(); const vector<Func::Body>& bodies = func->GetBodies();
if ( bodies.size() == 0 ) if ( bodies.size() == 0 ) {
{
debug_msg("Function %s is a built-in function\n", id->Name()); debug_msg("Function %s is a built-in function\n", id->Name());
plr.type = PLR_UNKNOWN; plr.type = PLR_UNKNOWN;
return; return;
@ -254,14 +229,12 @@ static void parse_function_name(vector<ParseLocationRec>& result, ParseLocationR
if ( bodies.size() == 1 ) if ( bodies.size() == 1 )
body = bodies[0].stmts.get(); body = bodies[0].stmts.get();
else else {
{ while ( true ) {
while ( true ) debug_msg(
{ "There are multiple definitions of that event handler.\n"
debug_msg("There are multiple definitions of that event handler.\n"
"Please choose one of the following options:\n"); "Please choose one of the following options:\n");
for ( unsigned int i = 0; i < bodies.size(); ++i ) for ( unsigned int i = 0; i < bodies.size(); ++i ) {
{
Stmt* first; Stmt* first;
Location stmt_loc; Location stmt_loc;
get_first_statement(bodies[i].stmts.get(), first, stmt_loc); get_first_statement(bodies[i].stmts.get(), first, stmt_loc);
@ -273,8 +246,7 @@ static void parse_function_name(vector<ParseLocationRec>& result, ParseLocationR
debug_msg("Enter your choice: "); debug_msg("Enter your choice: ");
char charinput[256]; char charinput[256];
if ( ! fgets(charinput, sizeof(charinput) - 1, stdin) ) if ( ! fgets(charinput, sizeof(charinput) - 1, stdin) ) {
{
plr.type = PLR_UNKNOWN; plr.type = PLR_UNKNOWN;
return; return;
} }
@ -287,15 +259,13 @@ static void parse_function_name(vector<ParseLocationRec>& result, ParseLocationR
if ( input == "a" ) if ( input == "a" )
break; break;
if ( input == "n" ) if ( input == "n" ) {
{
plr.type = PLR_UNKNOWN; plr.type = PLR_UNKNOWN;
return; return;
} }
int option = atoi(input.c_str()); int option = atoi(input.c_str());
if ( option > 0 && option <= (int)bodies.size() ) if ( option > 0 && option <= (int)bodies.size() ) {
{
body = bodies[option - 1].stmts.get(); body = bodies[option - 1].stmts.get();
break; break;
} }
@ -308,24 +278,20 @@ static void parse_function_name(vector<ParseLocationRec>& result, ParseLocationR
Stmt* first; Stmt* first;
Location stmt_loc; Location stmt_loc;
if ( body ) if ( body ) {
{
get_first_statement(body, first, stmt_loc); get_first_statement(body, first, stmt_loc);
if ( first ) if ( first ) {
{
plr.stmt = first; plr.stmt = first;
plr.filename = stmt_loc.filename; plr.filename = stmt_loc.filename;
plr.line = stmt_loc.last_line; plr.line = stmt_loc.last_line;
} }
} }
else else {
{
result.pop_back(); result.pop_back();
ParseLocationRec result_plr; ParseLocationRec result_plr;
for ( const auto& body : bodies ) for ( const auto& body : bodies ) {
{
get_first_statement(body.stmts.get(), first, stmt_loc); get_first_statement(body.stmts.get(), first, stmt_loc);
if ( ! first ) if ( ! first )
continue; continue;
@ -337,10 +303,9 @@ static void parse_function_name(vector<ParseLocationRec>& result, ParseLocationR
result.push_back(result_plr); result.push_back(result_plr);
} }
} }
} }
vector<ParseLocationRec> parse_location_string(const string& s) vector<ParseLocationRec> parse_location_string(const string& s) {
{
vector<ParseLocationRec> result; vector<ParseLocationRec> result;
result.emplace_back(); result.emplace_back();
ParseLocationRec& plr = result[0]; ParseLocationRec& plr = result[0];
@ -350,21 +315,18 @@ vector<ParseLocationRec> parse_location_string(const string& s)
// up the line number to find the corresponding statement. // up the line number to find the corresponding statement.
std::string loc_filename; std::string loc_filename;
if ( sscanf(s.c_str(), "%d", &plr.line) ) if ( sscanf(s.c_str(), "%d", &plr.line) ) { // just a line number (implicitly referring to the current file)
{ // just a line number (implicitly referring to the current file)
loc_filename = g_debugger_state.last_loc.filename; loc_filename = g_debugger_state.last_loc.filename;
plr.type = PLR_FILE_AND_LINE; plr.type = PLR_FILE_AND_LINE;
} }
else else {
{
string::size_type pos_colon = s.find(':'); string::size_type pos_colon = s.find(':');
string::size_type pos_dblcolon = s.find("::"); string::size_type pos_dblcolon = s.find("::");
if ( pos_colon == string::npos || pos_dblcolon != string::npos ) if ( pos_colon == string::npos || pos_dblcolon != string::npos )
parse_function_name(result, plr, s); parse_function_name(result, plr, s);
else else { // file:line
{ // file:line
string policy_filename = s.substr(0, pos_colon); string policy_filename = s.substr(0, pos_colon);
string line_string = s.substr(pos_colon + 1, s.length() - pos_colon); string line_string = s.substr(pos_colon + 1, s.length() - pos_colon);
@ -373,8 +335,7 @@ vector<ParseLocationRec> parse_location_string(const string& s)
string path(util::find_script_file(policy_filename, util::zeek_path())); string path(util::find_script_file(policy_filename, util::zeek_path()));
if ( path.empty() ) if ( path.empty() ) {
{
debug_msg("No such policy file: %s.\n", policy_filename.c_str()); debug_msg("No such policy file: %s.\n", policy_filename.c_str());
plr.type = PLR_UNKNOWN; plr.type = PLR_UNKNOWN;
return result; return result;
@ -385,30 +346,25 @@ vector<ParseLocationRec> parse_location_string(const string& s)
} }
} }
if ( plr.type == PLR_FILE_AND_LINE ) if ( plr.type == PLR_FILE_AND_LINE ) {
{
auto iter = g_dbgfilemaps.find(loc_filename); auto iter = g_dbgfilemaps.find(loc_filename);
if ( iter == g_dbgfilemaps.end() ) if ( iter == g_dbgfilemaps.end() )
reporter->InternalError("Policy file %s should have been loaded\n", reporter->InternalError("Policy file %s should have been loaded\n", loc_filename.data());
loc_filename.data());
if ( plr.line > how_many_lines_in(loc_filename.data()) ) if ( plr.line > how_many_lines_in(loc_filename.data()) ) {
{
debug_msg("No line %d in %s.\n", plr.line, loc_filename.data()); debug_msg("No line %d in %s.\n", plr.line, loc_filename.data());
plr.type = PLR_UNKNOWN; plr.type = PLR_UNKNOWN;
return result; return result;
} }
StmtLocMapping* hit = nullptr; StmtLocMapping* hit = nullptr;
for ( const auto entry : *(iter->second) ) for ( const auto entry : *(iter->second) ) {
{
plr.filename = entry->Loc().filename; plr.filename = entry->Loc().filename;
if ( entry->Loc().first_line > plr.line ) if ( entry->Loc().first_line > plr.line )
break; break;
if ( plr.line >= entry->Loc().first_line && plr.line <= entry->Loc().last_line ) if ( plr.line >= entry->Loc().first_line && plr.line <= entry->Loc().last_line ) {
{
hit = entry; hit = entry;
break; break;
} }
@ -421,7 +377,7 @@ vector<ParseLocationRec> parse_location_string(const string& s)
} }
return result; return result;
} }
// Interactive debugging console. // Interactive debugging console.
@ -431,26 +387,23 @@ static int dbg_dispatch_cmd(DebugCmd cmd_code, const vector<string>& args);
void using_history(void); void using_history(void);
static bool init_readline() static bool init_readline() {
{
// ### Set up custom completion. // ### Set up custom completion.
rl_outstream = stderr; rl_outstream = stderr;
using_history(); using_history();
return false; return false;
} }
#endif #endif
void break_signal(int) void break_signal(int) {
{
g_debugger_state.BreakBeforeNextStmt(true); g_debugger_state.BreakBeforeNextStmt(true);
g_debugger_state.BreakFromSignal(true); g_debugger_state.BreakFromSignal(true);
} }
int dbg_init_debugger(const char* cmdfile) int dbg_init_debugger(const char* cmdfile) {
{
if ( ! g_policy_debug ) if ( ! g_policy_debug )
return 0; // probably shouldn't have been called return 0; // probably shouldn't have been called
@ -472,13 +425,12 @@ int dbg_init_debugger(const char* cmdfile)
setsignal(SIGTERM, break_signal); setsignal(SIGTERM, break_signal);
return 1; return 1;
} }
int dbg_shutdown_debugger() int dbg_shutdown_debugger() {
{
// ### TODO: Remove signal handlers // ### TODO: Remove signal handlers
return 1; return 1;
} }
// Umesh: I stole this code from libedit; I modified it here to use // Umesh: I stole this code from libedit; I modified it here to use
// <string>s to avoid memory management problems. The main command is returned // <string>s to avoid memory management problems. The main command is returned
@ -488,23 +440,19 @@ int dbg_shutdown_debugger()
// Parse the string into individual tokens, similarly to how shell // Parse the string into individual tokens, similarly to how shell
// would do it. // would do it.
void tokenize(const char* cstr, string& operation, vector<string>& arguments) void tokenize(const char* cstr, string& operation, vector<string>& arguments) {
{
int num_tokens = 0; int num_tokens = 0;
char delim = '\0'; char delim = '\0';
const string str(cstr); const string str(cstr);
for ( int i = 0; i < (signed int)str.length(); ++i ) for ( int i = 0; i < (signed int)str.length(); ++i ) {
{
while ( isspace((unsigned char)str[i]) ) while ( isspace((unsigned char)str[i]) )
++i; ++i;
int start = i; int start = i;
for ( ; str[i]; ++i ) for ( ; str[i]; ++i ) {
{ if ( str[i] == '\\' ) {
if ( str[i] == '\\' )
{
if ( i < (signed int)str.length() ) if ( i < (signed int)str.length() )
++i; ++i;
} }
@ -515,8 +463,7 @@ void tokenize(const char* cstr, string& operation, vector<string>& arguments)
else if ( ! delim && (str[i] == '\'' || str[i] == '"') ) else if ( ! delim && (str[i] == '\'' || str[i] == '"') )
delim = str[i]; delim = str[i];
else if ( delim && str[i] == delim ) else if ( delim && str[i] == delim ) {
{
delim = '\0'; delim = '\0';
++i; ++i;
break; break;
@ -535,11 +482,10 @@ void tokenize(const char* cstr, string& operation, vector<string>& arguments)
++num_tokens; ++num_tokens;
} }
} }
// Given a command string, parse it and send the command to be dispatched. // Given a command string, parse it and send the command to be dispatched.
int dbg_execute_command(const char* cmd) int dbg_execute_command(const char* cmd) {
{
bool matched_history = false; bool matched_history = false;
if ( ! cmd ) if ( ! cmd )
@ -549,16 +495,14 @@ int dbg_execute_command(const char* cmd)
{ {
#ifdef HAVE_READLINE #ifdef HAVE_READLINE
int i; int i;
for ( i = history_length; i >= 1; --i ) for ( i = history_length; i >= 1; --i ) {
{
HIST_ENTRY* entry = history_get(i); HIST_ENTRY* entry = history_get(i);
if ( ! entry ) if ( ! entry )
return 0; return 0;
const DebugCmdInfo* info = (const DebugCmdInfo*)entry->data; const DebugCmdInfo* info = (const DebugCmdInfo*)entry->data;
if ( info && info->Repeatable() ) if ( info && info->Repeatable() ) {
{
cmd = entry->line; cmd = entry->line;
matched_history = true; matched_history = true;
break; break;
@ -583,14 +527,12 @@ int dbg_execute_command(const char* cmd)
auto matching_cmds = matching_cmds_buf.get(); auto matching_cmds = matching_cmds_buf.get();
int num_matches = find_all_matching_cmds(opstring, matching_cmds); int num_matches = find_all_matching_cmds(opstring, matching_cmds);
if ( ! num_matches ) if ( ! num_matches ) {
{
debug_msg("No Matching command for '%s'.\n", opstring.c_str()); debug_msg("No Matching command for '%s'.\n", opstring.c_str());
return 0; return 0;
} }
if ( num_matches > 1 ) if ( num_matches > 1 ) {
{
debug_msg("Ambiguous command; could be\n"); debug_msg("Ambiguous command; could be\n");
for ( int i = 0; i < num_debug_cmds(); ++i ) for ( int i = 0; i < num_debug_cmds(); ++i )
@ -603,16 +545,14 @@ int dbg_execute_command(const char* cmd)
// Matched exactly one command: find out which one. // Matched exactly one command: find out which one.
DebugCmd cmd_code = dcInvalid; DebugCmd cmd_code = dcInvalid;
for ( int i = 0; i < num_debug_cmds(); ++i ) for ( int i = 0; i < num_debug_cmds(); ++i )
if ( matching_cmds[i] ) if ( matching_cmds[i] ) {
{
cmd_code = (DebugCmd)i; cmd_code = (DebugCmd)i;
break; break;
} }
#ifdef HAVE_READLINE #ifdef HAVE_READLINE
// Insert command into history. // Insert command into history.
if ( ! matched_history && cmd && *cmd ) if ( ! matched_history && cmd && *cmd ) {
{
/* The prototype for add_history(), at least under MacOS, /* The prototype for add_history(), at least under MacOS,
* has it taking a char* rather than a const char*. * has it taking a char* rather than a const char*.
* But documentation at * But documentation at
@ -641,20 +581,14 @@ int dbg_execute_command(const char* cmd)
return -2; // ### yuck, why -2? return -2; // ### yuck, why -2?
return info->ResumeExecution(); return info->ResumeExecution();
} }
// Call the appropriate function for the command. // Call the appropriate function for the command.
static int dbg_dispatch_cmd(DebugCmd cmd_code, const vector<string>& args) static int dbg_dispatch_cmd(DebugCmd cmd_code, const vector<string>& args) {
{ switch ( cmd_code ) {
switch ( cmd_code ) case dcHelp: dbg_cmd_help(cmd_code, args); break;
{
case dcHelp:
dbg_cmd_help(cmd_code, args);
break;
case dcQuit: case dcQuit: debug_msg("Program Terminating\n"); exit(0);
debug_msg("Program Terminating\n");
exit(0);
case dcNext: case dcNext:
g_frame_stack.back()->BreakBeforeNextStmt(true); g_frame_stack.back()->BreakBeforeNextStmt(true);
@ -678,60 +612,45 @@ static int dbg_dispatch_cmd(DebugCmd cmd_code, const vector<string>& args)
g_debugger_state.BreakBeforeNextStmt(false); g_debugger_state.BreakBeforeNextStmt(false);
break; break;
case dcBreak: case dcBreak: dbg_cmd_break(cmd_code, args); break;
dbg_cmd_break(cmd_code, args);
break;
case dcBreakCondition: case dcBreakCondition: dbg_cmd_break_condition(cmd_code, args); break;
dbg_cmd_break_condition(cmd_code, args);
break;
case dcDeleteBreak: case dcDeleteBreak:
case dcClearBreak: case dcClearBreak:
case dcDisableBreak: case dcDisableBreak:
case dcEnableBreak: case dcEnableBreak:
case dcIgnoreBreak: case dcIgnoreBreak: dbg_cmd_break_set_state(cmd_code, args); break;
dbg_cmd_break_set_state(cmd_code, args);
break;
case dcPrint: case dcPrint: dbg_cmd_print(cmd_code, args); break;
dbg_cmd_print(cmd_code, args);
break;
case dcBacktrace: case dcBacktrace: return dbg_cmd_backtrace(cmd_code, args);
return dbg_cmd_backtrace(cmd_code, args);
case dcFrame: case dcFrame:
case dcUp: case dcUp:
case dcDown: case dcDown: return dbg_cmd_frame(cmd_code, args);
return dbg_cmd_frame(cmd_code, args);
case dcInfo: case dcInfo: return dbg_cmd_info(cmd_code, args);
return dbg_cmd_info(cmd_code, args);
case dcList: case dcList: return dbg_cmd_list(cmd_code, args);
return dbg_cmd_list(cmd_code, args);
case dcDisplay: case dcDisplay:
case dcUndisplay: case dcUndisplay: debug_msg("Command not yet implemented.\n"); break;
debug_msg("Command not yet implemented.\n");
break;
case dcTrace: case dcTrace: return dbg_cmd_trace(cmd_code, args);
return dbg_cmd_trace(cmd_code, args);
default: default:
debug_msg("INTERNAL ERROR: " debug_msg(
"INTERNAL ERROR: "
"Got an unknown debugger command in DbgDispatchCmd: %d\n", "Got an unknown debugger command in DbgDispatchCmd: %d\n",
cmd_code); cmd_code);
return 0; return 0;
} }
return 0; return 0;
} }
static char* get_prompt(bool reset_counter = false) static char* get_prompt(bool reset_counter = false) {
{
static char prompt[512]; static char prompt[512];
static int counter = 0; static int counter = 0;
@ -741,10 +660,9 @@ static char* get_prompt(bool reset_counter = false)
snprintf(prompt, sizeof(prompt), "(Zeek [%d]) ", counter++); snprintf(prompt, sizeof(prompt), "(Zeek [%d]) ", counter++);
return prompt; return prompt;
} }
string get_context_description(const Stmt* stmt, const Frame* frame) string get_context_description(const Stmt* stmt, const Frame* frame) {
{
ODesc d; ODesc d;
const ScriptFunc* func = frame ? frame->GetFunction() : nullptr; const ScriptFunc* func = frame ? frame->GetFunction() : nullptr;
@ -756,8 +674,7 @@ string get_context_description(const Stmt* stmt, const Frame* frame)
Location loc; Location loc;
if ( stmt ) if ( stmt )
loc = *stmt->GetLocationInfo(); loc = *stmt->GetLocationInfo();
else else {
{
loc.filename = util::copy_string("<no filename>"); loc.filename = util::copy_string("<no filename>");
loc.last_line = 0; loc.last_line = 0;
} }
@ -769,15 +686,13 @@ string get_context_description(const Stmt* stmt, const Frame* frame)
string retval(buf); string retval(buf);
delete[] buf; delete[] buf;
return retval; return retval;
} }
int dbg_handle_debug_input() int dbg_handle_debug_input() {
{
static char* input_line = nullptr; static char* input_line = nullptr;
int status = 0; int status = 0;
if ( g_debugger_state.BreakFromSignal() ) if ( g_debugger_state.BreakFromSignal() ) {
{
debug_msg("Program received signal SIGINT: entering debugger\n"); debug_msg("Program received signal SIGINT: entering debugger\n");
g_debugger_state.BreakFromSignal(false); g_debugger_state.BreakFromSignal(false);
@ -796,8 +711,7 @@ int dbg_handle_debug_input()
const Location loc = *stmt->GetLocationInfo(); const Location loc = *stmt->GetLocationInfo();
if ( ! step_or_next_pending || g_frame_stack.back() != last_frame ) if ( ! step_or_next_pending || g_frame_stack.back() != last_frame ) {
{
string context = get_context_description(stmt, g_frame_stack.back()); string context = get_context_description(stmt, g_frame_stack.back());
debug_msg("%s\n", context.c_str()); debug_msg("%s\n", context.c_str());
} }
@ -807,8 +721,7 @@ int dbg_handle_debug_input()
PrintLines(loc.filename, loc.first_line, loc.last_line - loc.first_line + 1, true); PrintLines(loc.filename, loc.first_line, loc.last_line - loc.first_line + 1, true);
g_debugger_state.last_loc = loc; g_debugger_state.last_loc = loc;
do do {
{
// readline returns a pointer to a buffer it allocates; it's // readline returns a pointer to a buffer it allocates; it's
// freed at the bottom. // freed at the bottom.
#ifdef HAVE_READLINE #ifdef HAVE_READLINE
@ -830,8 +743,7 @@ int dbg_handle_debug_input()
status = dbg_execute_command(input_line); status = dbg_execute_command(input_line);
if ( input_line ) if ( input_line ) {
{
free(input_line); // this was malloc'ed free(input_line); // this was malloc'ed
input_line = nullptr; input_line = nullptr;
} }
@ -847,16 +759,14 @@ int dbg_handle_debug_input()
setsignal(SIGTERM, break_signal); setsignal(SIGTERM, break_signal);
return 0; return 0;
} }
// Return true to continue execution, false to abort. // Return true to continue execution, false to abort.
bool pre_execute_stmt(Stmt* stmt, Frame* f) bool pre_execute_stmt(Stmt* stmt, Frame* f) {
{
if ( ! g_policy_debug || stmt->Tag() == STMT_LIST || stmt->Tag() == STMT_NULL ) if ( ! g_policy_debug || stmt->Tag() == STMT_LIST || stmt->Tag() == STMT_NULL )
return true; return true;
if ( g_trace_state.DoTrace() ) if ( g_trace_state.DoTrace() ) {
{
ODesc d; ODesc d;
stmt->Describe(&d); stmt->Describe(&d);
@ -874,8 +784,7 @@ bool pre_execute_stmt(Stmt* stmt, Frame* f)
bool should_break = false; bool should_break = false;
if ( g_debugger_state.BreakBeforeNextStmt() || f->BreakBeforeNextStmt() ) if ( g_debugger_state.BreakBeforeNextStmt() || f->BreakBeforeNextStmt() ) {
{
if ( g_debugger_state.BreakBeforeNextStmt() ) if ( g_debugger_state.BreakBeforeNextStmt() )
g_debugger_state.BreakBeforeNextStmt(false); g_debugger_state.BreakBeforeNextStmt(false);
@ -885,8 +794,7 @@ bool pre_execute_stmt(Stmt* stmt, Frame* f)
should_break = true; should_break = true;
} }
if ( stmt->BPCount() ) if ( stmt->BPCount() ) {
{
pair<BPMapType::iterator, BPMapType::iterator> p; pair<BPMapType::iterator, BPMapType::iterator> p;
p = g_debugger_state.breakpoint_map.equal_range(stmt); p = g_debugger_state.breakpoint_map.equal_range(stmt);
@ -894,8 +802,7 @@ bool pre_execute_stmt(Stmt* stmt, Frame* f)
if ( p.first == p.second ) if ( p.first == p.second )
reporter->InternalError("Breakpoint count nonzero, but no matching breakpoints"); reporter->InternalError("Breakpoint count nonzero, but no matching breakpoints");
for ( BPMapType::iterator i = p.first; i != p.second; ++i ) for ( BPMapType::iterator i = p.first; i != p.second; ++i ) {
{
int break_code = i->second->ShouldBreak(stmt); int break_code = i->second->ShouldBreak(stmt);
if ( break_code == 2 ) // ### 2? if ( break_code == 2 ) // ### 2?
{ {
@ -911,10 +818,9 @@ bool pre_execute_stmt(Stmt* stmt, Frame* f)
dbg_handle_debug_input(); dbg_handle_debug_input();
return true; return true;
} }
bool post_execute_stmt(Stmt* stmt, Frame* f, Val* result, StmtFlowType* flow) bool post_execute_stmt(Stmt* stmt, Frame* f, Val* result, StmtFlowType* flow) {
{
// Handle the case where someone issues a "next" debugger command, // Handle the case where someone issues a "next" debugger command,
// but we're at a return statement, so the next statement is in // but we're at a return statement, so the next statement is in
// some other function. // some other function.
@ -922,10 +828,8 @@ bool post_execute_stmt(Stmt* stmt, Frame* f, Val* result, StmtFlowType* flow)
g_debugger_state.BreakBeforeNextStmt(true); g_debugger_state.BreakBeforeNextStmt(true);
// Handle "finish" commands. // Handle "finish" commands.
if ( *flow == FLOW_RETURN && f->BreakOnReturn() ) if ( *flow == FLOW_RETURN && f->BreakOnReturn() ) {
{ if ( result ) {
if ( result )
{
ODesc d; ODesc d;
result->Describe(&d); result->Describe(&d);
debug_msg("Return Value: '%s'\n", d.Description()); debug_msg("Return Value: '%s'\n", d.Description());
@ -938,18 +842,16 @@ bool post_execute_stmt(Stmt* stmt, Frame* f, Val* result, StmtFlowType* flow)
} }
return true; return true;
} }
ValPtr dbg_eval_expr(const char* expr) ValPtr dbg_eval_expr(const char* expr) {
{
// Push the current frame's associated scope. // Push the current frame's associated scope.
// Note: g_debugger_state.curr_frame_idx is the user-visible number, // Note: g_debugger_state.curr_frame_idx is the user-visible number,
// while the array index goes in the opposite direction // while the array index goes in the opposite direction
int frame_idx = (g_frame_stack.size() - 1) - g_debugger_state.curr_frame_idx; int frame_idx = (g_frame_stack.size() - 1) - g_debugger_state.curr_frame_idx;
if ( ! (frame_idx >= 0 && (unsigned)frame_idx < g_frame_stack.size()) ) if ( ! (frame_idx >= 0 && (unsigned)frame_idx < g_frame_stack.size()) )
reporter->InternalError( reporter->InternalError("Assertion failed: frame_idx >= 0 && (unsigned) frame_idx < g_frame_stack.size()");
"Assertion failed: frame_idx >= 0 && (unsigned) frame_idx < g_frame_stack.size()");
Frame* frame = g_frame_stack[frame_idx]; Frame* frame = g_frame_stack[frame_idx];
if ( ! (frame) ) if ( ! (frame) )
@ -973,15 +875,13 @@ ValPtr dbg_eval_expr(const char* expr)
// Parse the thing into an expr. // Parse the thing into an expr.
ValPtr result; ValPtr result;
if ( yyparse() ) if ( yyparse() ) {
{
if ( g_curr_debug_error ) if ( g_curr_debug_error )
debug_msg("Parsing expression '%s' failed: %s\n", expr, g_curr_debug_error); debug_msg("Parsing expression '%s' failed: %s\n", expr, g_curr_debug_error);
else else
debug_msg("Parsing expression '%s' failed\n", expr); debug_msg("Parsing expression '%s' failed\n", expr);
if ( g_curr_debug_expr ) if ( g_curr_debug_expr ) {
{
delete g_curr_debug_expr; delete g_curr_debug_expr;
g_curr_debug_expr = nullptr; g_curr_debug_expr = nullptr;
} }
@ -999,6 +899,6 @@ ValPtr dbg_eval_expr(const char* expr)
in_debug = false; in_debug = false;
return result; return result;
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -11,17 +11,16 @@
#include "zeek/StmtEnums.h" #include "zeek/StmtEnums.h"
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek namespace zeek {
{
class Val; class Val;
template <class T> class IntrusivePtr; template<class T>
class IntrusivePtr;
using ValPtr = zeek::IntrusivePtr<Val>; using ValPtr = zeek::IntrusivePtr<Val>;
extern std::string current_module; extern std::string current_module;
namespace detail namespace detail {
{
class Frame; class Frame;
class Stmt; class Stmt;
@ -30,20 +29,14 @@ class DbgWatch;
class DbgDisplay; class DbgDisplay;
// This needs to be defined before we do the includes that come after it. // This needs to be defined before we do the includes that come after it.
enum ParseLocationRecType enum ParseLocationRecType { PLR_UNKNOWN, PLR_FILE_AND_LINE, PLR_FUNCTION };
{ class ParseLocationRec {
PLR_UNKNOWN,
PLR_FILE_AND_LINE,
PLR_FUNCTION
};
class ParseLocationRec
{
public: public:
ParseLocationRecType type; ParseLocationRecType type;
int32_t line; int32_t line;
Stmt* stmt; Stmt* stmt;
const char* filename; const char* filename;
}; };
class StmtLocMapping; class StmtLocMapping;
using Filemap = std::deque<StmtLocMapping*>; // mapping for a single file using Filemap = std::deque<StmtLocMapping*>; // mapping for a single file
@ -51,11 +44,9 @@ using Filemap = std::deque<StmtLocMapping*>; // mapping for a single file
using BPIDMapType = std::map<int, DbgBreakpoint*>; using BPIDMapType = std::map<int, DbgBreakpoint*>;
using BPMapType = std::multimap<const Stmt*, DbgBreakpoint*>; using BPMapType = std::multimap<const Stmt*, DbgBreakpoint*>;
class TraceState class TraceState {
{
public: public:
TraceState() TraceState() {
{
dbgtrace = false; dbgtrace = false;
trace_file = stderr; trace_file = stderr;
} }
@ -73,12 +64,11 @@ public:
protected: protected:
bool dbgtrace; // print an execution trace bool dbgtrace; // print an execution trace
FILE* trace_file; FILE* trace_file;
}; };
extern TraceState g_trace_state; extern TraceState g_trace_state;
class DebuggerState class DebuggerState {
{
public: public:
DebuggerState(); DebuggerState();
~DebuggerState(); ~DebuggerState();
@ -95,7 +85,7 @@ public:
// Temporary state: vanishes when execution resumes. // Temporary state: vanishes when execution resumes.
//### Umesh, why do these all need to be public? -- Vern // ### Umesh, why do these all need to be public? -- Vern
// Which frame we're looking at; 0 = the innermost frame. // Which frame we're looking at; 0 = the innermost frame.
int curr_frame_idx; int curr_frame_idx;
@ -117,16 +107,14 @@ protected:
private: private:
Frame* dbg_locals; // unused Frame* dbg_locals; // unused
}; };
// Source line -> statement mapping. // Source line -> statement mapping.
// (obj -> source line mapping available in object itself) // (obj -> source line mapping available in object itself)
class StmtLocMapping class StmtLocMapping {
{
public: public:
StmtLocMapping() { } StmtLocMapping() {}
StmtLocMapping(const Location* l, Stmt* s) StmtLocMapping(const Location* l, Stmt* s) {
{
loc = *l; loc = *l;
stmt = s; stmt = s;
} }
@ -138,7 +126,7 @@ public:
protected: protected:
Location loc; Location loc;
Stmt* stmt = nullptr; Stmt* stmt = nullptr;
}; };
extern bool g_policy_debug; // enable debugging facility extern bool g_policy_debug; // enable debugging facility
extern DebuggerState g_debugger_state; extern DebuggerState g_debugger_state;
@ -195,5 +183,5 @@ extern std::map<std::string, Filemap*> g_dbgfilemaps; // filename => filemap
// Perhaps add a code/priority argument to do selective output. // Perhaps add a code/priority argument to do selective output.
int debug_msg(const char* fmt, ...) __attribute__((format(printf, 1, 2))); int debug_msg(const char* fmt, ...) __attribute__((format(printf, 1, 2)));
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -26,29 +26,22 @@
using namespace std; using namespace std;
namespace zeek::detail namespace zeek::detail {
{
DebugCmdInfoQueue g_DebugCmdInfos; DebugCmdInfoQueue g_DebugCmdInfos;
// //
// Helper routines // Helper routines
// //
static bool string_is_regex(const string& s) static bool string_is_regex(const string& s) { return strpbrk(s.data(), "?*\\+"); }
{
return strpbrk(s.data(), "?*\\+");
}
static void lookup_global_symbols_regex(const string& orig_regex, vector<ID*>& matches, static void lookup_global_symbols_regex(const string& orig_regex, vector<ID*>& matches, bool func_only = false) {
bool func_only = false)
{
if ( util::streq(orig_regex.c_str(), "") ) if ( util::streq(orig_regex.c_str(), "") )
return; return;
string regex = "^"; string regex = "^";
int len = orig_regex.length(); int len = orig_regex.length();
for ( int i = 0; i < len; ++i ) for ( int i = 0; i < len; ++i ) {
{
if ( orig_regex[i] == '*' ) if ( orig_regex[i] == '*' )
regex.push_back('.'); regex.push_back('.');
regex.push_back(orig_regex[i]); regex.push_back(orig_regex[i]);
@ -56,8 +49,7 @@ static void lookup_global_symbols_regex(const string& orig_regex, vector<ID*>& m
regex.push_back('$'); regex.push_back('$');
regex_t re; regex_t re;
if ( regcomp(&re, regex.c_str(), REG_EXTENDED | REG_NOSUB) ) if ( regcomp(&re, regex.c_str(), REG_EXTENDED | REG_NOSUB) ) {
{
debug_msg("Invalid regular expression: %s\n", regex.c_str()); debug_msg("Invalid regular expression: %s\n", regex.c_str());
return; return;
} }
@ -66,25 +58,21 @@ static void lookup_global_symbols_regex(const string& orig_regex, vector<ID*>& m
const auto& syms = global->Vars(); const auto& syms = global->Vars();
ID* nextid; ID* nextid;
for ( const auto& sym : syms ) for ( const auto& sym : syms ) {
{
ID* nextid = sym.second.get(); ID* nextid = sym.second.get();
if ( ! func_only || nextid->GetType()->Tag() == TYPE_FUNC ) if ( ! func_only || nextid->GetType()->Tag() == TYPE_FUNC )
if ( ! regexec(&re, nextid->Name(), 0, 0, 0) ) if ( ! regexec(&re, nextid->Name(), 0, 0, 0) )
matches.push_back(nextid); matches.push_back(nextid);
} }
} }
static void choose_global_symbols_regex(const string& regex, vector<ID*>& choices, static void choose_global_symbols_regex(const string& regex, vector<ID*>& choices, bool func_only = false) {
bool func_only = false)
{
lookup_global_symbols_regex(regex, choices, func_only); lookup_global_symbols_regex(regex, choices, func_only);
if ( choices.size() <= 1 ) if ( choices.size() <= 1 )
return; return;
while ( true ) while ( true ) {
{
debug_msg("There were multiple matches, please choose:\n"); debug_msg("There were multiple matches, please choose:\n");
for ( size_t i = 0; i < choices.size(); i++ ) for ( size_t i = 0; i < choices.size(); i++ )
@ -95,8 +83,7 @@ static void choose_global_symbols_regex(const string& regex, vector<ID*>& choice
debug_msg("Enter your choice: "); debug_msg("Enter your choice: ");
char charinput[256]; char charinput[256];
if ( ! fgets(charinput, sizeof(charinput) - 1, stdin) ) if ( ! fgets(charinput, sizeof(charinput) - 1, stdin) ) {
{
choices.clear(); choices.clear();
return; return;
} }
@ -107,75 +94,65 @@ static void choose_global_symbols_regex(const string& regex, vector<ID*>& choice
if ( input == "a" ) if ( input == "a" )
return; return;
if ( input == "n" ) if ( input == "n" ) {
{
choices.clear(); choices.clear();
return; return;
} }
int option = atoi(input.c_str()); int option = atoi(input.c_str());
if ( option > 0 && option <= (int)choices.size() ) if ( option > 0 && option <= (int)choices.size() ) {
{
ID* choice = choices[option - 1]; ID* choice = choices[option - 1];
choices.clear(); choices.clear();
choices.push_back(choice); choices.push_back(choice);
return; return;
} }
} }
} }
// //
// DebugCmdInfo implementation // DebugCmdInfo implementation
// //
DebugCmdInfo::DebugCmdInfo(const DebugCmdInfo& info) : cmd(info.cmd), helpstring(nullptr) DebugCmdInfo::DebugCmdInfo(const DebugCmdInfo& info) : cmd(info.cmd), helpstring(nullptr) {
{
num_names = info.num_names; num_names = info.num_names;
names = info.names; names = info.names;
resume_execution = info.resume_execution; resume_execution = info.resume_execution;
repeatable = info.repeatable; repeatable = info.repeatable;
} }
DebugCmdInfo::DebugCmdInfo(DebugCmd arg_cmd, const char* const* arg_names, int arg_num_names, DebugCmdInfo::DebugCmdInfo(DebugCmd arg_cmd, const char* const* arg_names, int arg_num_names, bool arg_resume_execution,
bool arg_resume_execution, const char* const arg_helpstring, const char* const arg_helpstring, bool arg_repeatable)
bool arg_repeatable) : cmd(arg_cmd), helpstring(arg_helpstring) {
: cmd(arg_cmd), helpstring(arg_helpstring)
{
num_names = arg_num_names; num_names = arg_num_names;
resume_execution = arg_resume_execution; resume_execution = arg_resume_execution;
repeatable = arg_repeatable; repeatable = arg_repeatable;
for ( int i = 0; i < num_names; ++i ) for ( int i = 0; i < num_names; ++i )
names.push_back(arg_names[i]); names.push_back(arg_names[i]);
} }
const DebugCmdInfo* get_debug_cmd_info(DebugCmd cmd) const DebugCmdInfo* get_debug_cmd_info(DebugCmd cmd) {
{
if ( (int)cmd < g_DebugCmdInfos.size() ) if ( (int)cmd < g_DebugCmdInfos.size() )
return g_DebugCmdInfos[(int)cmd]; return g_DebugCmdInfos[(int)cmd];
else else
return nullptr; return nullptr;
} }
int find_all_matching_cmds(const string& prefix, const char* array_of_matches[]) int find_all_matching_cmds(const string& prefix, const char* array_of_matches[]) {
{
// Trivial implementation for now (### use hashing later). // Trivial implementation for now (### use hashing later).
unsigned int arglen = prefix.length(); unsigned int arglen = prefix.length();
int matches = 0; int matches = 0;
for ( int i = 0; i < num_debug_cmds(); ++i ) for ( int i = 0; i < num_debug_cmds(); ++i ) {
{
array_of_matches[g_DebugCmdInfos[i]->Cmd()] = nullptr; array_of_matches[g_DebugCmdInfos[i]->Cmd()] = nullptr;
for ( int j = 0; j < g_DebugCmdInfos[i]->NumNames(); ++j ) for ( int j = 0; j < g_DebugCmdInfos[i]->NumNames(); ++j ) {
{
const char* curr_name = g_DebugCmdInfos[i]->Names()[j]; const char* curr_name = g_DebugCmdInfos[i]->Names()[j];
if ( strncmp(curr_name, prefix.c_str(), arglen) ) if ( strncmp(curr_name, prefix.c_str(), arglen) )
continue; continue;
// If exact match, then only return that one. // If exact match, then only return that one.
if ( ! prefix.compare(curr_name) ) if ( ! prefix.compare(curr_name) ) {
{
for ( int k = 0; k < num_debug_cmds(); ++k ) for ( int k = 0; k < num_debug_cmds(); ++k )
array_of_matches[k] = nullptr; array_of_matches[k] = nullptr;
@ -189,28 +166,24 @@ int find_all_matching_cmds(const string& prefix, const char* array_of_matches[])
} }
return matches; return matches;
} }
// //
// ------------------------------------------------------------ // ------------------------------------------------------------
// Implementation of some debugger commands // Implementation of some debugger commands
// Start, end bounds of which frame numbers to print // Start, end bounds of which frame numbers to print
static int dbg_backtrace_internal(int start, int end) static int dbg_backtrace_internal(int start, int end) {
{ if ( start < 0 || end < 0 || (unsigned)start >= g_frame_stack.size() || (unsigned)end >= g_frame_stack.size() )
if ( start < 0 || end < 0 || (unsigned)start >= g_frame_stack.size() ||
(unsigned)end >= g_frame_stack.size() )
reporter->InternalError("Invalid stack frame index in DbgBacktraceInternal\n"); reporter->InternalError("Invalid stack frame index in DbgBacktraceInternal\n");
if ( start < end ) if ( start < end ) {
{
int temp = start; int temp = start;
start = end; start = end;
end = temp; end = temp;
} }
for ( int i = start; i >= end; --i ) for ( int i = start; i >= end; --i ) {
{
const Frame* f = g_frame_stack[i]; const Frame* f = g_frame_stack[i];
const Stmt* stmt = f ? f->GetNextStmt() : nullptr; const Stmt* stmt = f ? f->GetNextStmt() : nullptr;
@ -219,76 +192,64 @@ static int dbg_backtrace_internal(int start, int end)
}; };
return 1; return 1;
} }
// Returns 0 for illegal arguments, or 1 on success. // Returns 0 for illegal arguments, or 1 on success.
int dbg_cmd_backtrace(DebugCmd cmd, const vector<string>& args) int dbg_cmd_backtrace(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcBacktrace); assert(cmd == dcBacktrace);
assert(g_frame_stack.size() > 0); assert(g_frame_stack.size() > 0);
unsigned int start_iter; unsigned int start_iter;
int end_iter; int end_iter;
if ( args.size() > 0 ) if ( args.size() > 0 ) {
{
int how_many; // determines how we traverse the frames int how_many; // determines how we traverse the frames
int valid_arg = sscanf(args[0].c_str(), "%i", &how_many); int valid_arg = sscanf(args[0].c_str(), "%i", &how_many);
if ( ! valid_arg ) if ( ! valid_arg ) {
{
debug_msg("Argument to backtrace '%s' invalid: must be an integer\n", args[0].c_str()); debug_msg("Argument to backtrace '%s' invalid: must be an integer\n", args[0].c_str());
return 0; return 0;
} }
if ( how_many > 0 ) if ( how_many > 0 ) { // innermost N frames
{ // innermost N frames
start_iter = g_frame_stack.size() - 1; start_iter = g_frame_stack.size() - 1;
end_iter = start_iter - how_many + 1; end_iter = start_iter - how_many + 1;
if ( end_iter < 0 ) if ( end_iter < 0 )
end_iter = 0; end_iter = 0;
} }
else else { // outermost N frames
{ // outermost N frames
start_iter = how_many - 1; start_iter = how_many - 1;
if ( start_iter + 1 > g_frame_stack.size() ) if ( start_iter + 1 > g_frame_stack.size() )
start_iter = g_frame_stack.size() - 1; start_iter = g_frame_stack.size() - 1;
end_iter = 0; end_iter = 0;
} }
} }
else else {
{
start_iter = g_frame_stack.size() - 1; start_iter = g_frame_stack.size() - 1;
end_iter = 0; end_iter = 0;
} }
return dbg_backtrace_internal(start_iter, end_iter); return dbg_backtrace_internal(start_iter, end_iter);
} }
// Returns 0 if invalid args, else 1. // Returns 0 if invalid args, else 1.
int dbg_cmd_frame(DebugCmd cmd, const vector<string>& args) int dbg_cmd_frame(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcFrame || cmd == dcUp || cmd == dcDown); assert(cmd == dcFrame || cmd == dcUp || cmd == dcDown);
if ( cmd == dcFrame ) if ( cmd == dcFrame ) {
{
int idx = 0; int idx = 0;
if ( args.size() > 0 ) if ( args.size() > 0 ) {
{ if ( args.size() > 1 ) {
if ( args.size() > 1 )
{
debug_msg("Too many arguments: expecting frame number 'n'\n"); debug_msg("Too many arguments: expecting frame number 'n'\n");
return 0; return 0;
} }
if ( ! sscanf(args[0].c_str(), "%d", &idx) ) if ( ! sscanf(args[0].c_str(), "%d", &idx) ) {
{
debug_msg("Argument to frame must be a positive integer\n"); debug_msg("Argument to frame must be a positive integer\n");
return 0; return 0;
} }
if ( idx < 0 || (unsigned int)idx >= g_frame_stack.size() ) if ( idx < 0 || (unsigned int)idx >= g_frame_stack.size() ) {
{
debug_msg("No frame %d", idx); debug_msg("No frame %d", idx);
return 0; return 0;
} }
@ -297,10 +258,8 @@ int dbg_cmd_frame(DebugCmd cmd, const vector<string>& args)
g_debugger_state.curr_frame_idx = idx; g_debugger_state.curr_frame_idx = idx;
} }
else if ( cmd == dcDown ) else if ( cmd == dcDown ) {
{ if ( g_debugger_state.curr_frame_idx == 0 ) {
if ( g_debugger_state.curr_frame_idx == 0 )
{
debug_msg("Innermost frame already selected\n"); debug_msg("Innermost frame already selected\n");
return 0; return 0;
} }
@ -308,10 +267,8 @@ int dbg_cmd_frame(DebugCmd cmd, const vector<string>& args)
g_debugger_state.curr_frame_idx--; g_debugger_state.curr_frame_idx--;
} }
else if ( cmd == dcUp ) else if ( cmd == dcUp ) {
{ if ( (unsigned int)(g_debugger_state.curr_frame_idx + 1) == g_frame_stack.size() ) {
if ( (unsigned int)(g_debugger_state.curr_frame_idx + 1) == g_frame_stack.size() )
{
debug_msg("Outermost frame already selected\n"); debug_msg("Outermost frame already selected\n");
return 0; return 0;
} }
@ -332,32 +289,28 @@ int dbg_cmd_frame(DebugCmd cmd, const vector<string>& args)
g_debugger_state.already_did_list = false; g_debugger_state.already_did_list = false;
return dbg_backtrace_internal(user_frame_number, user_frame_number); return dbg_backtrace_internal(user_frame_number, user_frame_number);
} }
int dbg_cmd_help(DebugCmd cmd, const vector<string>& args) int dbg_cmd_help(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcHelp); assert(cmd == dcHelp);
debug_msg("Help summary: \n\n"); debug_msg("Help summary: \n\n");
for ( int i = 1; i < num_debug_cmds(); ++i ) for ( int i = 1; i < num_debug_cmds(); ++i ) {
{
const DebugCmdInfo* info = get_debug_cmd_info(DebugCmd(i)); const DebugCmdInfo* info = get_debug_cmd_info(DebugCmd(i));
debug_msg("%s -- %s\n", info->Names()[0], info->Helpstring()); debug_msg("%s -- %s\n", info->Names()[0], info->Helpstring());
} }
return -1; return -1;
} }
int dbg_cmd_break(DebugCmd cmd, const vector<string>& args) int dbg_cmd_break(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcBreak); assert(cmd == dcBreak);
vector<DbgBreakpoint*> bps; vector<DbgBreakpoint*> bps;
int cond_index = -1; // at which argument pos. does bp condition start? int cond_index = -1; // at which argument pos. does bp condition start?
if ( args.empty() || args[0] == "if" ) if ( args.empty() || args[0] == "if" ) { // break on next stmt
{ // break on next stmt
int user_frame_number = g_frame_stack.size() - 1 - g_debugger_state.curr_frame_idx; int user_frame_number = g_frame_stack.size() - 1 - g_debugger_state.curr_frame_idx;
Stmt* stmt = g_frame_stack[user_frame_number]->GetNextStmt(); Stmt* stmt = g_frame_stack[user_frame_number]->GetNextStmt();
@ -367,8 +320,7 @@ int dbg_cmd_break(DebugCmd cmd, const vector<string>& args)
DbgBreakpoint* bp = new DbgBreakpoint(); DbgBreakpoint* bp = new DbgBreakpoint();
bp->SetID(g_debugger_state.NextBPID()); bp->SetID(g_debugger_state.NextBPID());
if ( ! bp->SetLocation(stmt) ) if ( ! bp->SetLocation(stmt) ) {
{
debug_msg("Breakpoint not set.\n"); debug_msg("Breakpoint not set.\n");
delete bp; delete bp;
return 0; return 0;
@ -380,11 +332,9 @@ int dbg_cmd_break(DebugCmd cmd, const vector<string>& args)
bps.push_back(bp); bps.push_back(bp);
} }
else else {
{
vector<string> locstrings; vector<string> locstrings;
if ( string_is_regex(args[0]) ) if ( string_is_regex(args[0]) ) {
{
vector<ID*> choices; vector<ID*> choices;
choose_global_symbols_regex(args[0], choices, true); choose_global_symbols_regex(args[0], choices, true);
for ( unsigned int i = 0; i < choices.size(); ++i ) for ( unsigned int i = 0; i < choices.size(); ++i )
@ -393,16 +343,13 @@ int dbg_cmd_break(DebugCmd cmd, const vector<string>& args)
else else
locstrings.push_back(args[0]); locstrings.push_back(args[0]);
for ( unsigned int strindex = 0; strindex < locstrings.size(); ++strindex ) for ( unsigned int strindex = 0; strindex < locstrings.size(); ++strindex ) {
{
debug_msg("Setting breakpoint on %s:\n", locstrings[strindex].c_str()); debug_msg("Setting breakpoint on %s:\n", locstrings[strindex].c_str());
vector<ParseLocationRec> plrs = parse_location_string(locstrings[strindex]); vector<ParseLocationRec> plrs = parse_location_string(locstrings[strindex]);
for ( const auto& plr : plrs ) for ( const auto& plr : plrs ) {
{
DbgBreakpoint* bp = new DbgBreakpoint(); DbgBreakpoint* bp = new DbgBreakpoint();
bp->SetID(g_debugger_state.NextBPID()); bp->SetID(g_debugger_state.NextBPID());
if ( ! bp->SetLocation(plr, locstrings[strindex]) ) if ( ! bp->SetLocation(plr, locstrings[strindex]) ) {
{
debug_msg("Breakpoint not set.\n"); debug_msg("Breakpoint not set.\n");
delete bp; delete bp;
} }
@ -416,34 +363,29 @@ int dbg_cmd_break(DebugCmd cmd, const vector<string>& args)
} }
// Is there a condition specified? // Is there a condition specified?
if ( cond_index >= 0 && ! bps.empty() ) if ( cond_index >= 0 && ! bps.empty() ) {
{
// ### Implement conditions // ### Implement conditions
string cond; string cond;
for ( const auto& arg : args ) for ( const auto& arg : args ) {
{
cond += arg; cond += arg;
cond += " "; cond += " ";
} }
bps[0]->SetCondition(cond); bps[0]->SetCondition(cond);
} }
for ( auto& bp : bps ) for ( auto& bp : bps ) {
{
bp->SetTemporary(false); bp->SetTemporary(false);
g_debugger_state.breakpoints[bp->GetID()] = bp; g_debugger_state.breakpoints[bp->GetID()] = bp;
} }
return 0; return 0;
} }
// Set a condition on an existing breakpoint. // Set a condition on an existing breakpoint.
int dbg_cmd_break_condition(DebugCmd cmd, const vector<string>& args) int dbg_cmd_break_condition(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcBreakCondition); assert(cmd == dcBreakCondition);
if ( args.size() < 2 ) if ( args.size() < 2 ) {
{
debug_msg("Arguments must specify breakpoint number and condition.\n"); debug_msg("Arguments must specify breakpoint number and condition.\n");
return 0; return 0;
} }
@ -452,58 +394,48 @@ int dbg_cmd_break_condition(DebugCmd cmd, const vector<string>& args)
DbgBreakpoint* bp = g_debugger_state.breakpoints[idx]; DbgBreakpoint* bp = g_debugger_state.breakpoints[idx];
string expr; string expr;
for ( int i = 1; i < int(args.size()); ++i ) for ( int i = 1; i < int(args.size()); ++i ) {
{
expr += args[i]; expr += args[i];
expr += " "; expr += " ";
} }
bp->SetCondition(expr); bp->SetCondition(expr);
return 1; return 1;
} }
// Change the state of a breakpoint. // Change the state of a breakpoint.
int dbg_cmd_break_set_state(DebugCmd cmd, const vector<string>& args) int dbg_cmd_break_set_state(DebugCmd cmd, const vector<string>& args) {
{ assert(cmd == dcDeleteBreak || cmd == dcClearBreak || cmd == dcDisableBreak || cmd == dcEnableBreak ||
assert(cmd == dcDeleteBreak || cmd == dcClearBreak || cmd == dcDisableBreak || cmd == dcIgnoreBreak);
cmd == dcEnableBreak || cmd == dcIgnoreBreak);
if ( cmd == dcClearBreak || cmd == dcIgnoreBreak ) if ( cmd == dcClearBreak || cmd == dcIgnoreBreak ) {
{
debug_msg("'clear' and 'ignore' commands not currently supported\n"); debug_msg("'clear' and 'ignore' commands not currently supported\n");
return 0; return 0;
} }
if ( g_debugger_state.breakpoints.empty() ) if ( g_debugger_state.breakpoints.empty() ) {
{
debug_msg("No breakpoints currently set.\n"); debug_msg("No breakpoints currently set.\n");
return -1; return -1;
} }
vector<int> bps_to_change; vector<int> bps_to_change;
if ( args.empty() ) if ( args.empty() ) {
{
BPIDMapType::iterator iter; BPIDMapType::iterator iter;
for ( iter = g_debugger_state.breakpoints.begin(); for ( iter = g_debugger_state.breakpoints.begin(); iter != g_debugger_state.breakpoints.end(); ++iter )
iter != g_debugger_state.breakpoints.end(); ++iter )
bps_to_change.push_back(iter->second->GetID()); bps_to_change.push_back(iter->second->GetID());
} }
else else {
{
for ( const auto& arg : args ) for ( const auto& arg : args )
if ( int idx = atoi(arg.c_str()) ) if ( int idx = atoi(arg.c_str()) )
bps_to_change.push_back(idx); bps_to_change.push_back(idx);
} }
for ( auto bp_change : bps_to_change ) for ( auto bp_change : bps_to_change ) {
{
BPIDMapType::iterator result = g_debugger_state.breakpoints.find(bp_change); BPIDMapType::iterator result = g_debugger_state.breakpoints.find(bp_change);
if ( result != g_debugger_state.breakpoints.end() ) if ( result != g_debugger_state.breakpoints.end() ) {
{ switch ( cmd ) {
switch ( cmd )
{
case dcDisableBreak: case dcDisableBreak:
g_debugger_state.breakpoints[bp_change]->SetEnable(false); g_debugger_state.breakpoints[bp_change]->SetEnable(false);
debug_msg("Breakpoint %d disabled\n", bp_change); debug_msg("Breakpoint %d disabled\n", bp_change);
@ -520,8 +452,7 @@ int dbg_cmd_break_set_state(DebugCmd cmd, const vector<string>& args)
debug_msg("Breakpoint %d deleted\n", bp_change); debug_msg("Breakpoint %d deleted\n", bp_change);
break; break;
default: default: reporter->InternalError("Invalid command in DbgCmdBreakSetState\n");
reporter->InternalError("Invalid command in DbgCmdBreakSetState\n");
} }
} }
@ -530,19 +461,17 @@ int dbg_cmd_break_set_state(DebugCmd cmd, const vector<string>& args)
} }
return -1; return -1;
} }
// Evaluate an expression and print the result. // Evaluate an expression and print the result.
int dbg_cmd_print(DebugCmd cmd, const vector<string>& args) int dbg_cmd_print(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcPrint); assert(cmd == dcPrint);
// ### TODO: add support for formats // ### TODO: add support for formats
// Just concatenate all the 'args' into one expression. // Just concatenate all the 'args' into one expression.
string expr; string expr;
for ( size_t i = 0; i < args.size(); ++i ) for ( size_t i = 0; i < args.size(); ++i ) {
{
expr += args[i]; expr += args[i];
if ( i < args.size() - 1 ) if ( i < args.size() - 1 )
expr += " "; expr += " ";
@ -550,28 +479,24 @@ int dbg_cmd_print(DebugCmd cmd, const vector<string>& args)
auto val = dbg_eval_expr(expr.c_str()); auto val = dbg_eval_expr(expr.c_str());
if ( val ) if ( val ) {
{
ODesc d; ODesc d;
val->Describe(&d); val->Describe(&d);
debug_msg("%s\n", d.Description()); debug_msg("%s\n", d.Description());
} }
else else {
{
debug_msg("<expression has no value>\n"); debug_msg("<expression has no value>\n");
} }
return 1; return 1;
} }
// Get the debugger's state. // Get the debugger's state.
// Allowed arguments are: break (breakpoints), watch, display, source. // Allowed arguments are: break (breakpoints), watch, display, source.
int dbg_cmd_info(DebugCmd cmd, const vector<string>& args) int dbg_cmd_info(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcInfo); assert(cmd == dcInfo);
if ( args.empty() ) if ( args.empty() ) {
{
debug_msg("Syntax: info info-command\n"); debug_msg("Syntax: info info-command\n");
debug_msg("List of info-commands:\n"); debug_msg("List of info-commands:\n");
debug_msg("info breakpoints -- List of breakpoints and watches\n"); debug_msg("info breakpoints -- List of breakpoints and watches\n");
@ -579,18 +504,14 @@ int dbg_cmd_info(DebugCmd cmd, const vector<string>& args)
} }
if ( ! strncmp(args[0].c_str(), "breakpoints", args[0].size()) || if ( ! strncmp(args[0].c_str(), "breakpoints", args[0].size()) ||
! strncmp(args[0].c_str(), "watch", args[0].size()) ) ! strncmp(args[0].c_str(), "watch", args[0].size()) ) {
{
debug_msg("Num Type Disp Enb What\n"); debug_msg("Num Type Disp Enb What\n");
BPIDMapType::iterator iter; BPIDMapType::iterator iter;
for ( iter = g_debugger_state.breakpoints.begin(); for ( iter = g_debugger_state.breakpoints.begin(); iter != g_debugger_state.breakpoints.end(); ++iter ) {
iter != g_debugger_state.breakpoints.end(); ++iter )
{
DbgBreakpoint* bp = (*iter).second; DbgBreakpoint* bp = (*iter).second;
debug_msg("%-4d%-15s%-5s%-4s%s\n", bp->GetID(), "breakpoint", debug_msg("%-4d%-15s%-5s%-4s%s\n", bp->GetID(), "breakpoint", bp->IsTemporary() ? "del" : "keep",
bp->IsTemporary() ? "del" : "keep", bp->IsEnabled() ? "y" : "n", bp->IsEnabled() ? "y" : "n", bp->Description());
bp->Description());
} }
} }
@ -598,24 +519,21 @@ int dbg_cmd_info(DebugCmd cmd, const vector<string>& args)
debug_msg("I don't have info for that yet.\n"); debug_msg("I don't have info for that yet.\n");
return 1; return 1;
} }
int dbg_cmd_list(DebugCmd cmd, const vector<string>& args) int dbg_cmd_list(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcList); assert(cmd == dcList);
// The constant 4 is to match the GDB behavior. // The constant 4 is to match the GDB behavior.
const unsigned int CENTER_IDX = 4; // 5th line is the 'interesting' one const unsigned int CENTER_IDX = 4; // 5th line is the 'interesting' one
int pre_offset = 0; int pre_offset = 0;
if ( args.size() > 1 ) if ( args.size() > 1 ) {
{
debug_msg("Syntax: list [file:]line OR list function_name\n"); debug_msg("Syntax: list [file:]line OR list function_name\n");
return 0; return 0;
} }
if ( args.empty() ) if ( args.empty() ) {
{
// Special case: if we just hit a breakpoint, then show // Special case: if we just hit a breakpoint, then show
// that line without advancing first. // that line without advancing first.
if ( g_debugger_state.already_did_list ) if ( g_debugger_state.already_did_list )
@ -626,11 +544,9 @@ int dbg_cmd_list(DebugCmd cmd, const vector<string>& args)
// Why -10 ? Because that's what GDB does. // Why -10 ? Because that's what GDB does.
pre_offset = -10; pre_offset = -10;
else if ( args[0][0] == '-' || args[0][0] == '+' ) else if ( args[0][0] == '-' || args[0][0] == '+' ) {
{
int offset; int offset;
if ( ! sscanf(args[0].c_str(), "%d", &offset) ) if ( ! sscanf(args[0].c_str(), "%d", &offset) ) {
{
debug_msg("Offset must be a number\n"); debug_msg("Offset must be a number\n");
return false; return false;
} }
@ -638,12 +554,10 @@ int dbg_cmd_list(DebugCmd cmd, const vector<string>& args)
pre_offset = offset; pre_offset = offset;
} }
else else {
{
vector<ParseLocationRec> plrs = parse_location_string(args[0]); vector<ParseLocationRec> plrs = parse_location_string(args[0]);
ParseLocationRec plr = plrs[0]; ParseLocationRec plr = plrs[0];
if ( plr.type == PLR_UNKNOWN ) if ( plr.type == PLR_UNKNOWN ) {
{
debug_msg("Invalid location specifier\n"); debug_msg("Invalid location specifier\n");
return false; return false;
} }
@ -663,38 +577,33 @@ int dbg_cmd_list(DebugCmd cmd, const vector<string>& args)
if ( g_debugger_state.last_loc.first_line > last_line_in_file ) if ( g_debugger_state.last_loc.first_line > last_line_in_file )
g_debugger_state.last_loc.first_line = last_line_in_file; g_debugger_state.last_loc.first_line = last_line_in_file;
PrintLines(g_debugger_state.last_loc.filename, PrintLines(g_debugger_state.last_loc.filename, g_debugger_state.last_loc.first_line - CENTER_IDX, 10, true);
g_debugger_state.last_loc.first_line - CENTER_IDX, 10, true);
g_debugger_state.already_did_list = true; g_debugger_state.already_did_list = true;
return 1; return 1;
} }
int dbg_cmd_trace(DebugCmd cmd, const vector<string>& args) int dbg_cmd_trace(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcTrace); assert(cmd == dcTrace);
if ( args.empty() ) if ( args.empty() ) {
{
debug_msg("Execution tracing is %s.\n", g_trace_state.DoTrace() ? "on" : "off"); debug_msg("Execution tracing is %s.\n", g_trace_state.DoTrace() ? "on" : "off");
return 1; return 1;
} }
if ( args[0] == "on" ) if ( args[0] == "on" ) {
{
g_trace_state.TraceOn(); g_trace_state.TraceOn();
return 1; return 1;
} }
if ( args[0] == "off" ) if ( args[0] == "off" ) {
{
g_trace_state.TraceOff(); g_trace_state.TraceOff();
return 1; return 1;
} }
debug_msg("Invalid argument"); debug_msg("Invalid argument");
return 0; return 0;
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -11,18 +11,16 @@
// This file is generated during the build. // This file is generated during the build.
#include "DebugCmdConstants.h" #include "DebugCmdConstants.h"
namespace zeek::detail namespace zeek::detail {
{
class DebugCmdInfo class DebugCmdInfo {
{
public: public:
DebugCmdInfo(const DebugCmdInfo& info); DebugCmdInfo(const DebugCmdInfo& info);
DebugCmdInfo(DebugCmd cmd, const char* const* names, int num_names, bool resume_execution, DebugCmdInfo(DebugCmd cmd, const char* const* names, int num_names, bool resume_execution,
const char* const helpstring, bool repeatable); const char* const helpstring, bool repeatable);
DebugCmdInfo() : helpstring(nullptr) { } DebugCmdInfo() : helpstring(nullptr) {}
int Cmd() const { return cmd; } int Cmd() const { return cmd; }
int NumNames() const { return num_names; } int NumNames() const { return num_names; }
@ -43,7 +41,7 @@ protected:
// Does entering a blank line repeat this command? // Does entering a blank line repeat this command?
bool repeatable; bool repeatable;
}; };
using DebugCmdInfoQueue = std::deque<DebugCmdInfo*>; using DebugCmdInfoQueue = std::deque<DebugCmdInfo*>;
extern DebugCmdInfoQueue g_DebugCmdInfos; extern DebugCmdInfoQueue g_DebugCmdInfos;
@ -80,4 +78,4 @@ DbgCmdFn dbg_cmd_info;
DbgCmdFn dbg_cmd_list; DbgCmdFn dbg_cmd_list;
DbgCmdFn dbg_cmd_trace; DbgCmdFn dbg_cmd_trace;
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -11,45 +11,36 @@
zeek::detail::DebugLogger zeek::detail::debug_logger; zeek::detail::DebugLogger zeek::detail::debug_logger;
zeek::detail::DebugLogger& debug_logger = zeek::detail::debug_logger; zeek::detail::DebugLogger& debug_logger = zeek::detail::debug_logger;
namespace zeek::detail namespace zeek::detail {
{
// Same order here as in DebugStream. // Same order here as in DebugStream.
DebugLogger::Stream DebugLogger::streams[NUM_DBGS] = { DebugLogger::Stream DebugLogger::streams[NUM_DBGS] =
{"serial", 0, false}, {"rules", 0, false}, {"string", 0, false}, {{"serial", 0, false}, {"rules", 0, false}, {"string", 0, false}, {"notifiers", 0, false},
{"notifiers", 0, false}, {"main-loop", 0, false}, {"dpd", 0, false}, {"main-loop", 0, false}, {"dpd", 0, false}, {"packet_analysis", 0, false}, {"file_analysis", 0, false},
{"packet_analysis", 0, false}, {"file_analysis", 0, false}, {"tm", 0, false}, {"tm", 0, false}, {"logging", 0, false}, {"input", 0, false}, {"threading", 0, false},
{"logging", 0, false}, {"input", 0, false}, {"threading", 0, false}, {"plugins", 0, false}, {"zeekygen", 0, false}, {"pktio", 0, false}, {"broker", 0, false},
{"plugins", 0, false}, {"zeekygen", 0, false}, {"pktio", 0, false}, {"scripts", 0, false}, {"supervisor", 0, false}, {"hashkey", 0, false}, {"spicy", 0, false}};
{"broker", 0, false}, {"scripts", 0, false}, {"supervisor", 0, false},
{"hashkey", 0, false}, {"spicy", 0, false}};
DebugLogger::DebugLogger() DebugLogger::DebugLogger() {
{
verbose = false; verbose = false;
file = nullptr; file = nullptr;
} }
DebugLogger::~DebugLogger() DebugLogger::~DebugLogger() {
{
if ( file && file != stderr ) if ( file && file != stderr )
fclose(file); fclose(file);
} }
void DebugLogger::OpenDebugLog(const char* filename) void DebugLogger::OpenDebugLog(const char* filename) {
{ if ( filename ) {
if ( filename )
{
filename = util::detail::log_file_name(filename); filename = util::detail::log_file_name(filename);
file = fopen(filename, "w"); file = fopen(filename, "w");
if ( ! file ) if ( ! file ) {
{
// The reporter may not be initialized here yet. // The reporter may not be initialized here yet.
if ( reporter ) if ( reporter )
reporter->FatalError("can't open '%s' for debugging output", filename); reporter->FatalError("can't open '%s' for debugging output", filename);
else else {
{
fprintf(stderr, "can't open '%s' for debugging output\n", filename); fprintf(stderr, "can't open '%s' for debugging output\n", filename);
exit(1); exit(1);
} }
@ -59,10 +50,9 @@ void DebugLogger::OpenDebugLog(const char* filename)
} }
else else
file = stderr; file = stderr;
} }
void DebugLogger::ShowStreamsHelp() void DebugLogger::ShowStreamsHelp() {
{
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "Enable debug output into debug.log with -B <streams>.\n"); fprintf(stderr, "Enable debug output into debug.log with -B <streams>.\n");
fprintf(stderr, "<streams> is a comma-separated list of streams to enable.\n"); fprintf(stderr, "<streams> is a comma-separated list of streams to enable.\n");
@ -73,27 +63,24 @@ void DebugLogger::ShowStreamsHelp()
fprintf(stderr, " %s\n", streams[i].prefix); fprintf(stderr, " %s\n", streams[i].prefix);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, " plugin-<plugin-name> (replace '::' in name with '-'; e.g., '-B " fprintf(stderr,
" plugin-<plugin-name> (replace '::' in name with '-'; e.g., '-B "
"plugin-Zeek-Netmap')\n"); "plugin-Zeek-Netmap')\n");
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "Pseudo streams\n"); fprintf(stderr, "Pseudo streams\n");
fprintf(stderr, " verbose Increase verbosity.\n"); fprintf(stderr, " verbose Increase verbosity.\n");
fprintf(stderr, " all Enable all streams at maximum verbosity.\n"); fprintf(stderr, " all Enable all streams at maximum verbosity.\n");
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
void DebugLogger::EnableStreams(const char* s) void DebugLogger::EnableStreams(const char* s) {
{
char* brkt; char* brkt;
char* tmp = util::copy_string(s); char* tmp = util::copy_string(s);
char* tok = strtok(tmp, ","); char* tok = strtok(tmp, ",");
while ( tok ) while ( tok ) {
{ if ( strcasecmp("all", tok) == 0 ) {
if ( strcasecmp("all", tok) == 0 ) for ( int i = 0; i < NUM_DBGS; ++i ) {
{
for ( int i = 0; i < NUM_DBGS; ++i )
{
streams[i].enabled = true; streams[i].enabled = true;
enabled_streams.insert(streams[i].prefix); enabled_streams.insert(streams[i].prefix);
} }
@ -102,20 +89,17 @@ void DebugLogger::EnableStreams(const char* s)
goto next; goto next;
} }
if ( strcasecmp("verbose", tok) == 0 ) if ( strcasecmp("verbose", tok) == 0 ) {
{
verbose = true; verbose = true;
goto next; goto next;
} }
if ( strcasecmp("help", tok) == 0 ) if ( strcasecmp("help", tok) == 0 ) {
{
ShowStreamsHelp(); ShowStreamsHelp();
exit(0); exit(0);
} }
if ( util::starts_with(tok, "plugin-") ) if ( util::starts_with(tok, "plugin-") ) {
{
// Cannot verify this at this time, plugins may not // Cannot verify this at this time, plugins may not
// have been loaded. // have been loaded.
enabled_streams.insert(tok); enabled_streams.insert(tok);
@ -124,10 +108,8 @@ void DebugLogger::EnableStreams(const char* s)
int i; int i;
for ( i = 0; i < NUM_DBGS; ++i ) for ( i = 0; i < NUM_DBGS; ++i ) {
{ if ( strcasecmp(streams[i].prefix, tok) == 0 ) {
if ( strcasecmp(streams[i].prefix, tok) == 0 )
{
streams[i].enabled = true; streams[i].enabled = true;
enabled_streams.insert(tok); enabled_streams.insert(tok);
goto next; goto next;
@ -141,40 +123,35 @@ void DebugLogger::EnableStreams(const char* s)
} }
delete[] tmp; delete[] tmp;
} }
bool DebugLogger::CheckStreams(const std::set<std::string>& plugin_names) bool DebugLogger::CheckStreams(const std::set<std::string>& plugin_names) {
{
bool ok = true; bool ok = true;
std::set<std::string> available_plugin_streams; std::set<std::string> available_plugin_streams;
for ( const auto& p : plugin_names ) for ( const auto& p : plugin_names )
available_plugin_streams.insert(PluginStreamName(p)); available_plugin_streams.insert(PluginStreamName(p));
for ( const auto& stream : enabled_streams ) for ( const auto& stream : enabled_streams ) {
{
if ( ! util::starts_with(stream, "plugin-") ) if ( ! util::starts_with(stream, "plugin-") )
continue; continue;
if ( available_plugin_streams.count(stream) == 0 ) if ( available_plugin_streams.count(stream) == 0 ) {
{
reporter->Error("No plugin debug stream '%s' found", stream.c_str()); reporter->Error("No plugin debug stream '%s' found", stream.c_str());
ok = false; ok = false;
} }
} }
return ok; return ok;
} }
void DebugLogger::Log(DebugStream stream, const char* fmt, ...) void DebugLogger::Log(DebugStream stream, const char* fmt, ...) {
{
Stream* g = &streams[int(stream)]; Stream* g = &streams[int(stream)];
if ( ! g->enabled ) if ( ! g->enabled )
return; return;
fprintf(file, "%17.06f/%17.06f [%s] ", run_state::network_time, util::current_time(true), fprintf(file, "%17.06f/%17.06f [%s] ", run_state::network_time, util::current_time(true), g->prefix);
g->prefix);
for ( int i = g->indent; i > 0; --i ) for ( int i = g->indent; i > 0; --i )
fputs(" ", file); fputs(" ", file);
@ -186,10 +163,9 @@ void DebugLogger::Log(DebugStream stream, const char* fmt, ...)
fputc('\n', file); fputc('\n', file);
fflush(file); fflush(file);
} }
void DebugLogger::Log(const plugin::Plugin& plugin, const char* fmt, ...) void DebugLogger::Log(const plugin::Plugin& plugin, const char* fmt, ...) {
{
std::string tok = PluginStreamName(plugin.Name()); std::string tok = PluginStreamName(plugin.Name());
if ( enabled_streams.find(tok) == enabled_streams.end() ) if ( enabled_streams.find(tok) == enabled_streams.end() )
@ -205,8 +181,8 @@ void DebugLogger::Log(const plugin::Plugin& plugin, const char* fmt, ...)
fputc('\n', file); fputc('\n', file);
fflush(file); fflush(file);
} }
} // namespace zeek::detail } // namespace zeek::detail
#endif #endif

View file

@ -17,27 +17,23 @@
if ( ::zeek::detail::debug_logger.IsEnabled(stream) ) \ if ( ::zeek::detail::debug_logger.IsEnabled(stream) ) \
::zeek::detail::debug_logger.Log(stream, __VA_ARGS__) ::zeek::detail::debug_logger.Log(stream, __VA_ARGS__)
#define DBG_LOG_VERBOSE(stream, ...) \ #define DBG_LOG_VERBOSE(stream, ...) \
if ( ::zeek::detail::debug_logger.IsVerbose() && \ if ( ::zeek::detail::debug_logger.IsVerbose() && ::zeek::detail::debug_logger.IsEnabled(stream) ) \
::zeek::detail::debug_logger.IsEnabled(stream) ) \
::zeek::detail::debug_logger.Log(stream, __VA_ARGS__) ::zeek::detail::debug_logger.Log(stream, __VA_ARGS__)
#define DBG_PUSH(stream) ::zeek::detail::debug_logger.PushIndent(stream) #define DBG_PUSH(stream) ::zeek::detail::debug_logger.PushIndent(stream)
#define DBG_POP(stream) ::zeek::detail::debug_logger.PopIndent(stream) #define DBG_POP(stream) ::zeek::detail::debug_logger.PopIndent(stream)
#define PLUGIN_DBG_LOG(plugin, ...) ::zeek::detail::debug_logger.Log(plugin, __VA_ARGS__) #define PLUGIN_DBG_LOG(plugin, ...) ::zeek::detail::debug_logger.Log(plugin, __VA_ARGS__)
namespace zeek namespace zeek {
{
namespace plugin namespace plugin {
{
class Plugin; class Plugin;
} }
// To add a new debugging stream, add a constant here as well as // To add a new debugging stream, add a constant here as well as
// an entry to DebugLogger::streams in DebugLogger.cc. // an entry to DebugLogger::streams in DebugLogger.cc.
enum DebugStream enum DebugStream {
{
DBG_SERIAL, // Serialization DBG_SERIAL, // Serialization
DBG_RULES, // Signature matching DBG_RULES, // Signature matching
DBG_STRING, // String code DBG_STRING, // String code
@ -60,13 +56,11 @@ enum DebugStream
DBG_SPICY, // Spicy functionality DBG_SPICY, // Spicy functionality
NUM_DBGS // Has to be last NUM_DBGS // Has to be last
}; };
namespace detail namespace detail {
{
class DebugLogger class DebugLogger {
{
public: public:
// Output goes to stderr per default. // Output goes to stderr per default.
DebugLogger(); DebugLogger();
@ -75,8 +69,7 @@ public:
void OpenDebugLog(const char* filename = 0); void OpenDebugLog(const char* filename = 0);
void Log(DebugStream stream, const char* fmt, ...) __attribute__((format(printf, 3, 4))); void Log(DebugStream stream, const char* fmt, ...) __attribute__((format(printf, 3, 4)));
void Log(const plugin::Plugin& plugin, const char* fmt, ...) void Log(const plugin::Plugin& plugin, const char* fmt, ...) __attribute__((format(printf, 3, 4)));
__attribute__((format(printf, 3, 4)));
void PushIndent(DebugStream stream) { ++streams[int(stream)].indent; } void PushIndent(DebugStream stream) { ++streams[int(stream)].indent; }
void PopIndent(DebugStream stream) { --streams[int(stream)].indent; } void PopIndent(DebugStream stream) { --streams[int(stream)].indent; }
@ -101,8 +94,7 @@ private:
FILE* file; FILE* file;
bool verbose; bool verbose;
struct Stream struct Stream {
{
const char* prefix; const char* prefix;
int indent; int indent;
bool enabled; bool enabled;
@ -112,16 +104,15 @@ private:
static Stream streams[NUM_DBGS]; static Stream streams[NUM_DBGS];
const std::string PluginStreamName(const std::string& plugin_name) const std::string PluginStreamName(const std::string& plugin_name) {
{
return "plugin-" + util::strreplace(plugin_name, "::", "-"); return "plugin-" + util::strreplace(plugin_name, "::", "-");
} }
}; };
extern DebugLogger debug_logger; extern DebugLogger debug_logger;
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek
#else #else
#define DBG_LOG(...) #define DBG_LOG(...)

View file

@ -17,24 +17,20 @@
#define DEFAULT_SIZE 128 #define DEFAULT_SIZE 128
#define SLOP 10 #define SLOP 10
namespace zeek namespace zeek {
{
ODesc::ODesc(DescType t, File* arg_f) ODesc::ODesc(DescType t, File* arg_f) {
{
type = t; type = t;
style = STANDARD_STYLE; style = STANDARD_STYLE;
f = arg_f; f = arg_f;
if ( f == nullptr ) if ( f == nullptr ) {
{
size = DEFAULT_SIZE; size = DEFAULT_SIZE;
base = util::safe_malloc(size); base = util::safe_malloc(size);
((char*)base)[0] = '\0'; ((char*)base)[0] = '\0';
offset = 0; offset = 0;
} }
else else {
{
offset = size = 0; offset = size = 0;
base = nullptr; base = nullptr;
} }
@ -48,51 +44,39 @@ ODesc::ODesc(DescType t, File* arg_f)
indent_with_spaces = 0; indent_with_spaces = 0;
escape = false; escape = false;
utf8 = false; utf8 = false;
} }
ODesc::~ODesc() ODesc::~ODesc() {
{ if ( f ) {
if ( f )
{
if ( do_flush ) if ( do_flush )
f->Flush(); f->Flush();
} }
else if ( base ) else if ( base )
free(base); free(base);
} }
void ODesc::EnableEscaping() void ODesc::EnableEscaping() { escape = true; }
{
escape = true;
}
void ODesc::EnableUTF8() void ODesc::EnableUTF8() { utf8 = true; }
{
utf8 = true;
}
void ODesc::PushIndent() void ODesc::PushIndent() {
{
++indent_level; ++indent_level;
NL(); NL();
} }
void ODesc::PopIndent() void ODesc::PopIndent() {
{
if ( --indent_level < 0 ) if ( --indent_level < 0 )
reporter->InternalError("ODesc::PopIndent underflow"); reporter->InternalError("ODesc::PopIndent underflow");
NL(); NL();
} }
void ODesc::PopIndentNoNL() void ODesc::PopIndentNoNL() {
{
if ( --indent_level < 0 ) if ( --indent_level < 0 )
reporter->InternalError("ODesc::PopIndent underflow"); reporter->InternalError("ODesc::PopIndent underflow");
} }
void ODesc::Add(const char* s, int do_indent) void ODesc::Add(const char* s, int do_indent) {
{
unsigned int n = strlen(s); unsigned int n = strlen(s);
if ( do_indent && IsReadable() && offset > 0 && ((const char*)base)[offset - 1] == '\n' ) if ( do_indent && IsReadable() && offset > 0 && ((const char*)base)[offset - 1] == '\n' )
@ -102,62 +86,52 @@ void ODesc::Add(const char* s, int do_indent)
AddBytes(s, n + 1); AddBytes(s, n + 1);
else else
AddBytes(s, n); AddBytes(s, n);
} }
void ODesc::Add(int i) void ODesc::Add(int i) {
{
if ( IsBinary() ) if ( IsBinary() )
AddBytes(&i, sizeof(i)); AddBytes(&i, sizeof(i));
else else {
{
char tmp[256]; char tmp[256];
modp_litoa10(i, tmp); modp_litoa10(i, tmp);
Add(tmp); Add(tmp);
} }
} }
void ODesc::Add(uint32_t u) void ODesc::Add(uint32_t u) {
{
if ( IsBinary() ) if ( IsBinary() )
AddBytes(&u, sizeof(u)); AddBytes(&u, sizeof(u));
else else {
{
char tmp[256]; char tmp[256];
modp_ulitoa10(u, tmp); modp_ulitoa10(u, tmp);
Add(tmp); Add(tmp);
} }
} }
void ODesc::Add(int64_t i) void ODesc::Add(int64_t i) {
{
if ( IsBinary() ) if ( IsBinary() )
AddBytes(&i, sizeof(i)); AddBytes(&i, sizeof(i));
else else {
{
char tmp[256]; char tmp[256];
modp_litoa10(i, tmp); modp_litoa10(i, tmp);
Add(tmp); Add(tmp);
} }
} }
void ODesc::Add(uint64_t u) void ODesc::Add(uint64_t u) {
{
if ( IsBinary() ) if ( IsBinary() )
AddBytes(&u, sizeof(u)); AddBytes(&u, sizeof(u));
else else {
{
char tmp[256]; char tmp[256];
modp_ulitoa10(u, tmp); modp_ulitoa10(u, tmp);
Add(tmp); Add(tmp);
} }
} }
void ODesc::Add(double d, bool no_exp) void ODesc::Add(double d, bool no_exp) {
{
if ( IsBinary() ) if ( IsBinary() )
AddBytes(&d, sizeof(d)); AddBytes(&d, sizeof(d));
else else {
{
// Buffer needs enough chars to store max. possible "double" value // Buffer needs enough chars to store max. possible "double" value
// of 1.79e308 without using scientific notation. // of 1.79e308 without using scientific notation.
char tmp[350]; char tmp[350];
@ -169,8 +143,7 @@ void ODesc::Add(double d, bool no_exp)
Add(tmp); Add(tmp);
auto approx_equal = [](double a, double b, double tolerance = 1e-6) -> bool auto approx_equal = [](double a, double b, double tolerance = 1e-6) -> bool {
{
auto v = a - b; auto v = a - b;
return v < 0 ? -v < tolerance : v < tolerance; return v < 0 ? -v < tolerance : v < tolerance;
}; };
@ -179,80 +152,63 @@ void ODesc::Add(double d, bool no_exp)
// disambiguate from integer // disambiguate from integer
Add(".0"); Add(".0");
} }
} }
void ODesc::Add(const IPAddr& addr) void ODesc::Add(const IPAddr& addr) { Add(addr.AsString()); }
{
Add(addr.AsString());
}
void ODesc::Add(const IPPrefix& prefix) void ODesc::Add(const IPPrefix& prefix) { Add(prefix.AsString()); }
{
Add(prefix.AsString());
}
void ODesc::AddCS(const char* s) void ODesc::AddCS(const char* s) {
{
int n = strlen(s); int n = strlen(s);
Add(n); Add(n);
if ( ! IsBinary() ) if ( ! IsBinary() )
Add(" "); Add(" ");
Add(s); Add(s);
} }
void ODesc::AddBytes(const String* s) void ODesc::AddBytes(const String* s) {
{ if ( IsReadable() ) {
if ( IsReadable() )
{
if ( Style() == RAW_STYLE ) if ( Style() == RAW_STYLE )
AddBytes(reinterpret_cast<const char*>(s->Bytes()), s->Len()); AddBytes(reinterpret_cast<const char*>(s->Bytes()), s->Len());
else else {
{
const char* str = s->Render(String::EXPANDED_STRING); const char* str = s->Render(String::EXPANDED_STRING);
Add(str); Add(str);
delete[] str; delete[] str;
} }
} }
else else {
{
Add(s->Len()); Add(s->Len());
if ( ! IsBinary() ) if ( ! IsBinary() )
Add(" "); Add(" ");
AddBytes(s->Bytes(), s->Len()); AddBytes(s->Bytes(), s->Len());
} }
} }
void ODesc::Indent() void ODesc::Indent() {
{ if ( indent_with_spaces > 0 ) {
if ( indent_with_spaces > 0 )
{
for ( int i = 0; i < indent_level; ++i ) for ( int i = 0; i < indent_level; ++i )
for ( int j = 0; j < indent_with_spaces; ++j ) for ( int j = 0; j < indent_with_spaces; ++j )
Add(" ", 0); Add(" ", 0);
} }
else else {
{
for ( int i = 0; i < indent_level; ++i ) for ( int i = 0; i < indent_level; ++i )
Add("\t", 0); Add("\t", 0);
} }
} }
static bool starts_with(const char* str1, const char* str2, size_t len) static bool starts_with(const char* str1, const char* str2, size_t len) {
{
for ( size_t i = 0; i < len; ++i ) for ( size_t i = 0; i < len; ++i )
if ( str1[i] != str2[i] ) if ( str1[i] != str2[i] )
return false; return false;
return true; return true;
} }
size_t ODesc::StartsWithEscapeSequence(const char* start, const char* end) size_t ODesc::StartsWithEscapeSequence(const char* start, const char* end) {
{
if ( escape_sequences.empty() ) if ( escape_sequences.empty() )
return 0; return 0;
for ( const auto& esc_str : escape_sequences ) for ( const auto& esc_str : escape_sequences ) {
{
size_t esc_len = esc_str.length(); size_t esc_len = esc_str.length();
if ( start + esc_len > end ) if ( start + esc_len > end )
@ -263,15 +219,13 @@ size_t ODesc::StartsWithEscapeSequence(const char* start, const char* end)
} }
return 0; return 0;
} }
std::pair<const char*, size_t> ODesc::FirstEscapeLoc(const char* bytes, size_t n) std::pair<const char*, size_t> ODesc::FirstEscapeLoc(const char* bytes, size_t n) {
{
if ( IsBinary() ) if ( IsBinary() )
return {nullptr, 0}; return {nullptr, 0};
for ( size_t i = 0; i < n; ++i ) for ( size_t i = 0; i < n; ++i ) {
{
auto printable = isprint(bytes[i]); auto printable = isprint(bytes[i]);
if ( ! printable && ! utf8 ) if ( ! printable && ! utf8 )
@ -287,12 +241,10 @@ std::pair<const char*, size_t> ODesc::FirstEscapeLoc(const char* bytes, size_t n
} }
return {nullptr, 0}; return {nullptr, 0};
} }
void ODesc::AddBytes(const void* bytes, unsigned int n) void ODesc::AddBytes(const void* bytes, unsigned int n) {
{ if ( ! escape ) {
if ( ! escape )
{
AddBytesRaw(bytes, n); AddBytesRaw(bytes, n);
return; return;
} }
@ -300,14 +252,11 @@ void ODesc::AddBytes(const void* bytes, unsigned int n)
const char* s = (const char*)bytes; const char* s = (const char*)bytes;
const char* e = (const char*)bytes + n; const char* e = (const char*)bytes + n;
while ( s < e ) while ( s < e ) {
{
auto [esc_start, esc_len] = FirstEscapeLoc(s, e - s); auto [esc_start, esc_len] = FirstEscapeLoc(s, e - s);
if ( esc_start != nullptr ) if ( esc_start != nullptr ) {
{ if ( utf8 ) {
if ( utf8 )
{
std::string result = util::json_escape_utf8(s, esc_start - s, false); std::string result = util::json_escape_utf8(s, esc_start - s, false);
AddBytesRaw(result.c_str(), result.size()); AddBytesRaw(result.c_str(), result.size());
} }
@ -317,10 +266,8 @@ void ODesc::AddBytes(const void* bytes, unsigned int n)
util::get_escaped_string(this, esc_start, esc_len, true); util::get_escaped_string(this, esc_start, esc_len, true);
s = esc_start + esc_len; s = esc_start + esc_len;
} }
else else {
{ if ( utf8 ) {
if ( utf8 )
{
std::string result = util::json_escape_utf8(s, e - s, false); std::string result = util::json_escape_utf8(s, e - s, false);
AddBytesRaw(result.c_str(), result.size()); AddBytesRaw(result.c_str(), result.size());
} }
@ -330,19 +277,16 @@ void ODesc::AddBytes(const void* bytes, unsigned int n)
break; break;
} }
} }
} }
void ODesc::AddBytesRaw(const void* bytes, unsigned int n) void ODesc::AddBytesRaw(const void* bytes, unsigned int n) {
{
if ( n == 0 ) if ( n == 0 )
return; return;
if ( f ) if ( f ) {
{
static bool write_failed = false; static bool write_failed = false;
if ( ! f->Write((const char*)bytes, n) ) if ( ! f->Write((const char*)bytes, n) ) {
{
if ( ! write_failed ) if ( ! write_failed )
// Most likely it's a "disk full" so report // Most likely it's a "disk full" so report
// subsequent failures only once. // subsequent failures only once.
@ -355,8 +299,7 @@ void ODesc::AddBytesRaw(const void* bytes, unsigned int n)
write_failed = false; write_failed = false;
} }
else else {
{
Grow(n); Grow(n);
// The following casting contortions are necessary because // The following casting contortions are necessary because
@ -367,59 +310,51 @@ void ODesc::AddBytesRaw(const void* bytes, unsigned int n)
((char*)base)[offset] = '\0'; // ensure that always NUL-term. ((char*)base)[offset] = '\0'; // ensure that always NUL-term.
} }
} }
void ODesc::Grow(unsigned int n) void ODesc::Grow(unsigned int n) {
{
bool size_changed = false; bool size_changed = false;
while ( offset + n + SLOP >= size ) while ( offset + n + SLOP >= size ) {
{
size *= 2; size *= 2;
size_changed = true; size_changed = true;
} }
if ( size_changed ) if ( size_changed )
base = util::safe_realloc(base, size); base = util::safe_realloc(base, size);
} }
void ODesc::Clear() void ODesc::Clear() {
{
offset = 0; offset = 0;
// If we've allocated an exceedingly large amount of space, free it. // If we've allocated an exceedingly large amount of space, free it.
if ( size > 10 * 1024 * 1024 ) if ( size > 10 * 1024 * 1024 ) {
{
free(base); free(base);
size = DEFAULT_SIZE; size = DEFAULT_SIZE;
base = util::safe_malloc(size); base = util::safe_malloc(size);
((char*)base)[0] = '\0'; ((char*)base)[0] = '\0';
} }
} }
bool ODesc::PushType(const Type* type) bool ODesc::PushType(const Type* type) {
{
auto res = encountered_types.insert(type); auto res = encountered_types.insert(type);
return std::get<1>(res); return std::get<1>(res);
} }
bool ODesc::PopType(const Type* type) bool ODesc::PopType(const Type* type) {
{
size_t res = encountered_types.erase(type); size_t res = encountered_types.erase(type);
return (res == 1); return (res == 1);
} }
bool ODesc::FindType(const Type* type) bool ODesc::FindType(const Type* type) {
{
auto res = encountered_types.find(type); auto res = encountered_types.find(type);
if ( res != encountered_types.end() ) if ( res != encountered_types.end() )
return true; return true;
return false; return false;
} }
std::string obj_desc(const Obj* o) std::string obj_desc(const Obj* o) {
{
static ODesc d; static ODesc d;
d.Clear(); d.Clear();
@ -428,10 +363,9 @@ std::string obj_desc(const Obj* o)
o->GetLocationInfo()->Describe(&d); o->GetLocationInfo()->Describe(&d);
return d.Description(); return d.Description();
} }
std::string obj_desc_short(const Obj* o) std::string obj_desc_short(const Obj* o) {
{
static ODesc d; static ODesc d;
d.SetShort(true); d.SetShort(true);
@ -439,6 +373,6 @@ std::string obj_desc_short(const Obj* o)
o->Describe(&d); o->Describe(&d);
return d.Description(); return d.Description();
} }
} // namespace zeek } // namespace zeek

View file

@ -10,28 +10,24 @@
#include "zeek/ZeekString.h" // for byte_vec #include "zeek/ZeekString.h" // for byte_vec
#include "zeek/util.h" // for zeek_int_t #include "zeek/util.h" // for zeek_int_t
namespace zeek namespace zeek {
{
class IPAddr; class IPAddr;
class IPPrefix; class IPPrefix;
class File; class File;
class Type; class Type;
enum DescType enum DescType {
{
DESC_READABLE, DESC_READABLE,
DESC_BINARY, DESC_BINARY,
}; };
enum DescStyle enum DescStyle {
{
STANDARD_STYLE, STANDARD_STYLE,
RAW_STYLE, RAW_STYLE,
}; };
class ODesc class ODesc {
{
public: public:
explicit ODesc(DescType t = DESC_READABLE, File* f = nullptr); explicit ODesc(DescType t = DESC_READABLE, File* f = nullptr);
@ -69,10 +65,7 @@ public:
void AddEscapeSequence(const char* s, size_t n) { escape_sequences.insert(std::string(s, n)); } void AddEscapeSequence(const char* s, size_t n) { escape_sequences.insert(std::string(s, n)); }
void AddEscapeSequence(const std::string& s) { escape_sequences.insert(s); } void AddEscapeSequence(const std::string& s) { escape_sequences.insert(s); }
void RemoveEscapeSequence(const char* s) { escape_sequences.erase(s); } void RemoveEscapeSequence(const char* s) { escape_sequences.erase(s); }
void RemoveEscapeSequence(const char* s, size_t n) void RemoveEscapeSequence(const char* s, size_t n) { escape_sequences.erase(std::string(s, n)); }
{
escape_sequences.erase(std::string(s, n));
}
void RemoveEscapeSequence(const std::string& s) { escape_sequences.erase(s); } void RemoveEscapeSequence(const std::string& s) { escape_sequences.erase(s); }
void PushIndent(); void PushIndent();
@ -100,40 +93,33 @@ public:
void AddBytes(const String* s); void AddBytes(const String* s);
void Add(const char* s1, const char* s2) void Add(const char* s1, const char* s2) {
{
Add(s1); Add(s1);
Add(s2); Add(s2);
} }
void AddSP(const char* s1, const char* s2) void AddSP(const char* s1, const char* s2) {
{
Add(s1); Add(s1);
AddSP(s2); AddSP(s2);
} }
void AddSP(const char* s) void AddSP(const char* s) {
{
Add(s); Add(s);
SP(); SP();
} }
void AddCount(zeek_int_t n) void AddCount(zeek_int_t n) {
{ if ( ! IsReadable() ) {
if ( ! IsReadable() )
{
Add(n); Add(n);
SP(); SP();
} }
} }
void SP() void SP() {
{
if ( ! IsBinary() ) if ( ! IsBinary() )
Add(" ", 0); Add(" ", 0);
} }
void NL() void NL() {
{
if ( ! IsBinary() && ! is_short ) if ( ! IsBinary() && ! is_short )
Add("\n", 0); Add("\n", 0);
} }
@ -146,8 +132,7 @@ public:
const char* Description() const { return (const char*)base; } const char* Description() const { return (const char*)base; }
const u_char* Bytes() const { return (const u_char*)base; } const u_char* Bytes() const { return (const u_char*)base; }
byte_vec TakeBytes() byte_vec TakeBytes() {
{
const void* t = base; const void* t = base;
base = nullptr; base = nullptr;
size = 0; size = 0;
@ -223,7 +208,7 @@ protected:
File* f; // or the file we're using. File* f; // or the file we're using.
std::set<const Type*> encountered_types; std::set<const Type*> encountered_types;
}; };
// Returns a string representation of an object's description. Used for // Returns a string representation of an object's description. Used for
// debugging and error messages. takes a bare pointer rather than an // debugging and error messages. takes a bare pointer rather than an
@ -235,4 +220,4 @@ std::string obj_desc(const Obj* o);
// Same as obj_desc(), but ensure it is short and don't include location info. // Same as obj_desc(), but ensure it is short and don't include location info.
std::string obj_desc_short(const Obj* o); std::string obj_desc_short(const Obj* o);
} // namespace zeek } // namespace zeek

View file

@ -5,15 +5,13 @@
#include "zeek/3rdparty/doctest.h" #include "zeek/3rdparty/doctest.h"
#include "zeek/Hash.h" #include "zeek/Hash.h"
namespace zeek namespace zeek {
{
// namespace detail // namespace detail
TEST_SUITE_BEGIN("Dict"); TEST_SUITE_BEGIN("Dict");
TEST_CASE("dict construction") TEST_CASE("dict construction") {
{
PDict<int> dict; PDict<int> dict;
CHECK(! dict.IsOrdered()); CHECK(! dict.IsOrdered());
CHECK(dict.Length() == 0); CHECK(dict.Length() == 0);
@ -21,10 +19,9 @@ TEST_CASE("dict construction")
PDict<int> dict2(ORDERED); PDict<int> dict2(ORDERED);
CHECK(dict2.IsOrdered()); CHECK(dict2.IsOrdered());
CHECK(dict2.Length() == 0); CHECK(dict2.Length() == 0);
} }
TEST_CASE("dict operation") TEST_CASE("dict operation") {
{
PDict<uint32_t> dict; PDict<uint32_t> dict;
uint32_t val = 10; uint32_t val = 10;
@ -67,10 +64,9 @@ TEST_CASE("dict operation")
delete key; delete key;
delete key2; delete key2;
} }
TEST_CASE("dict nthentry") TEST_CASE("dict nthentry") {
{
PDict<uint32_t> unordered(UNORDERED); PDict<uint32_t> unordered(UNORDERED);
PDict<uint32_t> ordered(ORDERED); PDict<uint32_t> ordered(ORDERED);
@ -103,10 +99,9 @@ TEST_CASE("dict nthentry")
delete okey2; delete okey2;
delete ukey; delete ukey;
delete ukey2; delete ukey2;
} }
TEST_CASE("dict iteration") TEST_CASE("dict iteration") {
{
PDict<uint32_t> dict; PDict<uint32_t> dict;
uint32_t val = 15; uint32_t val = 15;
@ -122,13 +117,11 @@ TEST_CASE("dict iteration")
int count = 0; int count = 0;
for ( const auto& entry : dict ) for ( const auto& entry : dict ) {
{
auto* v = static_cast<uint32_t*>(entry.value); auto* v = static_cast<uint32_t*>(entry.value);
uint64_t k = *(uint32_t*)entry.GetKey(); uint64_t k = *(uint32_t*)entry.GetKey();
switch ( count ) switch ( count ) {
{
case 0: case 0:
CHECK(k == key_val2); CHECK(k == key_val2);
CHECK(*v == val2); CHECK(*v == val2);
@ -137,8 +130,7 @@ TEST_CASE("dict iteration")
CHECK(k == key_val); CHECK(k == key_val);
CHECK(*v == val); CHECK(*v == val);
break; break;
default: default: break;
break;
} }
count++; count++;
@ -153,10 +145,9 @@ TEST_CASE("dict iteration")
delete key; delete key;
delete key2; delete key2;
} }
TEST_CASE("dict robust iteration") TEST_CASE("dict robust iteration") {
{
PDict<uint32_t> dict; PDict<uint32_t> dict;
uint32_t val = 15; uint32_t val = 15;
@ -178,13 +169,11 @@ TEST_CASE("dict robust iteration")
int count = 0; int count = 0;
auto it = dict.begin_robust(); auto it = dict.begin_robust();
for ( ; it != dict.end_robust(); ++it ) for ( ; it != dict.end_robust(); ++it ) {
{
auto* v = it->value; auto* v = it->value;
uint64_t k = *(uint32_t*)it->GetKey(); uint64_t k = *(uint32_t*)it->GetKey();
switch ( count ) switch ( count ) {
{
case 0: case 0:
CHECK(k == key_val2); CHECK(k == key_val2);
CHECK(*v == val2); CHECK(*v == val2);
@ -213,13 +202,11 @@ TEST_CASE("dict robust iteration")
int count = 0; int count = 0;
auto it = dict.begin_robust(); auto it = dict.begin_robust();
for ( ; it != dict.end_robust(); ++it ) for ( ; it != dict.end_robust(); ++it ) {
{
auto* v = it->value; auto* v = it->value;
uint64_t k = *(uint32_t*)it->GetKey(); uint64_t k = *(uint32_t*)it->GetKey();
switch ( count ) switch ( count ) {
{
case 0: case 0:
CHECK(k == key_val2); CHECK(k == key_val2);
CHECK(*v == val2); CHECK(*v == val2);
@ -244,10 +231,9 @@ TEST_CASE("dict robust iteration")
delete key; delete key;
delete key2; delete key2;
delete key3; delete key3;
} }
TEST_CASE("dict ordered iteration") TEST_CASE("dict ordered iteration") {
{
PDict<uint32_t> dict(DictOrder::ORDERED); PDict<uint32_t> dict(DictOrder::ORDERED);
// These key values are specifically contrived to be inserted // These key values are specifically contrived to be inserted
@ -276,8 +262,7 @@ TEST_CASE("dict ordered iteration")
int count = 0; int count = 0;
for ( const auto& entry : dict ) for ( const auto& entry : dict ) {
{
auto* v = static_cast<uint32_t*>(entry.value); auto* v = static_cast<uint32_t*>(entry.value);
uint32_t k = *(uint32_t*)entry.GetKey(); uint32_t k = *(uint32_t*)entry.GetKey();
@ -296,8 +281,7 @@ TEST_CASE("dict ordered iteration")
dict.Insert(key4.get(), &val4); dict.Insert(key4.get(), &val4);
count = 0; count = 0;
for ( const auto& entry : dict ) for ( const auto& entry : dict ) {
{
auto* v = static_cast<uint32_t*>(entry.value); auto* v = static_cast<uint32_t*>(entry.value);
uint32_t k = *(uint32_t*)entry.GetKey(); uint32_t k = *(uint32_t*)entry.GetKey();
@ -318,8 +302,7 @@ TEST_CASE("dict ordered iteration")
dict.Remove(key2.get()); dict.Remove(key2.get());
count = 0; count = 0;
for ( const auto& entry : dict ) for ( const auto& entry : dict ) {
{
auto* v = static_cast<uint32_t*>(entry.value); auto* v = static_cast<uint32_t*>(entry.value);
uint32_t k = *(uint32_t*)entry.GetKey(); uint32_t k = *(uint32_t*)entry.GetKey();
@ -334,18 +317,16 @@ TEST_CASE("dict ordered iteration")
count++; count++;
} }
} }
class DictTestDummy class DictTestDummy {
{
public: public:
DictTestDummy(int v) : v(v) { } DictTestDummy(int v) : v(v) {}
~DictTestDummy() = default; ~DictTestDummy() = default;
int v = 0; int v = 0;
}; };
TEST_CASE("dict robust iteration replacement") TEST_CASE("dict robust iteration replacement") {
{
PDict<DictTestDummy> dict; PDict<DictTestDummy> dict;
DictTestDummy* val1 = new DictTestDummy(15); DictTestDummy* val1 = new DictTestDummy(15);
@ -369,7 +350,8 @@ TEST_CASE("dict robust iteration replacement")
// Iterate past the first couple of elements so we're not done, but the // Iterate past the first couple of elements so we're not done, but the
// iterator is still pointing at a valid element. // iterator is still pointing at a valid element.
for ( ; count != 2 && it != dict.end_robust(); ++count, ++it ) { } for ( ; count != 2 && it != dict.end_robust(); ++count, ++it ) {
}
// Store off the value at this iterator index // Store off the value at this iterator index
auto* v = it->value; auto* v = it->value;
@ -383,8 +365,7 @@ TEST_CASE("dict robust iteration replacement")
delete val2; delete val2;
// This shouldn't crash with AddressSanitizer // This shouldn't crash with AddressSanitizer
for ( ; it != dict.end_robust(); ++it ) for ( ; it != dict.end_robust(); ++it ) {
{
uint64_t k = *(uint32_t*)it->GetKey(); uint64_t k = *(uint32_t*)it->GetKey();
auto* v = it->value; auto* v = it->value;
CHECK(v->v == 50); CHECK(v->v == 50);
@ -397,10 +378,9 @@ TEST_CASE("dict robust iteration replacement")
delete val1; delete val1;
delete val3; delete val3;
delete val4; delete val4;
} }
TEST_CASE("dict iterator invalidation") TEST_CASE("dict iterator invalidation") {
{
PDict<uint32_t> dict; PDict<uint32_t> dict;
uint32_t val = 15; uint32_t val = 15;
@ -451,12 +431,9 @@ TEST_CASE("dict iterator invalidation")
delete key; delete key;
delete key2; delete key2;
delete key3; delete key3;
} }
// private // private
void generic_delete_func(void* v) void generic_delete_func(void* v) { free(v); }
{
free(v);
}
} // namespace zeek } // namespace zeek

File diff suppressed because it is too large Load diff

View file

@ -14,39 +14,30 @@
#include "zeek/Var.h" #include "zeek/Var.h"
#include "zeek/ZeekString.h" #include "zeek/ZeekString.h"
namespace zeek::detail namespace zeek::detail {
{
Discarder::Discarder() Discarder::Discarder() {
{
check_ip = id::find_func("discarder_check_ip"); check_ip = id::find_func("discarder_check_ip");
check_tcp = id::find_func("discarder_check_tcp"); check_tcp = id::find_func("discarder_check_tcp");
check_udp = id::find_func("discarder_check_udp"); check_udp = id::find_func("discarder_check_udp");
check_icmp = id::find_func("discarder_check_icmp"); check_icmp = id::find_func("discarder_check_icmp");
discarder_maxlen = static_cast<int>(id::find_val("discarder_maxlen")->AsCount()); discarder_maxlen = static_cast<int>(id::find_val("discarder_maxlen")->AsCount());
} }
bool Discarder::IsActive() bool Discarder::IsActive() { return check_ip || check_tcp || check_udp || check_icmp; }
{
return check_ip || check_tcp || check_udp || check_icmp;
}
bool Discarder::NextPacket(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen) bool Discarder::NextPacket(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen) {
{
bool discard_packet = false; bool discard_packet = false;
if ( check_ip ) if ( check_ip ) {
{
zeek::Args args{ip->ToPktHdrVal()}; zeek::Args args{ip->ToPktHdrVal()};
try try {
{
discard_packet = check_ip->Invoke(&args)->AsBool(); discard_packet = check_ip->Invoke(&args)->AsBool();
} }
catch ( InterpreterException& e ) catch ( InterpreterException& e ) {
{
discard_packet = false; discard_packet = false;
} }
@ -70,8 +61,7 @@ bool Discarder::NextPacket(const std::shared_ptr<IP_Hdr>& ip, int len, int caple
bool is_tcp = (proto == IPPROTO_TCP); bool is_tcp = (proto == IPPROTO_TCP);
bool is_udp = (proto == IPPROTO_UDP); bool is_udp = (proto == IPPROTO_UDP);
int min_hdr_len = is_tcp ? sizeof(struct tcphdr) int min_hdr_len = is_tcp ? sizeof(struct tcphdr) : (is_udp ? sizeof(struct udphdr) : sizeof(struct icmp));
: (is_udp ? sizeof(struct udphdr) : sizeof(struct icmp));
if ( len < min_hdr_len || caplen < min_hdr_len ) if ( len < min_hdr_len || caplen < min_hdr_len )
// we don't have a complete protocol header // we don't have a complete protocol header
@ -81,10 +71,8 @@ bool Discarder::NextPacket(const std::shared_ptr<IP_Hdr>& ip, int len, int caple
// this gets advanced past the transport header. // this gets advanced past the transport header.
const u_char* data = ip->Payload(); const u_char* data = ip->Payload();
if ( is_tcp ) if ( is_tcp ) {
{ if ( check_tcp ) {
if ( check_tcp )
{
const struct tcphdr* tp = (const struct tcphdr*)data; const struct tcphdr* tp = (const struct tcphdr*)data;
int th_len = tp->th_off * 4; int th_len = tp->th_off * 4;
@ -93,22 +81,18 @@ bool Discarder::NextPacket(const std::shared_ptr<IP_Hdr>& ip, int len, int caple
{AdoptRef{}, BuildData(data, th_len, len, caplen)}, {AdoptRef{}, BuildData(data, th_len, len, caplen)},
}; };
try try {
{
discard_packet = check_tcp->Invoke(&args)->AsBool(); discard_packet = check_tcp->Invoke(&args)->AsBool();
} }
catch ( InterpreterException& e ) catch ( InterpreterException& e ) {
{
discard_packet = false; discard_packet = false;
} }
} }
} }
else if ( is_udp ) else if ( is_udp ) {
{ if ( check_udp ) {
if ( check_udp )
{
const struct udphdr* up = (const struct udphdr*)data; const struct udphdr* up = (const struct udphdr*)data;
int uh_len = sizeof(struct udphdr); int uh_len = sizeof(struct udphdr);
@ -117,43 +101,36 @@ bool Discarder::NextPacket(const std::shared_ptr<IP_Hdr>& ip, int len, int caple
{AdoptRef{}, BuildData(data, uh_len, len, caplen)}, {AdoptRef{}, BuildData(data, uh_len, len, caplen)},
}; };
try try {
{
discard_packet = check_udp->Invoke(&args)->AsBool(); discard_packet = check_udp->Invoke(&args)->AsBool();
} }
catch ( InterpreterException& e ) catch ( InterpreterException& e ) {
{
discard_packet = false; discard_packet = false;
} }
} }
} }
else else {
{ if ( check_icmp ) {
if ( check_icmp )
{
const struct icmp* ih = (const struct icmp*)data; const struct icmp* ih = (const struct icmp*)data;
zeek::Args args{ip->ToPktHdrVal()}; zeek::Args args{ip->ToPktHdrVal()};
try try {
{
discard_packet = check_icmp->Invoke(&args)->AsBool(); discard_packet = check_icmp->Invoke(&args)->AsBool();
} }
catch ( InterpreterException& e ) catch ( InterpreterException& e ) {
{
discard_packet = false; discard_packet = false;
} }
} }
} }
return discard_packet; return discard_packet;
} }
Val* Discarder::BuildData(const u_char* data, int hdrlen, int len, int caplen) Val* Discarder::BuildData(const u_char* data, int hdrlen, int len, int caplen) {
{
len -= hdrlen; len -= hdrlen;
caplen -= hdrlen; caplen -= hdrlen;
data += hdrlen; data += hdrlen;
@ -161,6 +138,6 @@ Val* Discarder::BuildData(const u_char* data, int hdrlen, int len, int caplen)
len = std::max(std::min(std::min(len, caplen), discarder_maxlen), 0); len = std::max(std::min(std::min(len, caplen), discarder_maxlen), 0);
return new StringVal(new String(data, len, true)); return new StringVal(new String(data, len, true));
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -7,19 +7,16 @@
#include "zeek/IntrusivePtr.h" #include "zeek/IntrusivePtr.h"
namespace zeek namespace zeek {
{
class IP_Hdr; class IP_Hdr;
class Val; class Val;
class Func; class Func;
using FuncPtr = IntrusivePtr<Func>; using FuncPtr = IntrusivePtr<Func>;
namespace detail namespace detail {
{
class Discarder final class Discarder final {
{
public: public:
Discarder(); Discarder();
~Discarder() = default; ~Discarder() = default;
@ -38,7 +35,7 @@ protected:
// Maximum amount of application data passed to filtering functions. // Maximum amount of application data passed to filtering functions.
int discarder_maxlen; int discarder_maxlen;
}; };
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -7,11 +7,9 @@
#include "zeek/CCL.h" #include "zeek/CCL.h"
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek::detail namespace zeek::detail {
{
EquivClass::EquivClass(int arg_size) EquivClass::EquivClass(int arg_size) {
{
size = arg_size; size = arg_size;
fwd = new int[size]; fwd = new int[size];
bck = new int[size]; bck = new int[size];
@ -25,10 +23,8 @@ EquivClass::EquivClass(int arg_size)
bck[0] = ec_nil; bck[0] = ec_nil;
fwd[size - 1] = ec_nil; fwd[size - 1] = ec_nil;
for ( int i = 0; i < size; ++i ) for ( int i = 0; i < size; ++i ) {
{ if ( i > 0 ) {
if ( i > 0 )
{
fwd[i - 1] = i; fwd[i - 1] = i;
bck[i] = i - 1; bck[i] = i - 1;
} }
@ -36,19 +32,17 @@ EquivClass::EquivClass(int arg_size)
equiv_class[i] = no_class; equiv_class[i] = no_class;
rep[i] = no_rep; rep[i] = no_rep;
} }
} }
EquivClass::~EquivClass() EquivClass::~EquivClass() {
{
delete[] fwd; delete[] fwd;
delete[] bck; delete[] bck;
delete[] equiv_class; delete[] equiv_class;
delete[] rep; delete[] rep;
delete[] ccl_flags; delete[] ccl_flags;
} }
void EquivClass::ConvertCCL(CCL* ccl) void EquivClass::ConvertCCL(CCL* ccl) {
{
// For each character in the class, add the character's // For each character in the class, add the character's
// equivalence class to the new "character" class we are // equivalence class to the new "character" class we are
// creating. Thus when we are all done, the character class // creating. Thus when we are all done, the character class
@ -58,50 +52,43 @@ void EquivClass::ConvertCCL(CCL* ccl)
int_list* c_syms = ccl->Syms(); int_list* c_syms = ccl->Syms();
int_list* new_syms = new int_list; int_list* new_syms = new int_list;
for ( auto sym : *c_syms ) for ( auto sym : *c_syms ) {
{
if ( IsRep(sym) ) if ( IsRep(sym) )
new_syms->push_back(SymEquivClass(sym)); new_syms->push_back(SymEquivClass(sym));
} }
ccl->ReplaceSyms(new_syms); ccl->ReplaceSyms(new_syms);
} }
int EquivClass::BuildECs() int EquivClass::BuildECs() {
{
// Create equivalence class numbers. If bck[x] is nil, // Create equivalence class numbers. If bck[x] is nil,
// then x is the representative of its equivalence class. // then x is the representative of its equivalence class.
for ( int i = 0; i < size; ++i ) for ( int i = 0; i < size; ++i )
if ( bck[i] == ec_nil ) if ( bck[i] == ec_nil ) {
{
equiv_class[i] = num_ecs++; equiv_class[i] = num_ecs++;
rep[i] = i; rep[i] = i;
for ( int j = fwd[i]; j != ec_nil; j = fwd[j] ) for ( int j = fwd[i]; j != ec_nil; j = fwd[j] ) {
{
equiv_class[j] = equiv_class[i]; equiv_class[j] = equiv_class[i];
rep[j] = i; rep[j] = i;
} }
} }
return num_ecs; return num_ecs;
} }
void EquivClass::CCL_Use(CCL* ccl) void EquivClass::CCL_Use(CCL* ccl) {
{
// Note that it doesn't matter whether or not the character class is // Note that it doesn't matter whether or not the character class is
// negated. The same results will be obtained in either case. // negated. The same results will be obtained in either case.
if ( ! ccl_flags ) if ( ! ccl_flags ) {
{
ccl_flags = new int[size]; ccl_flags = new int[size];
for ( int i = 0; i < size; ++i ) for ( int i = 0; i < size; ++i )
ccl_flags[i] = 0; ccl_flags[i] = 0;
} }
int_list* csyms = ccl->Syms(); int_list* csyms = ccl->Syms();
for ( size_t i = 0; i < csyms->size(); /* no increment */ ) for ( size_t i = 0; i < csyms->size(); /* no increment */ ) {
{
int sym = (*csyms)[i]; int sym = (*csyms)[i];
int old_ec = bck[sym]; int old_ec = bck[sym];
@ -109,17 +96,14 @@ void EquivClass::CCL_Use(CCL* ccl)
size_t j = i + 1; size_t j = i + 1;
for ( int k = fwd[sym]; k && k < size; k = fwd[k] ) for ( int k = fwd[sym]; k && k < size; k = fwd[k] ) { // look for the symbol in the character class
{ // look for the symbol in the character class for ( ; j < csyms->size(); ++j ) {
for ( ; j < csyms->size(); ++j )
{
if ( (*csyms)[j] > k ) if ( (*csyms)[j] > k )
// Since the character class is sorted, // Since the character class is sorted,
// we can stop. // we can stop.
break; break;
if ( (*csyms)[j] == k && ! ccl_flags[j] ) if ( (*csyms)[j] == k && ! ccl_flags[j] ) {
{
// We found an old companion of sym // We found an old companion of sym
// in the ccl. Link it into the new // in the ccl. Link it into the new
// equivalence class and flag it as // equivalence class and flag it as
@ -150,8 +134,7 @@ void EquivClass::CCL_Use(CCL* ccl)
old_ec = k; old_ec = k;
} }
if ( bck[sym] != ec_nil || old_ec != bck[sym] ) if ( bck[sym] != ec_nil || old_ec != bck[sym] ) {
{
bck[sym] = ec_nil; bck[sym] = ec_nil;
fwd[old_ec] = ec_nil; fwd[old_ec] = ec_nil;
} }
@ -163,10 +146,9 @@ void EquivClass::CCL_Use(CCL* ccl)
// Reset "doesn't need processing" flag. // Reset "doesn't need processing" flag.
ccl_flags[i] = 0; ccl_flags[i] = 0;
} }
} }
void EquivClass::UniqueChar(int sym) void EquivClass::UniqueChar(int sym) {
{
// If until now the character has been a proper subset of // If until now the character has been a proper subset of
// an equivalence class, break it away to create a new ec. // an equivalence class, break it away to create a new ec.
@ -178,19 +160,15 @@ void EquivClass::UniqueChar(int sym)
fwd[sym] = ec_nil; fwd[sym] = ec_nil;
bck[sym] = ec_nil; bck[sym] = ec_nil;
} }
void EquivClass::Dump(FILE* f) void EquivClass::Dump(FILE* f) {
{
fprintf(f, "%d symbols in EC yielded %d ecs\n", size, num_ecs); fprintf(f, "%d symbols in EC yielded %d ecs\n", size, num_ecs);
for ( int i = 0; i < size; ++i ) for ( int i = 0; i < size; ++i )
if ( SymEquivClass(i) != 0 ) // skip usually huge default ec if ( SymEquivClass(i) != 0 ) // skip usually huge default ec
fprintf(f, "map %d ('%c') -> %d\n", i, i, SymEquivClass(i)); fprintf(f, "map %d ('%c') -> %d\n", i, i, SymEquivClass(i));
} }
int EquivClass::Size() const int EquivClass::Size() const { return padded_sizeof(*this) + util::pad_size(sizeof(int) * size * (ccl_flags ? 5 : 4)); }
{
return padded_sizeof(*this) + util::pad_size(sizeof(int) * size * (ccl_flags ? 5 : 4));
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -4,13 +4,11 @@
#include <stdio.h> #include <stdio.h>
namespace zeek::detail namespace zeek::detail {
{
class CCL; class CCL;
class EquivClass class EquivClass {
{
public: public:
explicit EquivClass(int size); explicit EquivClass(int size);
~EquivClass(); ~EquivClass();
@ -44,6 +42,6 @@ protected:
int* rep; // representative for symbol's equivalence class int* rep; // representative for symbol's equivalence class
int* ccl_flags; int* ccl_flags;
int ec_nil, no_class, no_rep; int ec_nil, no_class, no_rep;
}; };
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -15,20 +15,22 @@
zeek::EventMgr zeek::event_mgr; zeek::EventMgr zeek::event_mgr;
namespace zeek namespace zeek {
{
Event::Event(const EventHandlerPtr& arg_handler, zeek::Args arg_args, Event::Event(const EventHandlerPtr& arg_handler, zeek::Args arg_args, util::detail::SourceID arg_src,
util::detail::SourceID arg_src, analyzer::ID arg_aid, Obj* arg_obj, double arg_ts) analyzer::ID arg_aid, Obj* arg_obj, double arg_ts)
: handler(arg_handler), args(std::move(arg_args)), src(arg_src), aid(arg_aid), ts(arg_ts), : handler(arg_handler),
obj(arg_obj), next_event(nullptr) args(std::move(arg_args)),
{ src(arg_src),
aid(arg_aid),
ts(arg_ts),
obj(arg_obj),
next_event(nullptr) {
if ( obj ) if ( obj )
Ref(obj); Ref(obj);
} }
void Event::Describe(ODesc* d) const void Event::Describe(ODesc* d) const {
{
if ( d->IsReadable() ) if ( d->IsReadable() )
d->AddSP("event"); d->AddSP("event");
@ -40,23 +42,20 @@ void Event::Describe(ODesc* d) const
describe_vals(args, d); describe_vals(args, d);
if ( ! d->IsBinary() ) if ( ! d->IsBinary() )
d->Add("("); d->Add("(");
} }
void Event::Dispatch(bool no_remote) void Event::Dispatch(bool no_remote) {
{
if ( src == util::detail::SOURCE_BROKER ) if ( src == util::detail::SOURCE_BROKER )
no_remote = true; no_remote = true;
if ( handler->ErrorHandler() ) if ( handler->ErrorHandler() )
reporter->BeginErrorHandler(); reporter->BeginErrorHandler();
try try {
{
handler->Call(&args, no_remote, ts); handler->Call(&args, no_remote, ts);
} }
catch ( InterpreterException& e ) catch ( InterpreterException& e ) {
{
// Already reported. // Already reported.
} }
@ -66,68 +65,59 @@ void Event::Dispatch(bool no_remote)
if ( handler->ErrorHandler() ) if ( handler->ErrorHandler() )
reporter->EndErrorHandler(); reporter->EndErrorHandler();
} }
EventMgr::EventMgr() EventMgr::EventMgr() {
{
head = tail = nullptr; head = tail = nullptr;
current_src = util::detail::SOURCE_LOCAL; current_src = util::detail::SOURCE_LOCAL;
current_aid = 0; current_aid = 0;
current_ts = 0; current_ts = 0;
src_val = nullptr; src_val = nullptr;
draining = false; draining = false;
} }
EventMgr::~EventMgr() EventMgr::~EventMgr() {
{ while ( head ) {
while ( head )
{
Event* n = head->NextEvent(); Event* n = head->NextEvent();
Unref(head); Unref(head);
head = n; head = n;
} }
Unref(src_val); Unref(src_val);
} }
void EventMgr::Enqueue(const EventHandlerPtr& h, Args vl, util::detail::SourceID src, void EventMgr::Enqueue(const EventHandlerPtr& h, Args vl, util::detail::SourceID src, analyzer::ID aid, Obj* obj,
analyzer::ID aid, Obj* obj, double ts) double ts) {
{
QueueEvent(new Event(h, std::move(vl), src, aid, obj, ts)); QueueEvent(new Event(h, std::move(vl), src, aid, obj, ts));
} }
void EventMgr::QueueEvent(Event* event) void EventMgr::QueueEvent(Event* event) {
{
bool done = PLUGIN_HOOK_WITH_RESULT(HOOK_QUEUE_EVENT, HookQueueEvent(event), false); bool done = PLUGIN_HOOK_WITH_RESULT(HOOK_QUEUE_EVENT, HookQueueEvent(event), false);
if ( done ) if ( done )
return; return;
if ( ! head ) if ( ! head ) {
{
head = tail = event; head = tail = event;
queue_flare.Fire(); queue_flare.Fire();
} }
else else {
{
tail->SetNext(event); tail->SetNext(event);
tail = event; tail = event;
} }
++event_mgr.num_events_queued; ++event_mgr.num_events_queued;
} }
void EventMgr::Dispatch(Event* event, bool no_remote) void EventMgr::Dispatch(Event* event, bool no_remote) {
{
current_src = event->Source(); current_src = event->Source();
current_aid = event->Analyzer(); current_aid = event->Analyzer();
current_ts = event->Time(); current_ts = event->Time();
event->Dispatch(no_remote); event->Dispatch(no_remote);
Unref(event); Unref(event);
} }
void EventMgr::Drain() void EventMgr::Drain() {
{
if ( event_queue_flush_point ) if ( event_queue_flush_point )
Enqueue(event_queue_flush_point, Args{}); Enqueue(event_queue_flush_point, Args{});
@ -144,14 +134,12 @@ void EventMgr::Drain()
// just one round to make it less likely to break existing scripts // just one round to make it less likely to break existing scripts
// that expect the old behavior to trigger something quickly. // that expect the old behavior to trigger something quickly.
for ( int round = 0; head && round < 2; round++ ) for ( int round = 0; head && round < 2; round++ ) {
{
Event* current = head; Event* current = head;
head = nullptr; head = nullptr;
tail = nullptr; tail = nullptr;
while ( current ) while ( current ) {
{
Event* next = current->NextEvent(); Event* next = current->NextEvent();
current_src = current->Source(); current_src = current->Source();
@ -172,10 +160,9 @@ void EventMgr::Drain()
// Make sure all of the triggers get processed every time the events // Make sure all of the triggers get processed every time the events
// drain. // drain.
detail::trigger_mgr->Process(); detail::trigger_mgr->Process();
} }
void EventMgr::Describe(ODesc* d) const void EventMgr::Describe(ODesc* d) const {
{
int n = 0; int n = 0;
Event* e; Event* e;
for ( e = head; e; e = e->NextEvent() ) for ( e = head; e; e = e->NextEvent() )
@ -183,36 +170,32 @@ void EventMgr::Describe(ODesc* d) const
d->AddCount(n); d->AddCount(n);
for ( e = head; e; e = e->NextEvent() ) for ( e = head; e; e = e->NextEvent() ) {
{
e->Describe(d); e->Describe(d);
d->NL(); d->NL();
} }
} }
void EventMgr::Process() void EventMgr::Process() {
{
queue_flare.Extinguish(); queue_flare.Extinguish();
// While it semes like the most logical thing to do, we dont want // While it semes like the most logical thing to do, we dont want
// to call Drain() as part of this method. It will get called at // to call Drain() as part of this method. It will get called at
// the end of net_run after all of the sources have been processed // the end of net_run after all of the sources have been processed
// and had the opportunity to spawn new events. // and had the opportunity to spawn new events.
} }
void EventMgr::InitPostScript() void EventMgr::InitPostScript() {
{
iosource_mgr->Register(this, true, false); iosource_mgr->Register(this, true, false);
if ( ! iosource_mgr->RegisterFd(queue_flare.FD(), this) ) if ( ! iosource_mgr->RegisterFd(queue_flare.FD(), this) )
reporter->FatalError("Failed to register event manager FD with iosource_mgr"); reporter->FatalError("Failed to register event manager FD with iosource_mgr");
} }
void EventMgr::InitPostFork() void EventMgr::InitPostFork() {
{
// Re-initialize the flare, closing and re-opening the underlying // Re-initialize the flare, closing and re-opening the underlying
// pipe FDs. This is needed so that each Zeek process in a supervisor // pipe FDs. This is needed so that each Zeek process in a supervisor
// setup has its own pipe instead of them all sharing a single pipe. // setup has its own pipe instead of them all sharing a single pipe.
queue_flare = zeek::detail::Flare{}; queue_flare = zeek::detail::Flare{};
} }
} // namespace zeek } // namespace zeek

View file

@ -12,22 +12,18 @@
#include "zeek/analyzer/Analyzer.h" #include "zeek/analyzer/Analyzer.h"
#include "zeek/iosource/IOSource.h" #include "zeek/iosource/IOSource.h"
namespace zeek namespace zeek {
{
namespace run_state namespace run_state {
{
extern double network_time; extern double network_time;
} // namespace run_state } // namespace run_state
class EventMgr; class EventMgr;
class Event final : public Obj class Event final : public Obj {
{
public: public:
Event(const EventHandlerPtr& handler, zeek::Args args, Event(const EventHandlerPtr& handler, zeek::Args args, util::detail::SourceID src = util::detail::SOURCE_LOCAL,
util::detail::SourceID src = util::detail::SOURCE_LOCAL, analyzer::ID aid = 0, analyzer::ID aid = 0, Obj* obj = nullptr, double ts = run_state::network_time);
Obj* obj = nullptr, double ts = run_state::network_time);
void SetNext(Event* n) { next_event = n; } void SetNext(Event* n) { next_event = n; }
Event* NextEvent() const { return next_event; } Event* NextEvent() const { return next_event; }
@ -54,10 +50,9 @@ protected:
double ts; double ts;
Obj* obj; Obj* obj;
Event* next_event; Event* next_event;
}; };
class EventMgr final : public Obj, public iosource::IOSource class EventMgr final : public Obj, public iosource::IOSource {
{
public: public:
EventMgr(); EventMgr();
~EventMgr() override; ~EventMgr() override;
@ -76,17 +71,15 @@ public:
* @param ts timestamp at which the event is intended to be executed * @param ts timestamp at which the event is intended to be executed
* (defaults to current network time). * (defaults to current network time).
*/ */
void Enqueue(const EventHandlerPtr& h, zeek::Args vl, void Enqueue(const EventHandlerPtr& h, zeek::Args vl, util::detail::SourceID src = util::detail::SOURCE_LOCAL,
util::detail::SourceID src = util::detail::SOURCE_LOCAL, analyzer::ID aid = 0, analyzer::ID aid = 0, Obj* obj = nullptr, double ts = run_state::network_time);
Obj* obj = nullptr, double ts = run_state::network_time);
/** /**
* A version of Enqueue() taking a variable number of arguments. * A version of Enqueue() taking a variable number of arguments.
*/ */
template <class... Args> template<class... Args>
std::enable_if_t<std::is_convertible_v<std::tuple_element_t<0, std::tuple<Args...>>, ValPtr>> std::enable_if_t<std::is_convertible_v<std::tuple_element_t<0, std::tuple<Args...>>, ValPtr>> Enqueue(
Enqueue(const EventHandlerPtr& h, Args&&... args) const EventHandlerPtr& h, Args&&... args) {
{
return Enqueue(h, zeek::Args{std::forward<Args>(args)...}); return Enqueue(h, zeek::Args{std::forward<Args>(args)...});
} }
@ -135,8 +128,8 @@ protected:
RecordVal* src_val; RecordVal* src_val;
bool draining; bool draining;
detail::Flare queue_flare; detail::Flare queue_flare;
}; };
extern EventMgr event_mgr; extern EventMgr event_mgr;
} // namespace zeek } // namespace zeek

View file

@ -11,31 +11,25 @@
#include "zeek/broker/Manager.h" #include "zeek/broker/Manager.h"
#include "zeek/telemetry/Manager.h" #include "zeek/telemetry/Manager.h"
namespace zeek namespace zeek {
{
EventHandler::EventHandler(std::string arg_name) EventHandler::EventHandler(std::string arg_name) {
{
name = std::move(arg_name); name = std::move(arg_name);
used = false; used = false;
error_handler = false; error_handler = false;
enabled = true; enabled = true;
generate_always = false; generate_always = false;
} }
EventHandler::operator bool() const EventHandler::operator bool() const {
{ return enabled && ((local && local->HasEnabledBodies()) || generate_always || ! auto_publish.empty());
return enabled && }
((local && local->HasEnabledBodies()) || generate_always || ! auto_publish.empty());
}
const FuncTypePtr& EventHandler::GetType(bool check_export) const FuncTypePtr& EventHandler::GetType(bool check_export) {
{
if ( type ) if ( type )
return type; return type;
const auto& id = detail::lookup_ID(name.data(), detail::current_module.c_str(), false, false, const auto& id = detail::lookup_ID(name.data(), detail::current_module.c_str(), false, false, check_export);
check_export);
if ( ! id ) if ( ! id )
return FuncType::nil; return FuncType::nil;
@ -45,19 +39,14 @@ const FuncTypePtr& EventHandler::GetType(bool check_export)
type = id->GetType<FuncType>(); type = id->GetType<FuncType>();
return type; return type;
} }
void EventHandler::SetFunc(FuncPtr f) void EventHandler::SetFunc(FuncPtr f) { local = std::move(f); }
{
local = std::move(f);
}
void EventHandler::Call(Args* vl, bool no_remote, double ts) void EventHandler::Call(Args* vl, bool no_remote, double ts) {
{ if ( ! call_count ) {
if ( ! call_count ) static auto eh_invocations_family =
{ telemetry_mgr->CounterFamily("zeek", "event-handler-invocations", {"name"},
static auto eh_invocations_family = telemetry_mgr->CounterFamily(
"zeek", "event-handler-invocations", {"name"},
"Number of times the given event handler was called", "1", true); "Number of times the given event handler was called", "1", true);
call_count = eh_invocations_family.GetOrAdd({{"name", name}}); call_count = eh_invocations_family.GetOrAdd({{"name", name}});
@ -68,23 +57,19 @@ void EventHandler::Call(Args* vl, bool no_remote, double ts)
if ( new_event ) if ( new_event )
NewEvent(vl); NewEvent(vl);
if ( ! no_remote ) if ( ! no_remote ) {
{ if ( ! auto_publish.empty() ) {
if ( ! auto_publish.empty() )
{
// Send event in form [name, xs...] where xs represent the arguments. // Send event in form [name, xs...] where xs represent the arguments.
broker::vector xs; broker::vector xs;
xs.reserve(vl->size()); xs.reserve(vl->size());
bool valid_args = true; bool valid_args = true;
for ( auto i = 0u; i < vl->size(); ++i ) for ( auto i = 0u; i < vl->size(); ++i ) {
{
auto opt_data = Broker::detail::val_to_data((*vl)[i].get()); auto opt_data = Broker::detail::val_to_data((*vl)[i].get());
if ( opt_data ) if ( opt_data )
xs.emplace_back(std::move(*opt_data)); xs.emplace_back(std::move(*opt_data));
else else {
{
valid_args = false; valid_args = false;
auto_publish.clear(); auto_publish.clear();
reporter->Error("failed auto-remote event '%s', disabled", Name()); reporter->Error("failed auto-remote event '%s', disabled", Name());
@ -92,17 +77,14 @@ void EventHandler::Call(Args* vl, bool no_remote, double ts)
} }
} }
if ( valid_args ) if ( valid_args ) {
{ for ( auto it = auto_publish.begin();; ) {
for ( auto it = auto_publish.begin();; )
{
const auto& topic = *it; const auto& topic = *it;
++it; ++it;
if ( it != auto_publish.end() ) if ( it != auto_publish.end() )
broker_mgr->PublishEvent(topic, Name(), xs, ts); broker_mgr->PublishEvent(topic, Name(), xs, ts);
else else {
{
broker_mgr->PublishEvent(topic, Name(), std::move(xs), ts); broker_mgr->PublishEvent(topic, Name(), std::move(xs), ts);
break; break;
} }
@ -114,10 +96,9 @@ void EventHandler::Call(Args* vl, bool no_remote, double ts)
if ( local ) if ( local )
// No try/catch here; we pass exceptions upstream. // No try/catch here; we pass exceptions upstream.
local->Invoke(vl); local->Invoke(vl);
} }
void EventHandler::NewEvent(Args* vl) void EventHandler::NewEvent(Args* vl) {
{
if ( ! new_event ) if ( ! new_event )
return; return;
@ -132,6 +113,6 @@ void EventHandler::NewEvent(Args* vl)
std::move(vargs), std::move(vargs),
}); });
event_mgr.Dispatch(ev); event_mgr.Dispatch(ev);
} }
} // namespace zeek } // namespace zeek

View file

@ -11,19 +11,16 @@
#include "zeek/ZeekList.h" #include "zeek/ZeekList.h"
#include "zeek/telemetry/Counter.h" #include "zeek/telemetry/Counter.h"
namespace zeek namespace zeek {
{
namespace run_state namespace run_state {
{
extern double network_time; extern double network_time;
} // namespace run_state } // namespace run_state
class Func; class Func;
using FuncPtr = IntrusivePtr<Func>; using FuncPtr = IntrusivePtr<Func>;
class EventHandler class EventHandler {
{
public: public:
explicit EventHandler(std::string name); explicit EventHandler(std::string name);
@ -56,10 +53,7 @@ public:
// Flags the event as interesting even if there is no body defined. In // Flags the event as interesting even if there is no body defined. In
// particular, this will then still pass the event on to plugins. // particular, this will then still pass the event on to plugins.
void SetGenerateAlways(bool arg_generate_always = true) void SetGenerateAlways(bool arg_generate_always = true) { generate_always = arg_generate_always; }
{
generate_always = arg_generate_always;
}
bool GenerateAlways() const { return generate_always; } bool GenerateAlways() const { return generate_always; }
uint64_t CallCount() const { return call_count ? call_count->Value() : 0; } uint64_t CallCount() const { return call_count ? call_count->Value() : 0; }
@ -79,22 +73,19 @@ private:
std::optional<zeek::telemetry::IntCounter> call_count; std::optional<zeek::telemetry::IntCounter> call_count;
std::unordered_set<std::string> auto_publish; std::unordered_set<std::string> auto_publish;
}; };
// Encapsulates a ptr to an event handler to overload the boolean operator. // Encapsulates a ptr to an event handler to overload the boolean operator.
class EventHandlerPtr class EventHandlerPtr {
{
public: public:
EventHandlerPtr(EventHandler* p = nullptr) { handler = p; } EventHandlerPtr(EventHandler* p = nullptr) { handler = p; }
EventHandlerPtr(const EventHandlerPtr& h) { handler = h.handler; } EventHandlerPtr(const EventHandlerPtr& h) { handler = h.handler; }
const EventHandlerPtr& operator=(EventHandler* p) const EventHandlerPtr& operator=(EventHandler* p) {
{
handler = p; handler = p;
return *this; return *this;
} }
const EventHandlerPtr& operator=(const EventHandlerPtr& h) const EventHandlerPtr& operator=(const EventHandlerPtr& h) {
{
handler = h.handler; handler = h.handler;
return *this; return *this;
} }
@ -109,6 +100,6 @@ public:
private: private:
EventHandler* handler; EventHandler* handler;
}; };
} // namespace zeek } // namespace zeek

View file

@ -7,20 +7,17 @@
#include "zeek/RE.h" #include "zeek/RE.h"
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
namespace zeek namespace zeek {
{
EventRegistry::EventRegistry() = default; EventRegistry::EventRegistry() = default;
EventRegistry::~EventRegistry() noexcept = default; EventRegistry::~EventRegistry() noexcept = default;
EventHandlerPtr EventRegistry::Register(std::string_view name, bool is_from_script) EventHandlerPtr EventRegistry::Register(std::string_view name, bool is_from_script) {
{
// If there already is an entry in the registry, we have a // If there already is an entry in the registry, we have a
// local handler on the script layer. // local handler on the script layer.
EventHandler* h = event_registry->Lookup(name); EventHandler* h = event_registry->Lookup(name);
if ( h ) if ( h ) {
{
if ( ! is_from_script ) if ( ! is_from_script )
not_only_from_script.insert(std::string(name)); not_only_from_script.insert(std::string(name));
@ -34,140 +31,120 @@ EventHandlerPtr EventRegistry::Register(std::string_view name, bool is_from_scri
h->SetUsed(); h->SetUsed();
return h; return h;
} }
void EventRegistry::Register(EventHandlerPtr handler, bool is_from_script) void EventRegistry::Register(EventHandlerPtr handler, bool is_from_script) {
{
std::string name = handler->Name(); std::string name = handler->Name();
handlers[name] = std::unique_ptr<EventHandler>(handler.Ptr()); handlers[name] = std::unique_ptr<EventHandler>(handler.Ptr());
if ( ! is_from_script ) if ( ! is_from_script )
not_only_from_script.insert(name); not_only_from_script.insert(name);
} }
EventHandler* EventRegistry::Lookup(std::string_view name) EventHandler* EventRegistry::Lookup(std::string_view name) {
{
auto it = handlers.find(name); auto it = handlers.find(name);
if ( it != handlers.end() ) if ( it != handlers.end() )
return it->second.get(); return it->second.get();
return nullptr; return nullptr;
} }
bool EventRegistry::NotOnlyRegisteredFromScript(std::string_view name) bool EventRegistry::NotOnlyRegisteredFromScript(std::string_view name) {
{
return not_only_from_script.count(std::string(name)) > 0; return not_only_from_script.count(std::string(name)) > 0;
} }
EventRegistry::string_list EventRegistry::Match(RE_Matcher* pattern) EventRegistry::string_list EventRegistry::Match(RE_Matcher* pattern) {
{
string_list names; string_list names;
for ( const auto& entry : handlers ) for ( const auto& entry : handlers ) {
{
EventHandler* v = entry.second.get(); EventHandler* v = entry.second.get();
if ( v->GetFunc() && pattern->MatchExactly(v->Name()) ) if ( v->GetFunc() && pattern->MatchExactly(v->Name()) )
names.push_back(entry.first); names.push_back(entry.first);
} }
return names; return names;
} }
EventRegistry::string_list EventRegistry::UnusedHandlers() EventRegistry::string_list EventRegistry::UnusedHandlers() {
{
string_list names; string_list names;
for ( const auto& entry : handlers ) for ( const auto& entry : handlers ) {
{
EventHandler* v = entry.second.get(); EventHandler* v = entry.second.get();
if ( v->GetFunc() && ! v->Used() ) if ( v->GetFunc() && ! v->Used() )
names.push_back(entry.first); names.push_back(entry.first);
} }
return names; return names;
} }
EventRegistry::string_list EventRegistry::UsedHandlers() EventRegistry::string_list EventRegistry::UsedHandlers() {
{
string_list names; string_list names;
for ( const auto& entry : handlers ) for ( const auto& entry : handlers ) {
{
EventHandler* v = entry.second.get(); EventHandler* v = entry.second.get();
if ( v->GetFunc() && v->Used() ) if ( v->GetFunc() && v->Used() )
names.push_back(entry.first); names.push_back(entry.first);
} }
return names; return names;
} }
EventRegistry::string_list EventRegistry::AllHandlers() EventRegistry::string_list EventRegistry::AllHandlers() {
{
string_list names; string_list names;
for ( const auto& entry : handlers ) for ( const auto& entry : handlers ) {
{
names.push_back(entry.first); names.push_back(entry.first);
} }
return names; return names;
} }
void EventRegistry::PrintDebug() void EventRegistry::PrintDebug() {
{ for ( const auto& entry : handlers ) {
for ( const auto& entry : handlers )
{
EventHandler* v = entry.second.get(); EventHandler* v = entry.second.get();
fprintf(stderr, "Registered event %s (%s handler / %s)\n", v->Name(), fprintf(stderr, "Registered event %s (%s handler / %s)\n", v->Name(), v->GetFunc() ? "local" : "no",
v->GetFunc() ? "local" : "no", *v ? "active" : "not active"); *v ? "active" : "not active");
}
} }
}
void EventRegistry::SetErrorHandler(std::string_view name) void EventRegistry::SetErrorHandler(std::string_view name) {
{
EventHandler* eh = Lookup(name); EventHandler* eh = Lookup(name);
if ( eh ) if ( eh ) {
{
eh->SetErrorHandler(); eh->SetErrorHandler();
return; return;
} }
reporter->InternalWarning("unknown event handler '%s' in SetErrorHandler()", reporter->InternalWarning("unknown event handler '%s' in SetErrorHandler()", std::string(name).c_str());
std::string(name).c_str()); }
}
void EventRegistry::ActivateAllHandlers() void EventRegistry::ActivateAllHandlers() {
{
auto event_names = AllHandlers(); auto event_names = AllHandlers();
for ( const auto& name : event_names ) for ( const auto& name : event_names ) {
{
if ( auto event = Lookup(name) ) if ( auto event = Lookup(name) )
event->SetGenerateAlways(); event->SetGenerateAlways();
} }
} }
EventGroupPtr EventRegistry::RegisterGroup(EventGroupKind kind, std::string_view name) EventGroupPtr EventRegistry::RegisterGroup(EventGroupKind kind, std::string_view name) {
{
auto key = std::pair{kind, std::string{name}}; auto key = std::pair{kind, std::string{name}};
if ( const auto& it = event_groups.find(key); it != event_groups.end() ) if ( const auto& it = event_groups.find(key); it != event_groups.end() )
return it->second; return it->second;
auto group = std::make_shared<EventGroup>(kind, name); auto group = std::make_shared<EventGroup>(kind, name);
return event_groups.emplace(key, group).first->second; return event_groups.emplace(key, group).first->second;
} }
EventGroupPtr EventRegistry::LookupGroup(EventGroupKind kind, std::string_view name) EventGroupPtr EventRegistry::LookupGroup(EventGroupKind kind, std::string_view name) {
{
auto key = std::pair{kind, std::string{name}}; auto key = std::pair{kind, std::string{name}};
if ( const auto& it = event_groups.find(key); it != event_groups.end() ) if ( const auto& it = event_groups.find(key); it != event_groups.end() )
return it->second; return it->second;
return nullptr; return nullptr;
} }
EventGroup::EventGroup(EventGroupKind kind, std::string_view name) : kind(kind), name(name) { } EventGroup::EventGroup(EventGroupKind kind, std::string_view name) : kind(kind), name(name) {}
// Run through all ScriptFunc instances associated with this group and // Run through all ScriptFunc instances associated with this group and
// update their bodies after a group's enable/disable state has changed. // update their bodies after a group's enable/disable state has changed.
@ -177,50 +154,36 @@ EventGroup::EventGroup(EventGroupKind kind, std::string_view name) : kind(kind),
// EventGroup is private friend with Func, so fiddling with the bodies // EventGroup is private friend with Func, so fiddling with the bodies
// and private members works and keeps the logic out of Func and away // and private members works and keeps the logic out of Func and away
// from the public zeek:: namespace. // from the public zeek:: namespace.
void EventGroup::UpdateFuncBodies() void EventGroup::UpdateFuncBodies() {
{ static auto is_group_disabled = [](const auto& g) { return g->IsDisabled(); };
static auto is_group_disabled = [](const auto& g)
{
return g->IsDisabled();
};
for ( auto& func : funcs ) for ( auto& func : funcs ) {
{
for ( auto& b : func->bodies ) for ( auto& b : func->bodies )
b.disabled = std::any_of(b.groups.cbegin(), b.groups.cend(), is_group_disabled); b.disabled = std::any_of(b.groups.cbegin(), b.groups.cend(), is_group_disabled);
static auto is_body_enabled = [](const auto& b) static auto is_body_enabled = [](const auto& b) { return ! b.disabled; };
{ func->has_enabled_bodies = std::any_of(func->bodies.cbegin(), func->bodies.cend(), is_body_enabled);
return ! b.disabled;
};
func->has_enabled_bodies = std::any_of(func->bodies.cbegin(), func->bodies.cend(),
is_body_enabled);
}
} }
}
void EventGroup::Enable() void EventGroup::Enable() {
{
if ( enabled ) if ( enabled )
return; return;
enabled = true; enabled = true;
UpdateFuncBodies(); UpdateFuncBodies();
} }
void EventGroup::Disable() void EventGroup::Disable() {
{
if ( ! enabled ) if ( ! enabled )
return; return;
enabled = false; enabled = false;
UpdateFuncBodies(); UpdateFuncBodies();
} }
void EventGroup::AddFunc(detail::ScriptFuncPtr f) void EventGroup::AddFunc(detail::ScriptFuncPtr f) { funcs.insert(f); }
{
funcs.insert(f);
}
} // namespace zeek } // namespace zeek

View file

@ -13,15 +13,13 @@
#include "zeek/IntrusivePtr.h" #include "zeek/IntrusivePtr.h"
namespace zeek namespace zeek {
{
// The different kinds of event groups that exist. // The different kinds of event groups that exist.
enum class EventGroupKind enum class EventGroupKind {
{
Attribute, Attribute,
Module, Module,
}; };
class EventGroup; class EventGroup;
class EventHandler; class EventHandler;
@ -30,15 +28,13 @@ class RE_Matcher;
using EventGroupPtr = std::shared_ptr<EventGroup>; using EventGroupPtr = std::shared_ptr<EventGroup>;
namespace detail namespace detail {
{
class ScriptFunc; class ScriptFunc;
using ScriptFuncPtr = zeek::IntrusivePtr<ScriptFunc>; using ScriptFuncPtr = zeek::IntrusivePtr<ScriptFunc>;
} } // namespace detail
// The registry keeps track of all events that we provide or handle. // The registry keeps track of all events that we provide or handle.
class EventRegistry final class EventRegistry final {
{
public: public:
EventRegistry(); EventRegistry();
~EventRegistry() noexcept; ~EventRegistry() noexcept;
@ -110,9 +106,8 @@ private:
std::unordered_set<std::string> not_only_from_script; std::unordered_set<std::string> not_only_from_script;
// Map event groups identified by kind and name to their instances. // Map event groups identified by kind and name to their instances.
std::map<std::pair<EventGroupKind, std::string>, std::shared_ptr<EventGroup>, std::less<>> std::map<std::pair<EventGroupKind, std::string>, std::shared_ptr<EventGroup>, std::less<>> event_groups;
event_groups; };
};
/** /**
* Event group. * Event group.
@ -135,8 +130,7 @@ private:
* bodies of the tracked ScriptFuncs and updates them to reflect the current * bodies of the tracked ScriptFuncs and updates them to reflect the current
* group state. * group state.
*/ */
class EventGroup final class EventGroup final {
{
public: public:
EventGroup(EventGroupKind kind, std::string_view name); EventGroup(EventGroupKind kind, std::string_view name);
~EventGroup() noexcept = default; ~EventGroup() noexcept = default;
@ -172,8 +166,8 @@ private:
std::string name; std::string name;
bool enabled = true; bool enabled = true;
std::unordered_set<detail::ScriptFuncPtr> funcs; std::unordered_set<detail::ScriptFuncPtr> funcs;
}; };
extern EventRegistry* event_registry; extern EventRegistry* event_registry;
} // namespace zeek } // namespace zeek

File diff suppressed because it is too large Load diff

View file

@ -5,19 +5,17 @@
#include "zeek/Val.h" #include "zeek/Val.h"
#include "zeek/ZeekArgs.h" #include "zeek/ZeekArgs.h"
namespace zeek::detail namespace zeek::detail {
{
class ValTrace; class ValTrace;
class ValTraceMgr; class ValTraceMgr;
// Abstract class for capturing a single difference between two script-level // Abstract class for capturing a single difference between two script-level
// values. Includes notions of inserting, changing, or deleting a value. // values. Includes notions of inserting, changing, or deleting a value.
class ValDelta class ValDelta {
{
public: public:
ValDelta(const ValTrace* _vt) : vt(_vt) { } ValDelta(const ValTrace* _vt) : vt(_vt) {}
virtual ~ValDelta() { } virtual ~ValDelta() {}
// Return a string that performs the update operation, expressed // Return a string that performs the update operation, expressed
// as Zeek scripting. Does not include a terminating semicolon. // as Zeek scripting. Does not include a terminating semicolon.
@ -35,7 +33,7 @@ public:
protected: protected:
const ValTrace* vt; const ValTrace* vt;
}; };
using DeltaVector = std::vector<std::unique_ptr<ValDelta>>; using DeltaVector = std::vector<std::unique_ptr<ValDelta>>;
@ -43,8 +41,7 @@ using DeltaVector = std::vector<std::unique_ptr<ValDelta>>;
// For non-aggregates, this is simply the Val object, but for aggregates // For non-aggregates, this is simply the Val object, but for aggregates
// it is (recursively) each of the sub-elements, in a manner that can then // it is (recursively) each of the sub-elements, in a manner that can then
// be readily compared against future instances. // be readily compared against future instances.
class ValTrace class ValTrace {
{
public: public:
ValTrace(const ValPtr& v); ValTrace(const ValPtr& v);
~ValTrace() = default; ~ValTrace() = default;
@ -104,188 +101,157 @@ private:
ValPtr v; ValPtr v;
TypePtr t; // v's type, for convenience TypePtr t; // v's type, for convenience
}; };
// Captures the basic notion of a new, non-equivalent value being assigned. // Captures the basic notion of a new, non-equivalent value being assigned.
class DeltaReplaceValue : public ValDelta class DeltaReplaceValue : public ValDelta {
{
public: public:
DeltaReplaceValue(const ValTrace* _vt, ValPtr _new_val) DeltaReplaceValue(const ValTrace* _vt, ValPtr _new_val) : ValDelta(_vt), new_val(std::move(_new_val)) {}
: ValDelta(_vt), new_val(std::move(_new_val))
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
private: private:
ValPtr new_val; ValPtr new_val;
}; };
// Captures the notion of setting a record field. // Captures the notion of setting a record field.
class DeltaSetField : public ValDelta class DeltaSetField : public ValDelta {
{
public: public:
DeltaSetField(const ValTrace* _vt, int _field, ValPtr _new_val) DeltaSetField(const ValTrace* _vt, int _field, ValPtr _new_val)
: ValDelta(_vt), field(_field), new_val(std::move(_new_val)) : ValDelta(_vt), field(_field), new_val(std::move(_new_val)) {}
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
private: private:
int field; int field;
ValPtr new_val; ValPtr new_val;
}; };
// Captures the notion of deleting a record field. // Captures the notion of deleting a record field.
class DeltaRemoveField : public ValDelta class DeltaRemoveField : public ValDelta {
{
public: public:
DeltaRemoveField(const ValTrace* _vt, int _field) : ValDelta(_vt), field(_field) { } DeltaRemoveField(const ValTrace* _vt, int _field) : ValDelta(_vt), field(_field) {}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
bool NeedsLHS() const override { return false; } bool NeedsLHS() const override { return false; }
private: private:
int field; int field;
}; };
// Captures the notion of creating a record from scratch. // Captures the notion of creating a record from scratch.
class DeltaRecordCreate : public ValDelta class DeltaRecordCreate : public ValDelta {
{
public: public:
DeltaRecordCreate(const ValTrace* _vt) : ValDelta(_vt) { } DeltaRecordCreate(const ValTrace* _vt) : ValDelta(_vt) {}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
}; };
// Captures the notion of adding an element to a set. Use DeltaRemoveTableEntry to // Captures the notion of adding an element to a set. Use DeltaRemoveTableEntry to
// delete values. // delete values.
class DeltaSetSetEntry : public ValDelta class DeltaSetSetEntry : public ValDelta {
{
public: public:
DeltaSetSetEntry(const ValTrace* _vt, ValPtr _index) : ValDelta(_vt), index(_index) { } DeltaSetSetEntry(const ValTrace* _vt, ValPtr _index) : ValDelta(_vt), index(_index) {}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
bool NeedsLHS() const override { return false; } bool NeedsLHS() const override { return false; }
private: private:
ValPtr index; ValPtr index;
}; };
// Captures the notion of setting a table entry (which includes both changing // Captures the notion of setting a table entry (which includes both changing
// an existing one and adding a new one). Use DeltaRemoveTableEntry to // an existing one and adding a new one). Use DeltaRemoveTableEntry to
// delete values. // delete values.
class DeltaSetTableEntry : public ValDelta class DeltaSetTableEntry : public ValDelta {
{
public: public:
DeltaSetTableEntry(const ValTrace* _vt, ValPtr _index, ValPtr _new_val) DeltaSetTableEntry(const ValTrace* _vt, ValPtr _index, ValPtr _new_val)
: ValDelta(_vt), index(_index), new_val(std::move(_new_val)) : ValDelta(_vt), index(_index), new_val(std::move(_new_val)) {}
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
private: private:
ValPtr index; ValPtr index;
ValPtr new_val; ValPtr new_val;
}; };
// Captures the notion of removing a table/set entry. // Captures the notion of removing a table/set entry.
class DeltaRemoveTableEntry : public ValDelta class DeltaRemoveTableEntry : public ValDelta {
{
public: public:
DeltaRemoveTableEntry(const ValTrace* _vt, ValPtr _index) DeltaRemoveTableEntry(const ValTrace* _vt, ValPtr _index) : ValDelta(_vt), index(std::move(_index)) {}
: ValDelta(_vt), index(std::move(_index))
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
bool NeedsLHS() const override { return false; } bool NeedsLHS() const override { return false; }
private: private:
ValPtr index; ValPtr index;
}; };
// Captures the notion of creating a set from scratch. // Captures the notion of creating a set from scratch.
class DeltaSetCreate : public ValDelta class DeltaSetCreate : public ValDelta {
{
public: public:
DeltaSetCreate(const ValTrace* _vt) : ValDelta(_vt) { } DeltaSetCreate(const ValTrace* _vt) : ValDelta(_vt) {}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
}; };
// Captures the notion of creating a table from scratch. // Captures the notion of creating a table from scratch.
class DeltaTableCreate : public ValDelta class DeltaTableCreate : public ValDelta {
{
public: public:
DeltaTableCreate(const ValTrace* _vt) : ValDelta(_vt) { } DeltaTableCreate(const ValTrace* _vt) : ValDelta(_vt) {}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
}; };
// Captures the notion of changing an element of a vector. // Captures the notion of changing an element of a vector.
class DeltaVectorSet : public ValDelta class DeltaVectorSet : public ValDelta {
{
public: public:
DeltaVectorSet(const ValTrace* _vt, int _index, ValPtr _elem) DeltaVectorSet(const ValTrace* _vt, int _index, ValPtr _elem)
: ValDelta(_vt), index(_index), elem(std::move(_elem)) : ValDelta(_vt), index(_index), elem(std::move(_elem)) {}
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
private: private:
int index; int index;
ValPtr elem; ValPtr elem;
}; };
// Captures the notion of adding an entry to the end of a vector. // Captures the notion of adding an entry to the end of a vector.
class DeltaVectorAppend : public ValDelta class DeltaVectorAppend : public ValDelta {
{
public: public:
DeltaVectorAppend(const ValTrace* _vt, int _index, ValPtr _elem) DeltaVectorAppend(const ValTrace* _vt, int _index, ValPtr _elem)
: ValDelta(_vt), index(_index), elem(std::move(_elem)) : ValDelta(_vt), index(_index), elem(std::move(_elem)) {}
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
private: private:
int index; int index;
ValPtr elem; ValPtr elem;
}; };
// Captures the notion of replacing a vector wholesale. // Captures the notion of replacing a vector wholesale.
class DeltaVectorCreate : public ValDelta class DeltaVectorCreate : public ValDelta {
{
public: public:
DeltaVectorCreate(const ValTrace* _vt) : ValDelta(_vt) { } DeltaVectorCreate(const ValTrace* _vt) : ValDelta(_vt) {}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
}; };
// Captures the notion of creating a value with an unsupported type // Captures the notion of creating a value with an unsupported type
// (like "opaque"). // (like "opaque").
class DeltaUnsupportedCreate : public ValDelta class DeltaUnsupportedCreate : public ValDelta {
{
public: public:
DeltaUnsupportedCreate(const ValTrace* _vt) : ValDelta(_vt) { } DeltaUnsupportedCreate(const ValTrace* _vt) : ValDelta(_vt) {}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
}; };
// Manages the changes to (or creation of) a variable used to represent // Manages the changes to (or creation of) a variable used to represent
// a value. // a value.
class DeltaGen class DeltaGen {
{
public: public:
DeltaGen(ValPtr _val, std::string _rhs, bool _needs_lhs, bool _is_first_def) DeltaGen(ValPtr _val, std::string _rhs, bool _needs_lhs, bool _is_first_def)
: val(std::move(_val)), rhs(std::move(_rhs)), needs_lhs(_needs_lhs), : val(std::move(_val)), rhs(std::move(_rhs)), needs_lhs(_needs_lhs), is_first_def(_is_first_def) {}
is_first_def(_is_first_def)
{
}
const ValPtr& GetVal() const { return val; } const ValPtr& GetVal() const { return val; }
const std::string& RHS() const { return rhs; } const std::string& RHS() const { return rhs; }
@ -305,13 +271,12 @@ private:
// Whether this is the first definition of the variable (in which // Whether this is the first definition of the variable (in which
// case we also need to declare the variable). // case we also need to declare the variable).
bool is_first_def; bool is_first_def;
}; };
using DeltaGenVec = std::vector<DeltaGen>; using DeltaGenVec = std::vector<DeltaGen>;
// Tracks a single event. // Tracks a single event.
class EventTrace class EventTrace {
{
public: public:
// Constructed in terms of the associated script function, "network // Constructed in terms of the associated script function, "network
// time" when the event occurred, and the position of this event // time" when the event occurred, and the position of this event
@ -323,8 +288,7 @@ public:
void SetArgs(std::string _args) { args = std::move(_args); } void SetArgs(std::string _args) { args = std::move(_args); }
// Adds to the trace an update for the given value. // Adds to the trace an update for the given value.
void AddDelta(ValPtr val, std::string rhs, bool needs_lhs, bool is_first_def) void AddDelta(ValPtr val, std::string rhs, bool needs_lhs, bool is_first_def) {
{
auto& d = is_post ? post_deltas : deltas; auto& d = is_post ? post_deltas : deltas;
d.emplace_back(DeltaGen(val, rhs, needs_lhs, is_first_def)); d.emplace_back(DeltaGen(val, rhs, needs_lhs, is_first_def));
} }
@ -346,14 +310,12 @@ public:
// "predecessor", if non-nil, gives the event that came just before // "predecessor", if non-nil, gives the event that came just before
// this one (used for "# from script" annotations"). "successor", // this one (used for "# from script" annotations"). "successor",
// if not empty, gives the name of the successor internal event. // if not empty, gives the name of the successor internal event.
void Generate(FILE* f, ValTraceMgr& vtm, const EventTrace* predecessor, void Generate(FILE* f, ValTraceMgr& vtm, const EventTrace* predecessor, std::string successor) const;
std::string successor) const;
private: private:
// "dvec" is either just our deltas, or the "post_deltas" of our // "dvec" is either just our deltas, or the "post_deltas" of our
// predecessor plus our deltas. // predecessor plus our deltas.
void Generate(FILE* f, ValTraceMgr& vtm, const DeltaGenVec& dvec, std::string successor, void Generate(FILE* f, ValTraceMgr& vtm, const DeltaGenVec& dvec, std::string successor, int num_pre = 0) const;
int num_pre = 0) const;
const ScriptFunc* ev; const ScriptFunc* ev;
double nt; double nt;
@ -370,11 +332,10 @@ private:
// The event's name and a string representation of its arguments. // The event's name and a string representation of its arguments.
std::string name; std::string name;
std::string args; std::string args;
}; };
// Manages all of the events and associated values seen during the execution. // Manages all of the events and associated values seen during the execution.
class ValTraceMgr class ValTraceMgr {
{
public: public:
// Invoked to trace a new event with the associated arguments. // Invoked to trace a new event with the associated arguments.
void TraceEventValues(std::shared_ptr<EventTrace> et, const zeek::Args* args); void TraceEventValues(std::shared_ptr<EventTrace> et, const zeek::Args* args);
@ -470,12 +431,11 @@ private:
// Hang on to values we're tracking to make sure the pointers don't // Hang on to values we're tracking to make sure the pointers don't
// get reused when the main use of the value ends. // get reused when the main use of the value ends.
std::vector<ValPtr> vals; std::vector<ValPtr> vals;
}; };
// Manages tracing of all of the events seen during execution, including // Manages tracing of all of the events seen during execution, including
// the final generation of the trace script. // the final generation of the trace script.
class EventTraceMgr class EventTraceMgr {
{
public: public:
EventTraceMgr(const std::string& trace_file); EventTraceMgr(const std::string& trace_file);
~EventTraceMgr(); ~EventTraceMgr();
@ -498,9 +458,9 @@ private:
// The names of all of the script events that have been generated. // The names of all of the script events that have been generated.
std::unordered_set<std::string> script_events; std::unordered_set<std::string> script_events;
}; };
// If non-nil then we're doing event tracing. // If non-nil then we're doing event tracing.
extern std::unique_ptr<EventTraceMgr> etm; extern std::unique_ptr<EventTraceMgr> etm;
} // namespace zeek::detail } // namespace zeek::detail

File diff suppressed because it is too large Load diff

View file

@ -18,12 +18,11 @@
#include "zeek/ZeekArgs.h" #include "zeek/ZeekArgs.h"
#include "zeek/ZeekList.h" #include "zeek/ZeekList.h"
namespace zeek namespace zeek {
{ template<class T>
template <class T> class IntrusivePtr; class IntrusivePtr;
namespace detail namespace detail {
{
class Frame; class Frame;
class Scope; class Scope;
@ -34,8 +33,7 @@ using ScopePtr = IntrusivePtr<Scope>;
using ScriptFuncPtr = IntrusivePtr<ScriptFunc>; using ScriptFuncPtr = IntrusivePtr<ScriptFunc>;
using FunctionIngredientsPtr = std::shared_ptr<FunctionIngredients>; using FunctionIngredientsPtr = std::shared_ptr<FunctionIngredients>;
enum ExprTag : int enum ExprTag : int {
{
EXPR_ANY = -1, EXPR_ANY = -1,
EXPR_NAME, EXPR_NAME,
EXPR_CONST, EXPR_CONST,
@ -108,7 +106,7 @@ enum ExprTag : int
EXPR_NOP, EXPR_NOP,
#define NUM_EXPRS (int(EXPR_NOP) + 1) #define NUM_EXPRS (int(EXPR_NOP) + 1)
}; };
extern const char* expr_name(ExprTag t); extern const char* expr_name(ExprTag t);
@ -150,17 +148,18 @@ using StmtPtr = IntrusivePtr<Stmt>;
class ExprOptInfo; class ExprOptInfo;
class Expr : public Obj class Expr : public Obj {
{
public: public:
const TypePtr& GetType() const { return type; } const TypePtr& GetType() const { return type; }
template <class T> IntrusivePtr<T> GetType() const { return cast_intrusive<T>(type); } template<class T>
IntrusivePtr<T> GetType() const {
return cast_intrusive<T>(type);
}
ExprTag Tag() const { return tag; } ExprTag Tag() const { return tag; }
Expr* Ref() Expr* Ref() {
{
zeek::Ref(this); zeek::Ref(this);
return this; return this;
} }
@ -270,10 +269,7 @@ public:
// True if the expression can serve as an operand to a reduced // True if the expression can serve as an operand to a reduced
// expression. // expression.
bool IsSingleton(Reducer* r) const bool IsSingleton(Reducer* r) const { return (tag == EXPR_NAME && IsReduced(r)) || tag == EXPR_CONST; }
{
return (tag == EXPR_NAME && IsReduced(r)) || tag == EXPR_CONST;
}
// True if the expression has no side effects, false otherwise. // True if the expression has no side effects, false otherwise.
virtual bool HasNoSideEffects() const { return IsPure(); } virtual bool HasNoSideEffects() const { return IsPure(); }
@ -287,8 +283,7 @@ public:
// True if (a) the expression has at least one operand, and (b) all // True if (a) the expression has at least one operand, and (b) all
// of its operands are constant. // of its operands are constant.
bool HasConstantOps() const bool HasConstantOps() const {
{
return GetOp1() && GetOp1()->IsConst() && return GetOp1() && GetOp1()->IsConst() &&
(! GetOp2() || (GetOp2()->IsConst() && (! GetOp3() || GetOp3()->IsConst()))); (! GetOp2() || (GetOp2()->IsConst() && (! GetOp3() || GetOp3()->IsConst())));
} }
@ -346,10 +341,7 @@ public:
// that's been assigned to the given expression via red_stmt. // that's been assigned to the given expression via red_stmt.
ExprPtr AssignToTemporary(ExprPtr e, Reducer* c, StmtPtr& red_stmt); ExprPtr AssignToTemporary(ExprPtr e, Reducer* c, StmtPtr& red_stmt);
// Same but for this expression. // Same but for this expression.
ExprPtr AssignToTemporary(Reducer* c, StmtPtr& red_stmt) ExprPtr AssignToTemporary(Reducer* c, StmtPtr& red_stmt) { return AssignToTemporary(ThisPtr(), c, red_stmt); }
{
return AssignToTemporary(ThisPtr(), c, red_stmt);
}
// If the expression always evaluates to the same value, returns // If the expression always evaluates to the same value, returns
// that value. Otherwise, returns nullptr. // that value. Otherwise, returns nullptr.
@ -379,8 +371,7 @@ public:
const Expr* Original() const { return original ? original->Original() : this; } const Expr* Original() const { return original ? original->Original() : this; }
// Designate the given Expr node as the original for this one. // Designate the given Expr node as the original for this one.
void SetOriginal(ExprPtr _orig) void SetOriginal(ExprPtr _orig) {
{
if ( ! original ) if ( ! original )
original = std::move(_orig); original = std::move(_orig);
} }
@ -392,16 +383,14 @@ public:
// code, which is always passing in "new XyzExpr(...)". This // code, which is always passing in "new XyzExpr(...)". This
// call, as a convenient side effect, transforms that bare pointer // call, as a convenient side effect, transforms that bare pointer
// into an ExprPtr. // into an ExprPtr.
virtual ExprPtr SetSucc(Expr* succ) virtual ExprPtr SetSucc(Expr* succ) {
{
succ->SetOriginal(ThisPtr()); succ->SetOriginal(ThisPtr());
if ( IsParen() ) if ( IsParen() )
succ->MarkParen(); succ->MarkParen();
return {AdoptRef{}, succ}; return {AdoptRef{}, succ};
} }
const detail::Location* GetLocationInfo() const override const detail::Location* GetLocationInfo() const override {
{
if ( original ) if ( original )
return original->GetLocationInfo(); return original->GetLocationInfo();
else else
@ -456,10 +445,9 @@ protected:
// Number of expressions created thus far. // Number of expressions created thus far.
static int num_exprs; static int num_exprs;
}; };
class NameExpr final : public Expr class NameExpr final : public Expr {
{
public: public:
explicit NameExpr(IDPtr id, bool const_init = false); explicit NameExpr(IDPtr id, bool const_init = false);
@ -490,10 +478,9 @@ protected:
IDPtr id; IDPtr id;
bool in_const_init; bool in_const_init;
}; };
class ConstExpr final : public Expr class ConstExpr final : public Expr {
{
public: public:
explicit ConstExpr(ValPtr val); explicit ConstExpr(ValPtr val);
@ -511,10 +498,9 @@ public:
protected: protected:
void ExprDescribe(ODesc* d) const override; void ExprDescribe(ODesc* d) const override;
ValPtr val; ValPtr val;
}; };
class UnaryExpr : public Expr class UnaryExpr : public Expr {
{
public: public:
Expr* Op() const { return op.get(); } Expr* Op() const { return op.get(); }
@ -547,10 +533,9 @@ protected:
virtual ValPtr Fold(Val* v) const; virtual ValPtr Fold(Val* v) const;
ExprPtr op; ExprPtr op;
}; };
class BinaryExpr : public Expr class BinaryExpr : public Expr {
{
public: public:
Expr* Op1() const { return op1.get(); } Expr* Op1() const { return op1.get(); }
Expr* Op2() const { return op2.get(); } Expr* Op2() const { return op2.get(); }
@ -580,8 +565,7 @@ public:
protected: protected:
BinaryExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2) BinaryExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2)
: Expr(arg_tag), op1(std::move(arg_op1)), op2(std::move(arg_op2)) : Expr(arg_tag), op1(std::move(arg_op1)), op2(std::move(arg_op2)) {
{
if ( ! (op1 && op2) ) if ( ! (op1 && op2) )
return; return;
if ( op1->IsError() || op2->IsError() ) if ( op1->IsError() || op2->IsError() )
@ -634,10 +618,9 @@ protected:
ExprPtr op1; ExprPtr op1;
ExprPtr op2; ExprPtr op2;
}; };
class CloneExpr final : public UnaryExpr class CloneExpr final : public UnaryExpr {
{
public: public:
explicit CloneExpr(ExprPtr op); explicit CloneExpr(ExprPtr op);
ValPtr Eval(Frame* f) const override; ValPtr Eval(Frame* f) const override;
@ -647,10 +630,9 @@ public:
protected: protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class IncrExpr final : public UnaryExpr class IncrExpr final : public UnaryExpr {
{
public: public:
IncrExpr(ExprTag tag, ExprPtr op); IncrExpr(ExprTag tag, ExprPtr op);
@ -666,10 +648,9 @@ public:
bool HasReducedOps(Reducer* c) const override { return false; } bool HasReducedOps(Reducer* c) const override { return false; }
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override;
}; };
class ComplementExpr final : public UnaryExpr class ComplementExpr final : public UnaryExpr {
{
public: public:
explicit ComplementExpr(ExprPtr op); explicit ComplementExpr(ExprPtr op);
@ -680,10 +661,9 @@ public:
protected: protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class NotExpr final : public UnaryExpr class NotExpr final : public UnaryExpr {
{
public: public:
explicit NotExpr(ExprPtr op); explicit NotExpr(ExprPtr op);
@ -694,10 +674,9 @@ public:
protected: protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class PosExpr final : public UnaryExpr class PosExpr final : public UnaryExpr {
{
public: public:
explicit PosExpr(ExprPtr op); explicit PosExpr(ExprPtr op);
@ -708,10 +687,9 @@ public:
protected: protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class NegExpr final : public UnaryExpr class NegExpr final : public UnaryExpr {
{
public: public:
explicit NegExpr(ExprPtr op); explicit NegExpr(ExprPtr op);
@ -722,10 +700,9 @@ public:
protected: protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class SizeExpr final : public UnaryExpr class SizeExpr final : public UnaryExpr {
{
public: public:
explicit SizeExpr(ExprPtr op); explicit SizeExpr(ExprPtr op);
ValPtr Eval(Frame* f) const override; ValPtr Eval(Frame* f) const override;
@ -735,10 +712,9 @@ public:
protected: protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class AddExpr final : public BinaryExpr class AddExpr final : public BinaryExpr {
{
public: public:
AddExpr(ExprPtr op1, ExprPtr op2); AddExpr(ExprPtr op1, ExprPtr op2);
void Canonicalize() override; void Canonicalize() override;
@ -750,10 +726,9 @@ public:
protected: protected:
ExprPtr BuildSub(const ExprPtr& op1, const ExprPtr& op2); ExprPtr BuildSub(const ExprPtr& op1, const ExprPtr& op2);
}; };
class AddToExpr final : public BinaryExpr class AddToExpr final : public BinaryExpr {
{
public: public:
AddToExpr(ExprPtr op1, ExprPtr op2); AddToExpr(ExprPtr op1, ExprPtr op2);
ValPtr Eval(Frame* f) const override; ValPtr Eval(Frame* f) const override;
@ -770,10 +745,9 @@ public:
private: private:
// Whether this operation is appending a single element to a vector. // Whether this operation is appending a single element to a vector.
bool is_vector_elem_append = false; bool is_vector_elem_append = false;
}; };
class RemoveFromExpr final : public BinaryExpr class RemoveFromExpr final : public BinaryExpr {
{
public: public:
bool IsPure() const override { return false; } bool IsPure() const override { return false; }
RemoveFromExpr(ExprPtr op1, ExprPtr op2); RemoveFromExpr(ExprPtr op1, ExprPtr op2);
@ -786,10 +760,9 @@ public:
bool IsReduced(Reducer* c) const override; bool IsReduced(Reducer* c) const override;
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override;
}; };
class SubExpr final : public BinaryExpr class SubExpr final : public BinaryExpr {
{
public: public:
SubExpr(ExprPtr op1, ExprPtr op2); SubExpr(ExprPtr op1, ExprPtr op2);
@ -797,10 +770,9 @@ public:
ExprPtr Duplicate() override; ExprPtr Duplicate() override;
bool WillTransform(Reducer* c) const override; bool WillTransform(Reducer* c) const override;
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
}; };
class TimesExpr final : public BinaryExpr class TimesExpr final : public BinaryExpr {
{
public: public:
TimesExpr(ExprPtr op1, ExprPtr op2); TimesExpr(ExprPtr op1, ExprPtr op2);
void Canonicalize() override; void Canonicalize() override;
@ -809,10 +781,9 @@ public:
ExprPtr Duplicate() override; ExprPtr Duplicate() override;
bool WillTransform(Reducer* c) const override; bool WillTransform(Reducer* c) const override;
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
}; };
class DivideExpr final : public BinaryExpr class DivideExpr final : public BinaryExpr {
{
public: public:
DivideExpr(ExprPtr op1, ExprPtr op2); DivideExpr(ExprPtr op1, ExprPtr op2);
@ -820,10 +791,9 @@ public:
ExprPtr Duplicate() override; ExprPtr Duplicate() override;
bool WillTransform(Reducer* c) const override; bool WillTransform(Reducer* c) const override;
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
}; };
class MaskExpr final : public BinaryExpr class MaskExpr final : public BinaryExpr {
{
public: public:
MaskExpr(ExprPtr op1, ExprPtr op2); MaskExpr(ExprPtr op1, ExprPtr op2);
@ -832,19 +802,17 @@ public:
protected: protected:
ValPtr AddrFold(Val* v1, Val* v2) const override; ValPtr AddrFold(Val* v1, Val* v2) const override;
}; };
class ModExpr final : public BinaryExpr class ModExpr final : public BinaryExpr {
{
public: public:
ModExpr(ExprPtr op1, ExprPtr op2); ModExpr(ExprPtr op1, ExprPtr op2);
// Optimization-related: // Optimization-related:
ExprPtr Duplicate() override; ExprPtr Duplicate() override;
}; };
class BoolExpr final : public BinaryExpr class BoolExpr final : public BinaryExpr {
{
public: public:
BoolExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); BoolExpr(ExprTag tag, ExprPtr op1, ExprPtr op2);
@ -860,10 +828,9 @@ public:
protected: protected:
bool IsTrue(const ExprPtr& e) const; bool IsTrue(const ExprPtr& e) const;
bool IsFalse(const ExprPtr& e) const; bool IsFalse(const ExprPtr& e) const;
}; };
class BitExpr final : public BinaryExpr class BitExpr final : public BinaryExpr {
{
public: public:
BitExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); BitExpr(ExprTag tag, ExprPtr op1, ExprPtr op2);
@ -871,10 +838,9 @@ public:
ExprPtr Duplicate() override; ExprPtr Duplicate() override;
bool WillTransform(Reducer* c) const override; bool WillTransform(Reducer* c) const override;
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
}; };
class EqExpr final : public BinaryExpr class EqExpr final : public BinaryExpr {
{
public: public:
EqExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); EqExpr(ExprTag tag, ExprPtr op1, ExprPtr op2);
void Canonicalize() override; void Canonicalize() override;
@ -887,10 +853,9 @@ public:
protected: protected:
ValPtr Fold(Val* v1, Val* v2) const override; ValPtr Fold(Val* v1, Val* v2) const override;
}; };
class RelExpr final : public BinaryExpr class RelExpr final : public BinaryExpr {
{
public: public:
RelExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); RelExpr(ExprTag tag, ExprPtr op1, ExprPtr op2);
void Canonicalize() override; void Canonicalize() override;
@ -900,10 +865,9 @@ public:
bool WillTransform(Reducer* c) const override; bool WillTransform(Reducer* c) const override;
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
bool InvertSense() override; bool InvertSense() override;
}; };
class CondExpr final : public Expr class CondExpr final : public Expr {
{
public: public:
CondExpr(ExprPtr op1, ExprPtr op2, ExprPtr op3); CondExpr(ExprPtr op1, ExprPtr op2, ExprPtr op3);
@ -940,10 +904,9 @@ protected:
ExprPtr op1; ExprPtr op1;
ExprPtr op2; ExprPtr op2;
ExprPtr op3; ExprPtr op3;
}; };
class RefExpr final : public UnaryExpr class RefExpr final : public UnaryExpr {
{
public: public:
explicit RefExpr(ExprPtr op); explicit RefExpr(ExprPtr op);
@ -960,15 +923,14 @@ public:
// Reduce to simplified LHS form, i.e., a reference to only a name. // Reduce to simplified LHS form, i.e., a reference to only a name.
StmtPtr ReduceToLHS(Reducer* c); StmtPtr ReduceToLHS(Reducer* c);
}; };
class AssignExpr : public BinaryExpr class AssignExpr : public BinaryExpr {
{
public: public:
// If val is given, evaluating this expression will always yield the val // If val is given, evaluating this expression will always yield the val
// yet still perform the assignment. Used for triggers. // yet still perform the assignment. Used for triggers.
AssignExpr(ExprPtr op1, ExprPtr op2, bool is_init, ValPtr val = nullptr, AssignExpr(ExprPtr op1, ExprPtr op2, bool is_init, ValPtr val = nullptr, const AttributesPtr& attrs = nullptr,
const AttributesPtr& attrs = nullptr, bool type_check = true); bool type_check = true);
ValPtr Eval(Frame* f) const override; ValPtr Eval(Frame* f) const override;
TypePtr InitType() const override; TypePtr InitType() const override;
@ -1006,20 +968,18 @@ protected:
bool is_temp = false; // Optimization related bool is_temp = false; // Optimization related
ValPtr val; // optional ValPtr val; // optional
}; };
class IndexSliceAssignExpr final : public AssignExpr class IndexSliceAssignExpr final : public AssignExpr {
{
public: public:
IndexSliceAssignExpr(ExprPtr op1, ExprPtr op2, bool is_init); IndexSliceAssignExpr(ExprPtr op1, ExprPtr op2, bool is_init);
ValPtr Eval(Frame* f) const override; ValPtr Eval(Frame* f) const override;
// Optimization-related: // Optimization-related:
ExprPtr Duplicate() override; ExprPtr Duplicate() override;
}; };
class IndexExpr : public BinaryExpr class IndexExpr : public BinaryExpr {
{
public: public:
IndexExpr(ExprPtr op1, ListExprPtr op2, bool is_slice = false, bool is_inside_when = false); IndexExpr(ExprPtr op1, ListExprPtr op2, bool is_slice = false, bool is_inside_when = false);
@ -1052,7 +1012,7 @@ protected:
bool is_slice; bool is_slice;
bool is_inside_when; bool is_inside_when;
}; };
// The following execute the heart of IndexExpr functionality for // The following execute the heart of IndexExpr functionality for
// vector slices and strings. // vector slices and strings.
@ -1084,8 +1044,7 @@ extern VectorValPtr vector_int_select(VectorTypePtr vt, const VectorVal* v1, con
// //
// TODO: One Fine Day we should do the equivalent for accessing fields // TODO: One Fine Day we should do the equivalent for accessing fields
// in records, too. // in records, too.
class IndexExprWhen final : public IndexExpr class IndexExprWhen final : public IndexExpr {
{
public: public:
static inline std::vector<ValPtr> results = {}; static inline std::vector<ValPtr> results = {};
static inline int evaluating = 0; static inline int evaluating = 0;
@ -1094,20 +1053,16 @@ public:
static void EndEval() { --evaluating; } static void EndEval() { --evaluating; }
static std::vector<ValPtr> TakeAllResults() static std::vector<ValPtr> TakeAllResults() {
{
auto rval = std::move(results); auto rval = std::move(results);
results = {}; results = {};
return rval; return rval;
} }
IndexExprWhen(ExprPtr op1, ListExprPtr op2, bool is_slice = false) IndexExprWhen(ExprPtr op1, ListExprPtr op2, bool is_slice = false)
: IndexExpr(std::move(op1), std::move(op2), is_slice, true) : IndexExpr(std::move(op1), std::move(op2), is_slice, true) {}
{
}
ValPtr Eval(Frame* f) const override ValPtr Eval(Frame* f) const override {
{
auto v = IndexExpr::Eval(f); auto v = IndexExpr::Eval(f);
if ( v && evaluating > 0 ) if ( v && evaluating > 0 )
@ -1118,10 +1073,9 @@ public:
// Optimization-related: // Optimization-related:
ExprPtr Duplicate() override; ExprPtr Duplicate() override;
}; };
class FieldExpr final : public UnaryExpr class FieldExpr final : public UnaryExpr {
{
public: public:
FieldExpr(ExprPtr op, const char* field_name); FieldExpr(ExprPtr op, const char* field_name);
~FieldExpr() override; ~FieldExpr() override;
@ -1147,12 +1101,11 @@ protected:
const char* field_name; const char* field_name;
const TypeDecl* td; const TypeDecl* td;
int field; // -1 = attributes int field; // -1 = attributes
}; };
// "rec?$fieldname" is true if the value of $fieldname in rec is not nil. // "rec?$fieldname" is true if the value of $fieldname in rec is not nil.
// "rec?$$attrname" is true if the attribute attrname is not nil. // "rec?$$attrname" is true if the attribute attrname is not nil.
class HasFieldExpr final : public UnaryExpr class HasFieldExpr final : public UnaryExpr {
{
public: public:
HasFieldExpr(ExprPtr op, const char* field_name); HasFieldExpr(ExprPtr op, const char* field_name);
~HasFieldExpr() override; ~HasFieldExpr() override;
@ -1173,10 +1126,9 @@ protected:
const char* field_name; const char* field_name;
int field; int field;
}; };
class RecordConstructorExpr final : public Expr class RecordConstructorExpr final : public Expr {
{
public: public:
explicit RecordConstructorExpr(ListExprPtr constructor_list); explicit RecordConstructorExpr(ListExprPtr constructor_list);
@ -1205,10 +1157,9 @@ protected:
ListExprPtr op; ListExprPtr op;
std::optional<std::vector<int>> map; std::optional<std::vector<int>> map;
}; };
class TableConstructorExpr final : public UnaryExpr class TableConstructorExpr final : public UnaryExpr {
{
public: public:
TableConstructorExpr(ListExprPtr constructor_list, std::unique_ptr<std::vector<AttrPtr>> attrs, TableConstructorExpr(ListExprPtr constructor_list, std::unique_ptr<std::vector<AttrPtr>> attrs,
TypePtr arg_type = nullptr, AttributesPtr arg_attrs = nullptr); TypePtr arg_type = nullptr, AttributesPtr arg_attrs = nullptr);
@ -1231,10 +1182,9 @@ protected:
void ExprDescribe(ODesc* d) const override; void ExprDescribe(ODesc* d) const override;
AttributesPtr attrs; AttributesPtr attrs;
}; };
class SetConstructorExpr final : public UnaryExpr class SetConstructorExpr final : public UnaryExpr {
{
public: public:
SetConstructorExpr(ListExprPtr constructor_list, std::unique_ptr<std::vector<AttrPtr>> attrs, SetConstructorExpr(ListExprPtr constructor_list, std::unique_ptr<std::vector<AttrPtr>> attrs,
TypePtr arg_type = nullptr, AttributesPtr arg_attrs = nullptr); TypePtr arg_type = nullptr, AttributesPtr arg_attrs = nullptr);
@ -1257,10 +1207,9 @@ protected:
void ExprDescribe(ODesc* d) const override; void ExprDescribe(ODesc* d) const override;
AttributesPtr attrs; AttributesPtr attrs;
}; };
class VectorConstructorExpr final : public UnaryExpr class VectorConstructorExpr final : public UnaryExpr {
{
public: public:
explicit VectorConstructorExpr(ListExprPtr constructor_list, TypePtr arg_type = nullptr); explicit VectorConstructorExpr(ListExprPtr constructor_list, TypePtr arg_type = nullptr);
@ -1273,10 +1222,9 @@ public:
protected: protected:
void ExprDescribe(ODesc* d) const override; void ExprDescribe(ODesc* d) const override;
}; };
class FieldAssignExpr final : public UnaryExpr class FieldAssignExpr final : public UnaryExpr {
{
public: public:
FieldAssignExpr(const char* field_name, ExprPtr value); FieldAssignExpr(const char* field_name, ExprPtr value);
@ -1301,10 +1249,9 @@ protected:
void ExprDescribe(ODesc* d) const override; void ExprDescribe(ODesc* d) const override;
std::string field_name; std::string field_name;
}; };
class ArithCoerceExpr final : public UnaryExpr class ArithCoerceExpr final : public UnaryExpr {
{
public: public:
ArithCoerceExpr(ExprPtr op, TypeTag t); ArithCoerceExpr(ExprPtr op, TypeTag t);
@ -1317,10 +1264,9 @@ public:
protected: protected:
ValPtr FoldSingleVal(ValPtr v, const TypePtr& t) const; ValPtr FoldSingleVal(ValPtr v, const TypePtr& t) const;
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class RecordCoerceExpr final : public UnaryExpr class RecordCoerceExpr final : public UnaryExpr {
{
public: public:
RecordCoerceExpr(ExprPtr op, RecordTypePtr r); RecordCoerceExpr(ExprPtr op, RecordTypePtr r);
@ -1335,12 +1281,11 @@ protected:
// For each super-record slot, gives subrecord slot with which to // For each super-record slot, gives subrecord slot with which to
// fill it. // fill it.
std::vector<int> map; std::vector<int> map;
}; };
extern RecordValPtr coerce_to_record(RecordTypePtr rt, Val* v, const std::vector<int>& map); extern RecordValPtr coerce_to_record(RecordTypePtr rt, Val* v, const std::vector<int>& map);
class TableCoerceExpr final : public UnaryExpr class TableCoerceExpr final : public UnaryExpr {
{
public: public:
TableCoerceExpr(ExprPtr op, TableTypePtr r, bool type_check = true); TableCoerceExpr(ExprPtr op, TableTypePtr r, bool type_check = true);
~TableCoerceExpr() override = default; ~TableCoerceExpr() override = default;
@ -1350,10 +1295,9 @@ public:
protected: protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class VectorCoerceExpr final : public UnaryExpr class VectorCoerceExpr final : public UnaryExpr {
{
public: public:
VectorCoerceExpr(ExprPtr op, VectorTypePtr v); VectorCoerceExpr(ExprPtr op, VectorTypePtr v);
~VectorCoerceExpr() override = default; ~VectorCoerceExpr() override = default;
@ -1363,10 +1307,9 @@ public:
protected: protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class ScheduleTimer final : public Timer class ScheduleTimer final : public Timer {
{
public: public:
ScheduleTimer(const EventHandlerPtr& event, zeek::Args args, double t); ScheduleTimer(const EventHandlerPtr& event, zeek::Args args, double t);
~ScheduleTimer() override = default; ~ScheduleTimer() override = default;
@ -1376,10 +1319,9 @@ public:
protected: protected:
EventHandlerPtr event; EventHandlerPtr event;
zeek::Args args; zeek::Args args;
}; };
class ScheduleExpr final : public Expr class ScheduleExpr final : public Expr {
{
public: public:
ScheduleExpr(ExprPtr when, EventExprPtr event); ScheduleExpr(ExprPtr when, EventExprPtr event);
@ -1411,10 +1353,9 @@ protected:
ExprPtr when; ExprPtr when;
EventExprPtr event; EventExprPtr event;
}; };
class InExpr final : public BinaryExpr class InExpr final : public BinaryExpr {
{
public: public:
InExpr(ExprPtr op1, ExprPtr op2); InExpr(ExprPtr op1, ExprPtr op2);
@ -1425,10 +1366,9 @@ public:
protected: protected:
ValPtr Fold(Val* v1, Val* v2) const override; ValPtr Fold(Val* v1, Val* v2) const override;
}; };
class CallExpr final : public Expr class CallExpr final : public Expr {
{
public: public:
CallExpr(ExprPtr func, ListExprPtr args, bool in_hook = false, bool in_when = false); CallExpr(ExprPtr func, ListExprPtr args, bool in_hook = false, bool in_when = false);
@ -1458,15 +1398,14 @@ protected:
ExprPtr func; ExprPtr func;
ListExprPtr args; ListExprPtr args;
bool in_when; bool in_when;
}; };
/** /**
* Class that represents an anonymous function expression in Zeek. * Class that represents an anonymous function expression in Zeek.
* On evaluation, captures the frame that it is evaluated in. This becomes * On evaluation, captures the frame that it is evaluated in. This becomes
* the closure for the instance of the function that it creates. * the closure for the instance of the function that it creates.
*/ */
class LambdaExpr final : public Expr class LambdaExpr final : public Expr {
{
public: public:
LambdaExpr(FunctionIngredientsPtr ingredients, IDPList outer_ids, std::string name = "", LambdaExpr(FunctionIngredientsPtr ingredients, IDPList outer_ids, std::string name = "",
StmtPtr when_parent = nullptr); StmtPtr when_parent = nullptr);
@ -1528,12 +1467,11 @@ private:
IDSet private_captures; IDSet private_captures;
std::string my_name; std::string my_name;
}; };
// This comes before EventExpr so that EventExpr::GetOp1 can return its // This comes before EventExpr so that EventExpr::GetOp1 can return its
// arguments as convertible to ExprPtr. // arguments as convertible to ExprPtr.
class ListExpr : public Expr class ListExpr : public Expr {
{
public: public:
ListExpr(); ListExpr();
explicit ListExpr(ExprPtr e); explicit ListExpr(ExprPtr e);
@ -1568,10 +1506,9 @@ protected:
void ExprDescribe(ODesc* d) const override; void ExprDescribe(ODesc* d) const override;
ExprPList exprs; ExprPList exprs;
}; };
class EventExpr final : public Expr class EventExpr final : public Expr {
{
public: public:
EventExpr(const char* name, ListExprPtr args); EventExpr(const char* name, ListExprPtr args);
@ -1600,16 +1537,14 @@ protected:
std::string name; std::string name;
EventHandlerPtr handler; EventHandlerPtr handler;
ListExprPtr args; ListExprPtr args;
}; };
class RecordAssignExpr final : public ListExpr class RecordAssignExpr final : public ListExpr {
{
public: public:
RecordAssignExpr(const ExprPtr& record, const ExprPtr& init_list, bool is_init); RecordAssignExpr(const ExprPtr& record, const ExprPtr& init_list, bool is_init);
}; };
class CastExpr final : public UnaryExpr class CastExpr final : public UnaryExpr {
{
public: public:
CastExpr(ExprPtr op, TypePtr t); CastExpr(ExprPtr op, TypePtr t);
@ -1619,14 +1554,13 @@ public:
protected: protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
void ExprDescribe(ODesc* d) const override; void ExprDescribe(ODesc* d) const override;
}; };
// Returns the value 'v' cast to type 't'. On an error, returns nil // Returns the value 'v' cast to type 't'. On an error, returns nil
// and populates "error" with an error message. // and populates "error" with an error message.
extern ValPtr cast_value(ValPtr v, const TypePtr& t, std::string& error); extern ValPtr cast_value(ValPtr v, const TypePtr& t, std::string& error);
class IsExpr final : public UnaryExpr class IsExpr final : public UnaryExpr {
{
public: public:
IsExpr(ExprPtr op, TypePtr t); IsExpr(ExprPtr op, TypePtr t);
@ -1641,13 +1575,11 @@ protected:
private: private:
TypePtr t; TypePtr t;
}; };
class InlineExpr : public Expr class InlineExpr : public Expr {
{
public: public:
InlineExpr(ListExprPtr arg_args, std::vector<IDPtr> params, StmtPtr body, int frame_offset, InlineExpr(ListExprPtr arg_args, std::vector<IDPtr> params, StmtPtr body, int frame_offset, TypePtr ret_type);
TypePtr ret_type);
bool IsPure() const override; bool IsPure() const override;
@ -1672,12 +1604,11 @@ protected:
int frame_offset; int frame_offset;
ListExprPtr args; ListExprPtr args;
StmtPtr body; StmtPtr body;
}; };
// A companion to AddToExpr that's for vector-append, instantiated during // A companion to AddToExpr that's for vector-append, instantiated during
// the reduction process. // the reduction process.
class AppendToExpr : public BinaryExpr class AppendToExpr : public BinaryExpr {
{
public: public:
AppendToExpr(ExprPtr op1, ExprPtr op2); AppendToExpr(ExprPtr op1, ExprPtr op2);
ValPtr Eval(Frame* f) const override; ValPtr Eval(Frame* f) const override;
@ -1688,11 +1619,10 @@ public:
bool IsReduced(Reducer* c) const override; bool IsReduced(Reducer* c) const override;
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override;
}; };
// An internal class for reduced form. // An internal class for reduced form.
class IndexAssignExpr : public BinaryExpr class IndexAssignExpr : public BinaryExpr {
{
public: public:
// "op1[op2] = op3", all reduced. // "op1[op2] = op3", all reduced.
IndexAssignExpr(ExprPtr op1, ExprPtr op2, ExprPtr op3); IndexAssignExpr(ExprPtr op1, ExprPtr op2, ExprPtr op3);
@ -1716,11 +1646,10 @@ protected:
void ExprDescribe(ODesc* d) const override; void ExprDescribe(ODesc* d) const override;
ExprPtr op3; // assignment RHS ExprPtr op3; // assignment RHS
}; };
// An internal class for reduced form. // An internal class for reduced form.
class FieldLHSAssignExpr : public BinaryExpr class FieldLHSAssignExpr : public BinaryExpr {
{
public: public:
// "op1$field = RHS", where RHS is reduced with respect to // "op1$field = RHS", where RHS is reduced with respect to
// ReduceToFieldAssignment(). // ReduceToFieldAssignment().
@ -1744,12 +1673,11 @@ protected:
const char* field_name; const char* field_name;
int field; int field;
}; };
// Expression to explicitly capture conversion to an "any" type, rather // Expression to explicitly capture conversion to an "any" type, rather
// than it occurring implicitly during script interpretation. // than it occurring implicitly during script interpretation.
class CoerceToAnyExpr : public UnaryExpr class CoerceToAnyExpr : public UnaryExpr {
{
public: public:
CoerceToAnyExpr(ExprPtr op); CoerceToAnyExpr(ExprPtr op);
@ -1757,11 +1685,10 @@ protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
ExprPtr Duplicate() override; ExprPtr Duplicate() override;
}; };
// Same, but for conversion from an "any" type. // Same, but for conversion from an "any" type.
class CoerceFromAnyExpr : public UnaryExpr class CoerceFromAnyExpr : public UnaryExpr {
{
public: public:
CoerceFromAnyExpr(ExprPtr op, TypePtr to_type); CoerceFromAnyExpr(ExprPtr op, TypePtr to_type);
@ -1769,11 +1696,10 @@ protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
ExprPtr Duplicate() override; ExprPtr Duplicate() override;
}; };
// ... and for conversion from a "vector of any" type. // ... and for conversion from a "vector of any" type.
class CoerceFromAnyVecExpr : public UnaryExpr class CoerceFromAnyVecExpr : public UnaryExpr {
{
public: public:
// to_type is yield type, not VectorType. // to_type is yield type, not VectorType.
CoerceFromAnyVecExpr(ExprPtr op, TypePtr to_type); CoerceFromAnyVecExpr(ExprPtr op, TypePtr to_type);
@ -1784,11 +1710,10 @@ public:
protected: protected:
ExprPtr Duplicate() override; ExprPtr Duplicate() override;
}; };
// Expression used to explicitly capture [a, b, c, ...] = x assignments. // Expression used to explicitly capture [a, b, c, ...] = x assignments.
class AnyIndexExpr : public UnaryExpr class AnyIndexExpr : public UnaryExpr {
{
public: public:
AnyIndexExpr(ExprPtr op, int index); AnyIndexExpr(ExprPtr op, int index);
@ -1803,13 +1728,12 @@ protected:
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
int index; int index;
}; };
// Used internally for optimization, when a placeholder is needed. // Used internally for optimization, when a placeholder is needed.
class NopExpr : public Expr class NopExpr : public Expr {
{
public: public:
explicit NopExpr() : Expr(EXPR_NOP) { } explicit NopExpr() : Expr(EXPR_NOP) {}
ValPtr Eval(Frame* f) const override; ValPtr Eval(Frame* f) const override;
@ -1819,18 +1743,17 @@ public:
protected: protected:
void ExprDescribe(ODesc* d) const override; void ExprDescribe(ODesc* d) const override;
}; };
// Assigns v1[v2] = v3. Returns an error message, or nullptr on success. // Assigns v1[v2] = v3. Returns an error message, or nullptr on success.
// Factored out so that compiled code can call it as well as the interpreter. // Factored out so that compiled code can call it as well as the interpreter.
extern const char* assign_to_index(ValPtr v1, ValPtr v2, ValPtr v3, bool& iterators_invalidated); extern const char* assign_to_index(ValPtr v1, ValPtr v2, ValPtr v3, bool& iterators_invalidated);
inline Val* Expr::ExprVal() const inline Val* Expr::ExprVal() const {
{
if ( ! IsConst() ) if ( ! IsConst() )
BadTag("ExprVal::Val", expr_name(tag), expr_name(EXPR_CONST)); BadTag("ExprVal::Val", expr_name(tag), expr_name(EXPR_CONST));
return ((ConstExpr*)this)->Value(); return ((ConstExpr*)this)->Value();
} }
// Decides whether to return an AssignExpr or a RecordAssignExpr. // Decides whether to return an AssignExpr or a RecordAssignExpr.
extern ExprPtr get_assign_expr(ExprPtr op1, ExprPtr op2, bool is_init); extern ExprPtr get_assign_expr(ExprPtr op1, ExprPtr op2, bool is_init);
@ -1869,25 +1792,13 @@ extern std::optional<std::vector<ValPtr>> eval_list(Frame* f, const ListExpr* l)
extern bool expr_greater(const Expr* e1, const Expr* e2); extern bool expr_greater(const Expr* e1, const Expr* e2);
// True if the given Expr* has a vector type // True if the given Expr* has a vector type
inline bool is_vector(Expr* e) inline bool is_vector(Expr* e) { return e->GetType()->Tag() == TYPE_VECTOR; }
{ inline bool is_vector(const ExprPtr& e) { return is_vector(e.get()); }
return e->GetType()->Tag() == TYPE_VECTOR;
}
inline bool is_vector(const ExprPtr& e)
{
return is_vector(e.get());
}
// True if the given Expr* has a list type // True if the given Expr* has a list type
inline bool is_list(Expr* e) inline bool is_list(Expr* e) { return e->GetType()->Tag() == TYPE_LIST; }
{
return e->GetType()->Tag() == TYPE_LIST;
}
inline bool is_list(const ExprPtr& e) inline bool is_list(const ExprPtr& e) { return is_list(e.get()); }
{
return is_list(e.get());
}
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -31,20 +31,17 @@
#include "zeek/Type.h" #include "zeek/Type.h"
#include "zeek/Var.h" #include "zeek/Var.h"
namespace zeek namespace zeek {
{
std::list<std::pair<std::string, File*>> File::open_files; std::list<std::pair<std::string, File*>> File::open_files;
// Maximizes the number of open file descriptors. // Maximizes the number of open file descriptors.
static void maximize_num_fds() static void maximize_num_fds() {
{
struct rlimit rl; struct rlimit rl;
if ( getrlimit(RLIMIT_NOFILE, &rl) < 0 ) if ( getrlimit(RLIMIT_NOFILE, &rl) < 0 )
reporter->FatalError("maximize_num_fds(): getrlimit failed"); reporter->FatalError("maximize_num_fds(): getrlimit failed");
if ( rl.rlim_max == RLIM_INFINITY ) if ( rl.rlim_max == RLIM_INFINITY ) {
{
// Don't try raising the current limit. // Don't try raising the current limit.
return; return;
} }
@ -54,29 +51,26 @@ static void maximize_num_fds()
if ( setrlimit(RLIMIT_NOFILE, &rl) < 0 ) if ( setrlimit(RLIMIT_NOFILE, &rl) < 0 )
reporter->FatalError("maximize_num_fds(): setrlimit failed"); reporter->FatalError("maximize_num_fds(): setrlimit failed");
} }
File::File(FILE* arg_f) File::File(FILE* arg_f) {
{
Init(); Init();
f = arg_f; f = arg_f;
name = access = nullptr; name = access = nullptr;
t = base_type(TYPE_STRING); t = base_type(TYPE_STRING);
is_open = (f != nullptr); is_open = (f != nullptr);
} }
File::File(FILE* arg_f, const char* arg_name, const char* arg_access) File::File(FILE* arg_f, const char* arg_name, const char* arg_access) {
{
Init(); Init();
f = arg_f; f = arg_f;
name = util::copy_string(arg_name); name = util::copy_string(arg_name);
access = util::copy_string(arg_access); access = util::copy_string(arg_access);
t = base_type(TYPE_STRING); t = base_type(TYPE_STRING);
is_open = (f != nullptr); is_open = (f != nullptr);
} }
File::File(const char* arg_name, const char* arg_access) File::File(const char* arg_name, const char* arg_access) {
{
Init(); Init();
f = nullptr; f = nullptr;
name = util::copy_string(arg_name); name = util::copy_string(arg_name);
@ -93,15 +87,13 @@ File::File(const char* arg_name, const char* arg_access)
if ( f ) if ( f )
is_open = true; is_open = true;
else if ( ! Open() ) else if ( ! Open() ) {
{
reporter->Error("cannot open %s: %s", name, strerror(errno)); reporter->Error("cannot open %s: %s", name, strerror(errno));
is_open = false; is_open = false;
} }
} }
const char* File::Name() const const char* File::Name() const {
{
if ( name ) if ( name )
return name; return name;
@ -115,15 +107,13 @@ const char* File::Name() const
return "/dev/stderr"; return "/dev/stderr";
return nullptr; return nullptr;
} }
bool File::Open(FILE* file, const char* mode) bool File::Open(FILE* file, const char* mode) {
{
static bool fds_maximized = false; static bool fds_maximized = false;
open_time = run_state::network_time ? run_state::network_time : util::current_time(); open_time = run_state::network_time ? run_state::network_time : util::current_time();
if ( ! fds_maximized ) if ( ! fds_maximized ) {
{
// Haven't initialized yet. // Haven't initialized yet.
maximize_num_fds(); maximize_num_fds();
fds_maximized = true; fds_maximized = true;
@ -131,8 +121,7 @@ bool File::Open(FILE* file, const char* mode)
f = file; f = file;
if ( ! f ) if ( ! f ) {
{
if ( ! mode ) if ( ! mode )
f = fopen(name, access); f = fopen(name, access);
else else
@ -141,8 +130,7 @@ bool File::Open(FILE* file, const char* mode)
SetBuf(buffered); SetBuf(buffered);
if ( ! f ) if ( ! f ) {
{
is_open = false; is_open = false;
return false; return false;
} }
@ -153,10 +141,9 @@ bool File::Open(FILE* file, const char* mode)
RaiseOpenEvent(); RaiseOpenEvent();
return true; return true;
} }
File::~File() File::~File() {
{
Close(); Close();
Unref(attrs); Unref(attrs);
@ -166,10 +153,9 @@ File::~File()
#ifdef USE_PERFTOOLS_DEBUG #ifdef USE_PERFTOOLS_DEBUG
heap_checker->UnIgnoreObject(this); heap_checker->UnIgnoreObject(this);
#endif #endif
} }
void File::Init() void File::Init() {
{
open_time = 0; open_time = 0;
is_open = false; is_open = false;
attrs = nullptr; attrs = nullptr;
@ -179,15 +165,11 @@ void File::Init()
#ifdef USE_PERFTOOLS_DEBUG #ifdef USE_PERFTOOLS_DEBUG
heap_checker->IgnoreObject(this); heap_checker->IgnoreObject(this);
#endif #endif
} }
FILE* File::FileHandle() FILE* File::FileHandle() { return f; }
{
return f;
}
FILE* File::Seek(long new_position) FILE* File::Seek(long new_position) {
{
if ( ! FileHandle() ) if ( ! FileHandle() )
return nullptr; return nullptr;
@ -195,10 +177,9 @@ FILE* File::Seek(long new_position)
reporter->Error("seek failed"); reporter->Error("seek failed");
return f; return f;
} }
void File::SetBuf(bool arg_buffered) void File::SetBuf(bool arg_buffered) {
{
if ( ! f ) if ( ! f )
return; return;
@ -206,10 +187,9 @@ void File::SetBuf(bool arg_buffered)
reporter->Error("setvbuf failed"); reporter->Error("setvbuf failed");
buffered = arg_buffered; buffered = arg_buffered;
} }
bool File::Close() bool File::Close() {
{
if ( ! is_open ) if ( ! is_open )
return true; return true;
@ -228,26 +208,21 @@ bool File::Close()
Unlink(); Unlink();
return true; return true;
} }
void File::Unlink() void File::Unlink() {
{ for ( auto it = open_files.begin(); it != open_files.end(); ++it ) {
for ( auto it = open_files.begin(); it != open_files.end(); ++it ) if ( (*it).second == this ) {
{
if ( (*it).second == this )
{
open_files.erase(it); open_files.erase(it);
return; return;
} }
} }
} }
void File::Describe(ODesc* d) const void File::Describe(ODesc* d) const {
{
d->AddSP("file"); d->AddSP("file");
if ( name ) if ( name ) {
{
d->Add("\""); d->Add("\"");
d->Add(name); d->Add(name);
d->AddSP("\""); d->AddSP("\"");
@ -258,10 +233,9 @@ void File::Describe(ODesc* d) const
t->Describe(d); t->Describe(d);
else else
d->Add("(no type)"); d->Add("(no type)");
} }
void File::SetAttrs(detail::Attributes* arg_attrs) void File::SetAttrs(detail::Attributes* arg_attrs) {
{
if ( ! arg_attrs ) if ( ! arg_attrs )
return; return;
@ -270,10 +244,9 @@ void File::SetAttrs(detail::Attributes* arg_attrs)
if ( attrs->Find(detail::ATTR_RAW_OUTPUT) ) if ( attrs->Find(detail::ATTR_RAW_OUTPUT) )
EnableRawOutput(); EnableRawOutput();
} }
RecordVal* File::Rotate() RecordVal* File::Rotate() {
{
if ( ! is_open ) if ( ! is_open )
return nullptr; return nullptr;
@ -285,8 +258,7 @@ RecordVal* File::Rotate()
auto* info = new RecordVal(rotate_info); auto* info = new RecordVal(rotate_info);
FILE* newf = util::detail::rotate_file(name, info); FILE* newf = util::detail::rotate_file(name, info);
if ( ! newf ) if ( ! newf ) {
{
Unref(info); Unref(info);
return nullptr; return nullptr;
} }
@ -300,20 +272,17 @@ RecordVal* File::Rotate()
Open(newf); Open(newf);
return info; return info;
} }
void File::CloseOpenFiles() void File::CloseOpenFiles() {
{
auto it = open_files.begin(); auto it = open_files.begin();
while ( it != open_files.end() ) while ( it != open_files.end() ) {
{
auto el = it++; auto el = it++;
(*el).second->Close(); (*el).second->Close();
} }
} }
bool File::Write(const char* data, int len) bool File::Write(const char* data, int len) {
{
if ( ! is_open ) if ( ! is_open )
return false; return false;
@ -324,38 +293,34 @@ bool File::Write(const char* data, int len)
return false; return false;
return true; return true;
} }
void File::RaiseOpenEvent() void File::RaiseOpenEvent() {
{
if ( ! ::file_opened ) if ( ! ::file_opened )
return; return;
FilePtr bf{NewRef{}, this}; FilePtr bf{NewRef{}, this};
auto* event = new Event(::file_opened, {make_intrusive<FileVal>(std::move(bf))}); auto* event = new Event(::file_opened, {make_intrusive<FileVal>(std::move(bf))});
event_mgr.Dispatch(event, true); event_mgr.Dispatch(event, true);
} }
double File::Size() double File::Size() {
{
fflush(f); fflush(f);
struct stat s; struct stat s;
if ( fstat(fileno(f), &s) < 0 ) if ( fstat(fileno(f), &s) < 0 ) {
{
reporter->Error("can't stat fd for %s: %s", name, strerror(errno)); reporter->Error("can't stat fd for %s: %s", name, strerror(errno));
return 0; return 0;
} }
return s.st_size; return s.st_size;
} }
FilePtr File::Get(const char* name) FilePtr File::Get(const char* name) {
{
for ( const auto& el : open_files ) for ( const auto& el : open_files )
if ( el.first == name ) if ( el.first == name )
return {NewRef{}, el.second}; return {NewRef{}, el.second};
return make_intrusive<File>(name, "w"); return make_intrusive<File>(name, "w");
} }
} // namespace zeek } // namespace zeek

View file

@ -10,18 +10,16 @@
#include "zeek/Val.h" #include "zeek/Val.h"
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek namespace zeek {
{
namespace detail namespace detail {
{
class PrintStmt; class PrintStmt;
class Attributes; class Attributes;
extern void do_print_stmt(const std::vector<ValPtr>& vals); extern void do_print_stmt(const std::vector<ValPtr>& vals);
} // namespace detail; } // namespace detail
class RecordVal; class RecordVal;
class Type; class Type;
@ -30,8 +28,7 @@ using TypePtr = IntrusivePtr<Type>;
class File; class File;
using FilePtr = IntrusivePtr<File>; using FilePtr = IntrusivePtr<File>;
class File final : public Obj class File final : public Obj {
{
public: public:
explicit File(FILE* arg_f); explicit File(FILE* arg_f);
File(FILE* arg_f, const char* filename, const char* access); File(FILE* arg_f, const char* filename, const char* access);
@ -118,6 +115,6 @@ protected:
private: private:
static std::list<std::pair<std::string, File*>> open_files; static std::list<std::pair<std::string, File*>> open_files;
}; };
} // namespace zeek } // namespace zeek

View file

@ -13,12 +13,10 @@
#include <winsock2.h> #include <winsock2.h>
#define fatalError(...) \ #define fatalError(...) \
do \ do { \
{ \
if ( reporter ) \ if ( reporter ) \
reporter->FatalError(__VA_ARGS__); \ reporter->FatalError(__VA_ARGS__); \
else \ else { \
{ \
fprintf(stderr, __VA_ARGS__); \ fprintf(stderr, __VA_ARGS__); \
fprintf(stderr, "\n"); \ fprintf(stderr, "\n"); \
_exit(1); \ _exit(1); \
@ -27,26 +25,22 @@
#endif #endif
namespace zeek::detail namespace zeek::detail {
{
Flare::Flare() Flare::Flare()
#ifndef _MSC_VER #ifndef _MSC_VER
: pipe(FD_CLOEXEC, FD_CLOEXEC, O_NONBLOCK, O_NONBLOCK) : pipe(FD_CLOEXEC, FD_CLOEXEC, O_NONBLOCK, O_NONBLOCK) {
{ }
}
#else #else
{ {
WSADATA wsaData; WSADATA wsaData;
if ( WSAStartup(MAKEWORD(2, 2), &wsaData) != 0 ) if ( WSAStartup(MAKEWORD(2, 2), &wsaData) != 0 )
fatalError("WSAStartup failure: %d", WSAGetLastError()); fatalError("WSAStartup failure: %d", WSAGetLastError());
recvfd = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, nullptr, 0, recvfd = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, nullptr, 0, WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT);
WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT);
if ( recvfd == (int)INVALID_SOCKET ) if ( recvfd == (int)INVALID_SOCKET )
fatalError("WSASocket failure: %d", WSAGetLastError()); fatalError("WSASocket failure: %d", WSAGetLastError());
sendfd = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, nullptr, 0, sendfd = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, nullptr, 0, WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT);
WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT);
if ( sendfd == (int)INVALID_SOCKET ) if ( sendfd == (int)INVALID_SOCKET )
fatalError("WSASocket failure: %d", WSAGetLastError()); fatalError("WSASocket failure: %d", WSAGetLastError());
@ -61,11 +55,10 @@ Flare::Flare()
fatalError("getsockname failure: %d", WSAGetLastError()); fatalError("getsockname failure: %d", WSAGetLastError());
if ( connect(sendfd, (sockaddr*)&sa, sizeof(sa)) == SOCKET_ERROR ) if ( connect(sendfd, (sockaddr*)&sa, sizeof(sa)) == SOCKET_ERROR )
fatalError("connect failure: %d", WSAGetLastError()); fatalError("connect failure: %d", WSAGetLastError());
} }
#endif #endif
[[noreturn]] static void bad_pipe_op(const char* which, bool signal_safe) [[noreturn]] static void bad_pipe_op(const char* which, bool signal_safe) {
{
if ( signal_safe ) if ( signal_safe )
abort(); abort();
@ -74,19 +67,16 @@ Flare::Flare()
if ( reporter ) if ( reporter )
reporter->FatalErrorWithCore("unexpected pipe %s failure: %s", which, buf); reporter->FatalErrorWithCore("unexpected pipe %s failure: %s", which, buf);
else else {
{
fprintf(stderr, "unexpected pipe %s failure: %s", which, buf); fprintf(stderr, "unexpected pipe %s failure: %s", which, buf);
abort(); abort();
} }
} }
void Flare::Fire(bool signal_safe) void Flare::Fire(bool signal_safe) {
{
char tmp = 0; char tmp = 0;
for ( ;; ) for ( ;; ) {
{
#ifndef _MSC_VER #ifndef _MSC_VER
int n = write(pipe.WriteFD(), &tmp, 1); int n = write(pipe.WriteFD(), &tmp, 1);
@ -97,8 +87,7 @@ void Flare::Fire(bool signal_safe)
// Success -- wrote a byte to pipe. // Success -- wrote a byte to pipe.
break; break;
if ( n < 0 ) if ( n < 0 ) {
{
#ifdef _MSC_VER #ifdef _MSC_VER
errno = WSAGetLastError(); errno = WSAGetLastError();
bad_pipe_op("send", signal_safe); bad_pipe_op("send", signal_safe);
@ -116,22 +105,19 @@ void Flare::Fire(bool signal_safe)
// No error, but didn't write a byte: try again. // No error, but didn't write a byte: try again.
} }
} }
int Flare::Extinguish(bool signal_safe) int Flare::Extinguish(bool signal_safe) {
{
int rval = 0; int rval = 0;
char tmp[256]; char tmp[256];
for ( ;; ) for ( ;; ) {
{
#ifndef _MSC_VER #ifndef _MSC_VER
int n = read(pipe.ReadFD(), &tmp, sizeof(tmp)); int n = read(pipe.ReadFD(), &tmp, sizeof(tmp));
#else #else
int n = recv(recvfd, tmp, sizeof(tmp), 0); int n = recv(recvfd, tmp, sizeof(tmp), 0);
#endif #endif
if ( n >= 0 ) if ( n >= 0 ) {
{
rval += n; rval += n;
// Pipe may not be empty yet: try again. // Pipe may not be empty yet: try again.
continue; continue;
@ -154,6 +140,6 @@ int Flare::Extinguish(bool signal_safe)
} }
return rval; return rval;
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -6,11 +6,9 @@
#include "Pipe.h" #include "Pipe.h"
#endif #endif
namespace zeek::detail namespace zeek::detail {
{
class Flare class Flare {
{
public: public:
/** /**
* Create a flare object that can be used to signal a "ready" status via * Create a flare object that can be used to signal a "ready" status via
@ -57,6 +55,6 @@ private:
#else #else
int sendfd, recvfd; int sendfd, recvfd;
#endif #endif
}; };
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -14,40 +14,34 @@
constexpr uint32_t MIN_ACCEPTABLE_FRAG_SIZE = 64; constexpr uint32_t MIN_ACCEPTABLE_FRAG_SIZE = 64;
constexpr uint32_t MAX_ACCEPTABLE_FRAG_SIZE = 64000; constexpr uint32_t MAX_ACCEPTABLE_FRAG_SIZE = 64000;
namespace zeek::detail namespace zeek::detail {
{
FragTimer::~FragTimer() FragTimer::~FragTimer() {
{
if ( f ) if ( f )
f->ClearTimer(); f->ClearTimer();
} }
void FragTimer::Dispatch(double t, bool /* is_expire */) void FragTimer::Dispatch(double t, bool /* is_expire */) {
{
if ( f ) if ( f )
f->Expire(t); f->Expire(t);
else else
reporter->InternalWarning("fragment timer dispatched w/o reassembler"); reporter->InternalWarning("fragment timer dispatched w/o reassembler");
} }
FragReassembler::FragReassembler(session::Manager* arg_s, const std::shared_ptr<IP_Hdr>& ip, FragReassembler::FragReassembler(session::Manager* arg_s, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt,
const u_char* pkt, const FragReassemblerKey& k, double t) const FragReassemblerKey& k, double t)
: Reassembler(0, REASSEM_FRAG) : Reassembler(0, REASSEM_FRAG) {
{
s = arg_s; s = arg_s;
key = k; key = k;
const struct ip* ip4 = ip->IP4_Hdr(); const struct ip* ip4 = ip->IP4_Hdr();
if ( ip4 ) if ( ip4 ) {
{
proto_hdr_len = ip->HdrLen(); proto_hdr_len = ip->HdrLen();
proto_hdr = new u_char[64]; // max IP header + slop proto_hdr = new u_char[64]; // max IP header + slop
// Don't do a structure copy - need to pick up options, too. // Don't do a structure copy - need to pick up options, too.
memcpy((void*)proto_hdr, (const void*)ip4, proto_hdr_len); memcpy((void*)proto_hdr, (const void*)ip4, proto_hdr_len);
} }
else else {
{
proto_hdr_len = ip->HdrLen() - 8; // minus length of fragment header proto_hdr_len = ip->HdrLen() - 8; // minus length of fragment header
proto_hdr = new u_char[proto_hdr_len]; proto_hdr = new u_char[proto_hdr_len];
memcpy(proto_hdr, ip->IP6_Hdr(), proto_hdr_len); memcpy(proto_hdr, ip->IP6_Hdr(), proto_hdr_len);
@ -57,8 +51,7 @@ FragReassembler::FragReassembler(session::Manager* arg_s, const std::shared_ptr<
frag_size = 0; // flag meaning "not known" frag_size = 0; // flag meaning "not known"
next_proto = ip->NextProto(); next_proto = ip->NextProto();
if ( frag_timeout != 0.0 ) if ( frag_timeout != 0.0 ) {
{
expire_timer = new FragTimer(this, t + frag_timeout); expire_timer = new FragTimer(this, t + frag_timeout);
timer_mgr->Add(expire_timer); timer_mgr->Add(expire_timer);
} }
@ -66,30 +59,25 @@ FragReassembler::FragReassembler(session::Manager* arg_s, const std::shared_ptr<
expire_timer = nullptr; expire_timer = nullptr;
AddFragment(t, ip, pkt); AddFragment(t, ip, pkt);
} }
FragReassembler::~FragReassembler() FragReassembler::~FragReassembler() {
{
DeleteTimer(); DeleteTimer();
delete[] proto_hdr; delete[] proto_hdr;
} }
void FragReassembler::AddFragment(double t, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt) void FragReassembler::AddFragment(double t, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt) {
{
const struct ip* ip4 = ip->IP4_Hdr(); const struct ip* ip4 = ip->IP4_Hdr();
if ( ip4 ) if ( ip4 ) {
{ if ( ip4->ip_p != ((const struct ip*)proto_hdr)->ip_p || ip4->ip_hl != ((const struct ip*)proto_hdr)->ip_hl )
if ( ip4->ip_p != ((const struct ip*)proto_hdr)->ip_p ||
ip4->ip_hl != ((const struct ip*)proto_hdr)->ip_hl )
// || ip4->ip_tos != proto_hdr->ip_tos // || ip4->ip_tos != proto_hdr->ip_tos
// don't check TOS, there's at least one stack that actually // don't check TOS, there's at least one stack that actually
// uses different values, and it's hard to see an associated // uses different values, and it's hard to see an associated
// attack. // attack.
s->Weird("fragment_protocol_inconsistency", ip.get()); s->Weird("fragment_protocol_inconsistency", ip.get());
} }
else else {
{
if ( ip->NextProto() != next_proto || ip->HdrLen() - 8 != proto_hdr_len ) if ( ip->NextProto() != next_proto || ip->HdrLen() - 8 != proto_hdr_len )
s->Weird("fragment_protocol_inconsistency", ip.get()); s->Weird("fragment_protocol_inconsistency", ip.get());
// TODO: more detailed unfrag header consistency checks? // TODO: more detailed unfrag header consistency checks?
@ -103,8 +91,7 @@ void FragReassembler::AddFragment(double t, const std::shared_ptr<IP_Hdr>& ip, c
uint32_t len = ip->TotalLen(); uint32_t len = ip->TotalLen();
uint16_t hdr_len = ip->HdrLen(); uint16_t hdr_len = ip->HdrLen();
if ( len < hdr_len ) if ( len < hdr_len ) {
{
s->Weird("fragment_protocol_inconsistency", ip.get()); s->Weird("fragment_protocol_inconsistency", ip.get());
return; return;
} }
@ -115,14 +102,12 @@ void FragReassembler::AddFragment(double t, const std::shared_ptr<IP_Hdr>& ip, c
// Make sure to use the first fragment header's next field. // Make sure to use the first fragment header's next field.
next_proto = ip->NextProto(); next_proto = ip->NextProto();
if ( ! ip->MF() ) if ( ! ip->MF() ) {
{
// Last fragment. // Last fragment.
if ( frag_size == 0 ) if ( frag_size == 0 )
frag_size = upper_seq; frag_size = upper_seq;
else if ( upper_seq != frag_size ) else if ( upper_seq != frag_size ) {
{
s->Weird("fragment_size_inconsistency", ip.get()); s->Weird("fragment_size_inconsistency", ip.get());
if ( upper_seq > frag_size ) if ( upper_seq > frag_size )
@ -136,8 +121,7 @@ void FragReassembler::AddFragment(double t, const std::shared_ptr<IP_Hdr>& ip, c
if ( upper_seq > MAX_ACCEPTABLE_FRAG_SIZE ) if ( upper_seq > MAX_ACCEPTABLE_FRAG_SIZE )
s->Weird("excessively_large_fragment", ip.get()); s->Weird("excessively_large_fragment", ip.get());
if ( frag_size && upper_seq > frag_size ) if ( frag_size && upper_seq > frag_size ) {
{
// This can happen if we receive a fragment that's *not* // This can happen if we receive a fragment that's *not*
// the last fragment, but still imputes a size that's // the last fragment, but still imputes a size that's
// larger than the size we derived from a previously-seen // larger than the size we derived from a previously-seen
@ -155,41 +139,35 @@ void FragReassembler::AddFragment(double t, const std::shared_ptr<IP_Hdr>& ip, c
len -= hdr_len; len -= hdr_len;
NewBlock(run_state::network_time, offset, len, pkt); NewBlock(run_state::network_time, offset, len, pkt);
} }
void FragReassembler::Weird(const char* name) const void FragReassembler::Weird(const char* name) const {
{
unsigned int version = ((const ip*)proto_hdr)->ip_v; unsigned int version = ((const ip*)proto_hdr)->ip_v;
if ( version == 4 ) if ( version == 4 ) {
{
IP_Hdr hdr((const ip*)proto_hdr, false); IP_Hdr hdr((const ip*)proto_hdr, false);
s->Weird(name, &hdr); s->Weird(name, &hdr);
} }
else if ( version == 6 ) else if ( version == 6 ) {
{
IP_Hdr hdr((const ip6_hdr*)proto_hdr, false, proto_hdr_len); IP_Hdr hdr((const ip6_hdr*)proto_hdr, false, proto_hdr_len);
s->Weird(name, &hdr); s->Weird(name, &hdr);
} }
else else {
{
reporter->InternalWarning("Unexpected IP version in FragReassembler"); reporter->InternalWarning("Unexpected IP version in FragReassembler");
reporter->Weird(name); reporter->Weird(name);
} }
} }
void FragReassembler::Overlap(const u_char* b1, const u_char* b2, uint64_t n) void FragReassembler::Overlap(const u_char* b1, const u_char* b2, uint64_t n) {
{
if ( memcmp((const void*)b1, (const void*)b2, n) ) if ( memcmp((const void*)b1, (const void*)b2, n) )
Weird("fragment_inconsistency"); Weird("fragment_inconsistency");
else else
Weird("fragment_overlap"); Weird("fragment_overlap");
} }
void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */) void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */) {
{
auto it = block_list.Begin(); auto it = block_list.Begin();
if ( it->second.seq > 0 || ! frag_size ) if ( it->second.seq > 0 || ! frag_size )
@ -199,8 +177,7 @@ void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */)
auto next = std::next(it); auto next = std::next(it);
// We might have it all - look for contiguous all the way. // We might have it all - look for contiguous all the way.
while ( next != block_list.End() ) while ( next != block_list.End() ) {
{
if ( it->second.upper != next->second.seq ) if ( it->second.upper != next->second.seq )
break; break;
@ -210,11 +187,9 @@ void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */)
const auto& last = block_list.LastBlock(); const auto& last = block_list.LastBlock();
if ( next != block_list.End() ) if ( next != block_list.End() ) {
{
// We have a hole. // We have a hole.
if ( it->second.upper >= frag_size ) if ( it->second.upper >= frag_size ) {
{
// We're stuck. The point where we stopped is // We're stuck. The point where we stopped is
// contiguous up through the expected end of // contiguous up through the expected end of
// the fragment, but there's more stuff still // the fragment, but there's more stuff still
@ -232,8 +207,7 @@ void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */)
return; return;
} }
else if ( last.upper > frag_size ) else if ( last.upper > frag_size ) {
{
Weird("fragment_size_inconsistency"); Weird("fragment_size_inconsistency");
frag_size = last.upper; frag_size = last.upper;
} }
@ -259,12 +233,10 @@ void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */)
pkt += proto_hdr_len; pkt += proto_hdr_len;
for ( it = block_list.Begin(); it != block_list.End(); ++it ) for ( it = block_list.Begin(); it != block_list.End(); ++it ) {
{
const auto& b = it->second; const auto& b = it->second;
if ( it != block_list.Begin() ) if ( it != block_list.Begin() ) {
{
const auto& prev = std::prev(it)->second; const auto& prev = std::prev(it)->second;
// If we're above a hole, stop. This can happen because // If we're above a hole, stop. This can happen because
@ -274,8 +246,7 @@ void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */)
break; break;
} }
if ( b.upper > n ) if ( b.upper > n ) {
{
reporter->InternalWarning("bad fragment reassembly"); reporter->InternalWarning("bad fragment reassembly");
DeleteTimer(); DeleteTimer();
Expire(run_state::network_time); Expire(run_state::network_time);
@ -290,16 +261,14 @@ void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */)
unsigned int version = ((const struct ip*)pkt_start)->ip_v; unsigned int version = ((const struct ip*)pkt_start)->ip_v;
if ( version == 4 ) if ( version == 4 ) {
{
struct ip* reassem4 = (struct ip*)pkt_start; struct ip* reassem4 = (struct ip*)pkt_start;
reassem4->ip_len = htons(frag_size + proto_hdr_len); reassem4->ip_len = htons(frag_size + proto_hdr_len);
reassembled_pkt = std::make_shared<IP_Hdr>(reassem4, true, true); reassembled_pkt = std::make_shared<IP_Hdr>(reassem4, true, true);
DeleteTimer(); DeleteTimer();
} }
else if ( version == 6 ) else if ( version == 6 ) {
{
struct ip6_hdr* reassem6 = (struct ip6_hdr*)pkt_start; struct ip6_hdr* reassem6 = (struct ip6_hdr*)pkt_start;
reassem6->ip6_plen = htons(frag_size + proto_hdr_len - 40); reassem6->ip6_plen = htons(frag_size + proto_hdr_len - 40);
const IPv6_Hdr_Chain* chain = new IPv6_Hdr_Chain(reassem6, next_proto, n); const IPv6_Hdr_Chain* chain = new IPv6_Hdr_Chain(reassem6, next_proto, n);
@ -307,40 +276,31 @@ void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */)
DeleteTimer(); DeleteTimer();
} }
else else {
{
reporter->InternalWarning("bad IP version in fragment reassembly: %d", version); reporter->InternalWarning("bad IP version in fragment reassembly: %d", version);
delete[] pkt_start; delete[] pkt_start;
} }
} }
void FragReassembler::Expire(double t) void FragReassembler::Expire(double t) {
{
block_list.Clear(); block_list.Clear();
expire_timer->ClearReassembler(); expire_timer->ClearReassembler();
expire_timer = nullptr; // timer manager will delete it expire_timer = nullptr; // timer manager will delete it
fragment_mgr->Remove(this); fragment_mgr->Remove(this);
} }
void FragReassembler::DeleteTimer() void FragReassembler::DeleteTimer() {
{ if ( expire_timer ) {
if ( expire_timer )
{
expire_timer->ClearReassembler(); expire_timer->ClearReassembler();
timer_mgr->Cancel(expire_timer); timer_mgr->Cancel(expire_timer);
expire_timer = nullptr; // timer manager will delete it expire_timer = nullptr; // timer manager will delete it
} }
} }
FragmentManager::~FragmentManager() FragmentManager::~FragmentManager() { Clear(); }
{
Clear();
}
FragReassembler* FragmentManager::NextFragment(double t, const std::shared_ptr<IP_Hdr>& ip, FragReassembler* FragmentManager::NextFragment(double t, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt) {
const u_char* pkt)
{
uint32_t frag_id = ip->ID(); uint32_t frag_id = ip->ID();
FragReassemblerKey key = std::make_tuple(ip->SrcAddr(), ip->DstAddr(), frag_id); FragReassemblerKey key = std::make_tuple(ip->SrcAddr(), ip->DstAddr(), frag_id);
@ -349,8 +309,7 @@ FragReassembler* FragmentManager::NextFragment(double t, const std::shared_ptr<I
if ( it != fragments.end() ) if ( it != fragments.end() )
f = it->second; f = it->second;
if ( ! f ) if ( ! f ) {
{
f = new FragReassembler(session_mgr, ip, pkt, key, t); f = new FragReassembler(session_mgr, ip, pkt, key, t);
fragments[key] = f; fragments[key] = f;
if ( fragments.size() > max_fragments ) if ( fragments.size() > max_fragments )
@ -360,18 +319,16 @@ FragReassembler* FragmentManager::NextFragment(double t, const std::shared_ptr<I
f->AddFragment(t, ip, pkt); f->AddFragment(t, ip, pkt);
return f; return f;
} }
void FragmentManager::Clear() void FragmentManager::Clear() {
{
for ( const auto& entry : fragments ) for ( const auto& entry : fragments )
Unref(entry.second); Unref(entry.second);
fragments.clear(); fragments.clear();
} }
void FragmentManager::Remove(detail::FragReassembler* f) void FragmentManager::Remove(detail::FragReassembler* f) {
{
if ( ! f ) if ( ! f )
return; return;
@ -379,6 +336,6 @@ void FragmentManager::Remove(detail::FragReassembler* f)
reporter->InternalWarning("fragment reassembler not in dict"); reporter->InternalWarning("fragment reassembler not in dict");
Unref(f); Unref(f);
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -10,26 +10,22 @@
#include "zeek/Timer.h" #include "zeek/Timer.h"
#include "zeek/util.h" // for zeek_uint_t #include "zeek/util.h" // for zeek_uint_t
namespace zeek namespace zeek {
{
class IP_Hdr; class IP_Hdr;
namespace session namespace session {
{
class Manager; class Manager;
} }
namespace detail namespace detail {
{
class FragReassembler; class FragReassembler;
class FragTimer; class FragTimer;
using FragReassemblerKey = std::tuple<IPAddr, IPAddr, zeek_uint_t>; using FragReassemblerKey = std::tuple<IPAddr, IPAddr, zeek_uint_t>;
class FragReassembler : public Reassembler class FragReassembler : public Reassembler {
{
public: public:
FragReassembler(session::Manager* s, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt, FragReassembler(session::Manager* s, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt,
const FragReassemblerKey& k, double t); const FragReassemblerKey& k, double t);
@ -58,10 +54,9 @@ protected:
uint16_t proto_hdr_len; uint16_t proto_hdr_len;
FragTimer* expire_timer; FragTimer* expire_timer;
}; };
class FragTimer final : public Timer class FragTimer final : public Timer {
{
public: public:
FragTimer(FragReassembler* arg_f, double arg_t) : Timer(arg_t, TIMER_FRAG) { f = arg_f; } FragTimer(FragReassembler* arg_f, double arg_t) : Timer(arg_t, TIMER_FRAG) { f = arg_f; }
~FragTimer() override; ~FragTimer() override;
@ -73,10 +68,9 @@ public:
protected: protected:
FragReassembler* f; FragReassembler* f;
}; };
class FragmentManager class FragmentManager {
{
public: public:
FragmentManager() = default; FragmentManager() = default;
~FragmentManager(); ~FragmentManager();
@ -92,20 +86,19 @@ private:
using FragmentMap = std::map<detail::FragReassemblerKey, detail::FragReassembler*>; using FragmentMap = std::map<detail::FragReassemblerKey, detail::FragReassembler*>;
FragmentMap fragments; FragmentMap fragments;
size_t max_fragments = 0; size_t max_fragments = 0;
}; };
extern FragmentManager* fragment_mgr; extern FragmentManager* fragment_mgr;
class FragReassemblerTracker class FragReassemblerTracker {
{
public: public:
FragReassemblerTracker(FragReassembler* f) : frag_reassembler(f) { } FragReassemblerTracker(FragReassembler* f) : frag_reassembler(f) {}
~FragReassemblerTracker() { fragment_mgr->Remove(frag_reassembler); } ~FragReassemblerTracker() { fragment_mgr->Remove(frag_reassembler); }
private: private:
FragReassembler* frag_reassembler; FragReassembler* frag_reassembler;
}; };
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -13,11 +13,9 @@
std::vector<zeek::detail::Frame*> g_frame_stack; std::vector<zeek::detail::Frame*> g_frame_stack;
namespace zeek::detail namespace zeek::detail {
{
Frame::Frame(int arg_size, const ScriptFunc* func, const zeek::Args* fn_args) Frame::Frame(int arg_size, const ScriptFunc* func, const zeek::Args* fn_args) {
{
size = arg_size; size = arg_size;
frame = std::make_unique<Element[]>(size); frame = std::make_unique<Element[]>(size);
function = func; function = func;
@ -38,59 +36,49 @@ Frame::Frame(int arg_size, const ScriptFunc* func, const zeek::Args* fn_args)
captures = function ? function->GetCapturesFrame() : nullptr; captures = function ? function->GetCapturesFrame() : nullptr;
captures_offset_map = function ? function->GetCapturesOffsetMap() : nullptr; captures_offset_map = function ? function->GetCapturesOffsetMap() : nullptr;
current_offset = 0; current_offset = 0;
} }
void Frame::SetElement(int n, ValPtr v) void Frame::SetElement(int n, ValPtr v) {
{
n += current_offset; n += current_offset;
ASSERT(n >= 0 && n < size); ASSERT(n >= 0 && n < size);
frame[n] = std::move(v); frame[n] = std::move(v);
} }
void Frame::SetElement(const ID* id, ValPtr v) void Frame::SetElement(const ID* id, ValPtr v) {
{ if ( captures ) {
if ( captures )
{
auto cap_off = captures_offset_map->find(id->Name()); auto cap_off = captures_offset_map->find(id->Name());
if ( cap_off != captures_offset_map->end() ) if ( cap_off != captures_offset_map->end() ) {
{
captures->SetElement(cap_off->second, std::move(v)); captures->SetElement(cap_off->second, std::move(v));
return; return;
} }
} }
SetElement(id->Offset(), std::move(v)); SetElement(id->Offset(), std::move(v));
} }
const ValPtr& Frame::GetElementByID(const ID* id) const const ValPtr& Frame::GetElementByID(const ID* id) const {
{ if ( captures ) {
if ( captures )
{
auto cap_off = captures_offset_map->find(id->Name()); auto cap_off = captures_offset_map->find(id->Name());
if ( cap_off != captures_offset_map->end() ) if ( cap_off != captures_offset_map->end() )
return captures->GetElement(cap_off->second); return captures->GetElement(cap_off->second);
} }
return frame[id->Offset() + current_offset]; return frame[id->Offset() + current_offset];
} }
void Frame::Reset(int startIdx) void Frame::Reset(int startIdx) {
{
for ( int i = startIdx + current_offset; i < size; ++i ) for ( int i = startIdx + current_offset; i < size; ++i )
frame[i] = nullptr; frame[i] = nullptr;
} }
void Frame::Describe(ODesc* d) const void Frame::Describe(ODesc* d) const {
{
if ( ! d->IsBinary() ) if ( ! d->IsBinary() )
d->AddSP("frame"); d->AddSP("frame");
if ( ! d->IsReadable() ) if ( ! d->IsReadable() ) {
{
d->Add(size); d->Add(size);
for ( int i = 0; i < size; ++i ) for ( int i = 0; i < size; ++i ) {
{
d->Add(frame[i] != nullptr); d->Add(frame[i] != nullptr);
d->SP(); d->SP();
} }
@ -101,10 +89,9 @@ void Frame::Describe(ODesc* d) const
frame[i]->Describe(d); frame[i]->Describe(d);
else if ( d->IsReadable() ) else if ( d->IsReadable() )
d->Add("<nil>"); d->Add("<nil>");
} }
Frame* Frame::Clone() const Frame* Frame::Clone() const {
{
Frame* other = new Frame(size, function, func_args); Frame* other = new Frame(size, function, func_args);
other->call = call; other->call = call;
@ -119,10 +106,9 @@ Frame* Frame::Clone() const
// since those get created fresh when constructing "other". // since those get created fresh when constructing "other".
return other; return other;
} }
Frame* Frame::CloneForTrigger() const Frame* Frame::CloneForTrigger() const {
{
Frame* other = new Frame(0, function, func_args); Frame* other = new Frame(0, function, func_args);
other->call = call; other->call = call;
@ -130,22 +116,19 @@ Frame* Frame::CloneForTrigger() const
other->trigger = trigger; other->trigger = trigger;
return other; return other;
} }
static bool val_is_func(const ValPtr& v, ScriptFunc* func) static bool val_is_func(const ValPtr& v, ScriptFunc* func) {
{
if ( v->GetType()->Tag() != TYPE_FUNC ) if ( v->GetType()->Tag() != TYPE_FUNC )
return false; return false;
return v->AsFunc() == func; return v->AsFunc() == func;
} }
broker::expected<broker::data> Frame::Serialize() broker::expected<broker::data> Frame::Serialize() {
{
broker::vector body; broker::vector body;
for ( int i = 0; i < size; ++i ) for ( int i = 0; i < size; ++i ) {
{
const auto& val = frame[i]; const auto& val = frame[i];
auto expected = Broker::detail::val_to_data(val.get()); auto expected = Broker::detail::val_to_data(val.get());
if ( ! expected ) if ( ! expected )
@ -160,10 +143,9 @@ broker::expected<broker::data> Frame::Serialize()
rval.emplace_back(std::move(body)); rval.emplace_back(std::move(body));
return {std::move(rval)}; return {std::move(rval)};
} }
std::pair<bool, FramePtr> Frame::Unserialize(const broker::vector& data) std::pair<bool, FramePtr> Frame::Unserialize(const broker::vector& data) {
{
if ( data.size() == 0 ) if ( data.size() == 0 )
return std::make_pair(true, nullptr); return std::make_pair(true, nullptr);
@ -176,8 +158,7 @@ std::pair<bool, FramePtr> Frame::Unserialize(const broker::vector& data)
int frame_size = body.size(); int frame_size = body.size();
auto rf = make_intrusive<Frame>(frame_size, nullptr, nullptr); auto rf = make_intrusive<Frame>(frame_size, nullptr, nullptr);
for ( int i = 0; i < frame_size; ++i ) for ( int i = 0; i < frame_size; ++i ) {
{
auto has_vec = broker::get_if<broker::vector>(body[i]); auto has_vec = broker::get_if<broker::vector>(body[i]);
if ( ! has_vec ) if ( ! has_vec )
continue; continue;
@ -201,23 +182,16 @@ std::pair<bool, FramePtr> Frame::Unserialize(const broker::vector& data)
} }
return std::make_pair(true, std::move(rf)); return std::make_pair(true, std::move(rf));
} }
const detail::Location* Frame::GetCallLocation() const const detail::Location* Frame::GetCallLocation() const {
{
// This is currently trivial, but we keep it as an explicit // This is currently trivial, but we keep it as an explicit
// method because it can provide flexibility for compiled code. // method because it can provide flexibility for compiled code.
return call->GetLocationInfo(); return call->GetLocationInfo();
} }
void Frame::SetTrigger(trigger::TriggerPtr arg_trigger) void Frame::SetTrigger(trigger::TriggerPtr arg_trigger) { trigger = std::move(arg_trigger); }
{
trigger = std::move(arg_trigger);
}
void Frame::ClearTrigger() void Frame::ClearTrigger() { trigger = nullptr; }
{
trigger = nullptr;
}
} } // namespace zeek::detail

View file

@ -17,31 +17,27 @@
#include "zeek/ZeekArgs.h" #include "zeek/ZeekArgs.h"
#include "zeek/ZeekList.h" // for typedef val_list #include "zeek/ZeekList.h" // for typedef val_list
namespace zeek namespace zeek {
{
using ValPtr = IntrusivePtr<Val>; using ValPtr = IntrusivePtr<Val>;
namespace detail namespace detail {
{
class CallExpr; class CallExpr;
class ScriptFunc; class ScriptFunc;
using IDPtr = IntrusivePtr<ID>; using IDPtr = IntrusivePtr<ID>;
namespace trigger namespace trigger {
{
class Trigger; class Trigger;
using TriggerPtr = IntrusivePtr<Trigger>; using TriggerPtr = IntrusivePtr<Trigger>;
} } // namespace trigger
class Frame; class Frame;
using FramePtr = IntrusivePtr<Frame>; using FramePtr = IntrusivePtr<Frame>;
class Frame : public Obj class Frame : public Obj {
{
public: public:
/** /**
* Constructs a new frame belonging to *func* with *fn_args* * Constructs a new frame belonging to *func* with *fn_args*
@ -64,8 +60,7 @@ public:
* @param n the index to get. * @param n the index to get.
* @return the value at index *n* of the underlying array. * @return the value at index *n* of the underlying array.
*/ */
const ValPtr& GetElement(int n) const const ValPtr& GetElement(int n) const {
{
// Note: technically this may want to adjust by current_offset, but // Note: technically this may want to adjust by current_offset, but
// in practice, this method is never called from anywhere other than // in practice, this method is never called from anywhere other than
// function call invocation, where current_offset should be zero. // function call invocation, where current_offset should be zero.
@ -193,8 +188,7 @@ public:
void ClearTrigger(); void ClearTrigger();
trigger::Trigger* GetTrigger() const { return trigger.get(); } trigger::Trigger* GetTrigger() const { return trigger.get(); }
void SetCall(const CallExpr* arg_call) void SetCall(const CallExpr* arg_call) {
{
call = arg_call; call = arg_call;
SetTriggerAssoc((void*)call); SetTriggerAssoc((void*)call);
} }
@ -257,10 +251,10 @@ private:
trigger::TriggerPtr trigger; trigger::TriggerPtr trigger;
const CallExpr* call = nullptr; const CallExpr* call = nullptr;
const void* assoc = nullptr; const void* assoc = nullptr;
}; };
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek
/** /**
* If we stopped using this and instead just made a struct of the information * If we stopped using this and instead just made a struct of the information

File diff suppressed because it is too large Load diff

View file

@ -18,21 +18,19 @@
#include "zeek/ZeekArgs.h" #include "zeek/ZeekArgs.h"
#include "zeek/ZeekList.h" #include "zeek/ZeekList.h"
namespace broker namespace broker {
{
class data; class data;
using vector = std::vector<data>; using vector = std::vector<data>;
template <class> class expected; template<class>
} class expected;
} // namespace broker
namespace zeek namespace zeek {
{
class Val; class Val;
class FuncType; class FuncType;
namespace detail namespace detail {
{
class Scope; class Scope;
class Stmt; class Stmt;
@ -46,7 +44,7 @@ using StmtPtr = IntrusivePtr<Stmt>;
class ScriptFunc; class ScriptFunc;
class FunctionIngredients; class FunctionIngredients;
} // namespace detail } // namespace detail
class EventGroup; class EventGroup;
using EventGroupPtr = std::shared_ptr<EventGroup>; using EventGroupPtr = std::shared_ptr<EventGroup>;
@ -54,24 +52,18 @@ using EventGroupPtr = std::shared_ptr<EventGroup>;
class Func; class Func;
using FuncPtr = IntrusivePtr<Func>; using FuncPtr = IntrusivePtr<Func>;
class Func : public Obj class Func : public Obj {
{
public: public:
static inline const FuncPtr nil; static inline const FuncPtr nil;
enum Kind enum Kind { SCRIPT_FUNC, BUILTIN_FUNC };
{
SCRIPT_FUNC,
BUILTIN_FUNC
};
explicit Func(Kind arg_kind) : kind(arg_kind) { } explicit Func(Kind arg_kind) : kind(arg_kind) {}
virtual bool IsPure() const = 0; virtual bool IsPure() const = 0;
FunctionFlavor Flavor() const { return GetType()->Flavor(); } FunctionFlavor Flavor() const { return GetType()->Flavor(); }
struct Body struct Body {
{
detail::StmtPtr stmts; detail::StmtPtr stmts;
int priority; int priority;
std::set<EventGroupPtr> groups; std::set<EventGroupPtr> groups;
@ -79,10 +71,7 @@ public:
// The disabled field is updated from EventGroup instances. // The disabled field is updated from EventGroup instances.
bool disabled = false; bool disabled = false;
bool operator<(const Body& other) const bool operator<(const Body& other) const { return priority > other.priority; } // reverse sort
{
return priority > other.priority;
} // reverse sort
}; };
const std::vector<Body>& GetBodies() const { return bodies; } const std::vector<Body>& GetBodies() const { return bodies; }
@ -106,11 +95,9 @@ public:
/** /**
* A version of Invoke() taking a variable number of individual arguments. * A version of Invoke() taking a variable number of individual arguments.
*/ */
template <class... Args> template<class... Args>
std::enable_if_t<std::is_convertible_v<std::tuple_element_t<0, std::tuple<Args...>>, ValPtr>, std::enable_if_t<std::is_convertible_v<std::tuple_element_t<0, std::tuple<Args...>>, ValPtr>, ValPtr> Invoke(
ValPtr> Args&&... args) const {
Invoke(Args&&... args) const
{
auto zargs = zeek::Args{std::forward<Args>(args)...}; auto zargs = zeek::Args{std::forward<Args>(args)...};
return Invoke(&zargs); return Invoke(&zargs);
} }
@ -121,11 +108,10 @@ public:
// as is a non-default second parameter to the first method, which // as is a non-default second parameter to the first method, which
// overrides the function body in "ingr". // overrides the function body in "ingr".
void AddBody(const detail::FunctionIngredients& ingr, detail::StmtPtr new_body = nullptr); void AddBody(const detail::FunctionIngredients& ingr, detail::StmtPtr new_body = nullptr);
virtual void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, virtual void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, size_t new_frame_size,
size_t new_frame_size, int priority, int priority, const std::set<EventGroupPtr>& groups);
const std::set<EventGroupPtr>& groups); void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, size_t new_frame_size,
void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, int priority = 0);
size_t new_frame_size, int priority = 0);
void AddBody(detail::StmtPtr new_body, size_t new_frame_size); void AddBody(detail::StmtPtr new_body, size_t new_frame_size);
virtual void SetScope(detail::ScopePtr newscope); virtual void SetScope(detail::ScopePtr newscope);
@ -166,19 +152,16 @@ private:
// expose accessors in the zeek:: public interface. // expose accessors in the zeek:: public interface.
friend class EventGroup; friend class EventGroup;
bool has_enabled_bodies = true; bool has_enabled_bodies = true;
}; };
namespace detail namespace detail {
{
class ScriptFunc : public Func class ScriptFunc : public Func {
{
public: public:
ScriptFunc(const IDPtr& id); ScriptFunc(const IDPtr& id);
// For compiled scripts. // For compiled scripts.
ScriptFunc(std::string name, FuncTypePtr ft, std::vector<StmtPtr> bodies, ScriptFunc(std::string name, FuncTypePtr ft, std::vector<StmtPtr> bodies, std::vector<int> priorities);
std::vector<int> priorities);
~ScriptFunc() override; ~ScriptFunc() override;
@ -220,8 +203,7 @@ public:
* *
* @return internal vector of ZVal's kept for persisting captures * @return internal vector of ZVal's kept for persisting captures
*/ */
auto& GetCapturesVec() const auto& GetCapturesVec() const {
{
ASSERT(captures_vec); ASSERT(captures_vec);
return *captures_vec; return *captures_vec;
} }
@ -252,9 +234,8 @@ public:
using Func::AddBody; using Func::AddBody;
void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, size_t new_frame_size,
size_t new_frame_size, int priority, int priority, const std::set<EventGroupPtr>& groups) override;
const std::set<EventGroupPtr>& groups) override;
/** /**
* Replaces the given current instance of a function body with * Replaces the given current instance of a function body with
@ -289,7 +270,7 @@ public:
void Describe(ODesc* d) const override; void Describe(ODesc* d) const override;
protected: protected:
ScriptFunc() : Func(SCRIPT_FUNC) { } ScriptFunc() : Func(SCRIPT_FUNC) {}
StmtPtr AddInits(StmtPtr body, const std::vector<IDPtr>& inits); StmtPtr AddInits(StmtPtr body, const std::vector<IDPtr>& inits);
@ -328,12 +309,11 @@ private:
// ... and its priority. // ... and its priority.
int current_priority = 0; int current_priority = 0;
}; };
using built_in_func = BifReturnVal (*)(Frame* frame, const Args* args); using built_in_func = BifReturnVal (*)(Frame* frame, const Args* args);
class BuiltinFunc final : public Func class BuiltinFunc final : public Func {
{
public: public:
BuiltinFunc(built_in_func func, const char* name, bool is_pure); BuiltinFunc(built_in_func func, const char* name, bool is_pure);
~BuiltinFunc() override = default; ~BuiltinFunc() override = default;
@ -345,28 +325,25 @@ public:
void Describe(ODesc* d) const override; void Describe(ODesc* d) const override;
protected: protected:
BuiltinFunc() BuiltinFunc() {
{
func = nullptr; func = nullptr;
is_pure = 0; is_pure = 0;
} }
built_in_func func; built_in_func func;
bool is_pure; bool is_pure;
}; };
extern bool check_built_in_call(BuiltinFunc* f, CallExpr* call); extern bool check_built_in_call(BuiltinFunc* f, CallExpr* call);
struct CallInfo struct CallInfo {
{
const CallExpr* call; const CallExpr* call;
const Func* func; const Func* func;
const zeek::Args& args; const zeek::Args& args;
}; };
// Class that collects all the specifics defining a Func. // Class that collects all the specifics defining a Func.
class FunctionIngredients class FunctionIngredients {
{
public: public:
// Gathers all of the information from a scope and a function body needed // Gathers all of the information from a scope and a function body needed
// to build a function. // to build a function.
@ -397,7 +374,7 @@ private:
int priority = 0; int priority = 0;
ScopePtr scope; ScopePtr scope;
std::set<EventGroupPtr> groups; std::set<EventGroupPtr> groups;
}; };
using FunctionIngredientsPtr = std::shared_ptr<FunctionIngredients>; using FunctionIngredientsPtr = std::shared_ptr<FunctionIngredients>;
@ -427,19 +404,18 @@ extern bool did_builtin_init;
extern std::vector<void (*)()> bif_initializers; extern std::vector<void (*)()> bif_initializers;
extern void init_primary_bifs(); extern void init_primary_bifs();
inline void run_bif_initializers() inline void run_bif_initializers() {
{
for ( const auto& bi : bif_initializers ) for ( const auto& bi : bif_initializers )
bi(); bi();
bif_initializers = {}; bif_initializers = {};
} }
extern void emit_builtin_exception(const char* msg); extern void emit_builtin_exception(const char* msg);
extern void emit_builtin_exception(const char* msg, const ValPtr& arg); extern void emit_builtin_exception(const char* msg, const ValPtr& arg);
extern void emit_builtin_exception(const char* msg, Obj* arg); extern void emit_builtin_exception(const char* msg, Obj* arg);
} // namespace detail } // namespace detail
extern std::string render_call_stack(); extern std::string render_call_stack();
@ -448,4 +424,4 @@ extern void emit_builtin_error(const char* msg);
extern void emit_builtin_error(const char* msg, const ValPtr&); extern void emit_builtin_error(const char* msg, const ValPtr&);
extern void emit_builtin_error(const char* msg, Obj* arg); extern void emit_builtin_error(const char* msg, Obj* arg);
} // namespace zeek } // namespace zeek

View file

@ -18,8 +18,7 @@
#include "const.bif.netvar_h" #include "const.bif.netvar_h"
namespace zeek::detail namespace zeek::detail {
{
alignas(32) uint64_t KeyedHash::shared_highwayhash_key[4]; alignas(32) uint64_t KeyedHash::shared_highwayhash_key[4];
alignas(32) uint64_t KeyedHash::cluster_highwayhash_key[4]; alignas(32) uint64_t KeyedHash::cluster_highwayhash_key[4];
@ -27,17 +26,12 @@ alignas(16) unsigned long long KeyedHash::shared_siphash_key[2];
// we use the following lines to not pull in the highwayhash headers in Hash.h - but to check the // we use the following lines to not pull in the highwayhash headers in Hash.h - but to check the
// types did not change underneath us. // types did not change underneath us.
static_assert(std::is_same_v<hash64_t, highwayhash::HHResult64>, static_assert(std::is_same_v<hash64_t, highwayhash::HHResult64>, "Highwayhash return values must match hash_x_t");
"Highwayhash return values must match hash_x_t"); static_assert(std::is_same_v<hash128_t, highwayhash::HHResult128>, "Highwayhash return values must match hash_x_t");
static_assert(std::is_same_v<hash128_t, highwayhash::HHResult128>, static_assert(std::is_same_v<hash256_t, highwayhash::HHResult256>, "Highwayhash return values must match hash_x_t");
"Highwayhash return values must match hash_x_t");
static_assert(std::is_same_v<hash256_t, highwayhash::HHResult256>,
"Highwayhash return values must match hash_x_t");
void KeyedHash::InitializeSeeds(const std::array<uint32_t, SEED_INIT_SIZE>& seed_data) void KeyedHash::InitializeSeeds(const std::array<uint32_t, SEED_INIT_SIZE>& seed_data) {
{ static_assert(std::is_same_v<decltype(KeyedHash::shared_siphash_key), highwayhash::SipHashState::Key>,
static_assert(
std::is_same_v<decltype(KeyedHash::shared_siphash_key), highwayhash::SipHashState::Key>,
"Highwayhash Key is not unsigned long long[2]"); "Highwayhash Key is not unsigned long long[2]");
static_assert(std::is_same_v<decltype(KeyedHash::shared_highwayhash_key), highwayhash::HHKey>, static_assert(std::is_same_v<decltype(KeyedHash::shared_highwayhash_key), highwayhash::HHKey>,
"Highwayhash HHKey is not uint64_t[4]"); "Highwayhash HHKey is not uint64_t[4]");
@ -57,137 +51,101 @@ void KeyedHash::InitializeSeeds(const std::array<uint32_t, SEED_INIT_SIZE>& seed
memcpy(shared_siphash_key, reinterpret_cast<const char*>(seed_data.data()) + 64, 16); memcpy(shared_siphash_key, reinterpret_cast<const char*>(seed_data.data()) + 64, 16);
seeds_initialized = true; seeds_initialized = true;
} }
void KeyedHash::InitOptions() void KeyedHash::InitOptions() {
{
calculate_digest(Hash_SHA256, BifConst::digest_salt->Bytes(), BifConst::digest_salt->Len(), calculate_digest(Hash_SHA256, BifConst::digest_salt->Bytes(), BifConst::digest_salt->Len(),
reinterpret_cast<unsigned char*>(cluster_highwayhash_key)); reinterpret_cast<unsigned char*>(cluster_highwayhash_key));
} }
hash64_t KeyedHash::Hash64(const void* bytes, uint64_t size) hash64_t KeyedHash::Hash64(const void* bytes, uint64_t size) {
{
return highwayhash::SipHash(shared_siphash_key, static_cast<const char*>(bytes), size); return highwayhash::SipHash(shared_siphash_key, static_cast<const char*>(bytes), size);
} }
void KeyedHash::Hash128(const void* bytes, uint64_t size, hash128_t* result) void KeyedHash::Hash128(const void* bytes, uint64_t size, hash128_t* result) {
{ highwayhash::InstructionSets::Run<highwayhash::HighwayHash>(shared_highwayhash_key, static_cast<const char*>(bytes),
highwayhash::InstructionSets::Run<highwayhash::HighwayHash>( size, result);
shared_highwayhash_key, static_cast<const char*>(bytes), size, result); }
}
void KeyedHash::Hash256(const void* bytes, uint64_t size, hash256_t* result) void KeyedHash::Hash256(const void* bytes, uint64_t size, hash256_t* result) {
{ highwayhash::InstructionSets::Run<highwayhash::HighwayHash>(shared_highwayhash_key, static_cast<const char*>(bytes),
highwayhash::InstructionSets::Run<highwayhash::HighwayHash>( size, result);
shared_highwayhash_key, static_cast<const char*>(bytes), size, result); }
}
hash64_t KeyedHash::StaticHash64(const void* bytes, uint64_t size) hash64_t KeyedHash::StaticHash64(const void* bytes, uint64_t size) {
{
hash64_t result = 0; hash64_t result = 0;
highwayhash::InstructionSets::Run<highwayhash::HighwayHash>( highwayhash::InstructionSets::Run<highwayhash::HighwayHash>(cluster_highwayhash_key,
cluster_highwayhash_key, static_cast<const char*>(bytes), size, &result); static_cast<const char*>(bytes), size, &result);
return result; return result;
} }
void KeyedHash::StaticHash128(const void* bytes, uint64_t size, hash128_t* result) void KeyedHash::StaticHash128(const void* bytes, uint64_t size, hash128_t* result) {
{ highwayhash::InstructionSets::Run<highwayhash::HighwayHash>(cluster_highwayhash_key,
highwayhash::InstructionSets::Run<highwayhash::HighwayHash>( static_cast<const char*>(bytes), size, result);
cluster_highwayhash_key, static_cast<const char*>(bytes), size, result); }
}
void KeyedHash::StaticHash256(const void* bytes, uint64_t size, hash256_t* result) void KeyedHash::StaticHash256(const void* bytes, uint64_t size, hash256_t* result) {
{ highwayhash::InstructionSets::Run<highwayhash::HighwayHash>(cluster_highwayhash_key,
highwayhash::InstructionSets::Run<highwayhash::HighwayHash>( static_cast<const char*>(bytes), size, result);
cluster_highwayhash_key, static_cast<const char*>(bytes), size, result); }
}
void init_hash_function() void init_hash_function() {
{
// Make sure we have already called init_random_seed(). // Make sure we have already called init_random_seed().
if ( ! KeyedHash::IsInitialized() ) if ( ! KeyedHash::IsInitialized() )
reporter->InternalError("Zeek's hash functions aren't fully initialized"); reporter->InternalError("Zeek's hash functions aren't fully initialized");
} }
HashKey::HashKey(bool b) HashKey::HashKey(bool b) { Set(b); }
{
Set(b);
}
HashKey::HashKey(int i) HashKey::HashKey(int i) { Set(i); }
{
Set(i);
}
HashKey::HashKey(zeek_int_t bi) HashKey::HashKey(zeek_int_t bi) { Set(bi); }
{
Set(bi);
}
HashKey::HashKey(zeek_uint_t bu) HashKey::HashKey(zeek_uint_t bu) { Set(bu); }
{
Set(bu);
}
HashKey::HashKey(uint32_t u) HashKey::HashKey(uint32_t u) { Set(u); }
{
Set(u);
}
HashKey::HashKey(const uint32_t u[], size_t n) HashKey::HashKey(const uint32_t u[], size_t n) {
{
size = write_size = n * sizeof(u[0]); size = write_size = n * sizeof(u[0]);
key = (char*)u; key = (char*)u;
} }
HashKey::HashKey(double d) HashKey::HashKey(double d) { Set(d); }
{
Set(d);
}
HashKey::HashKey(const void* p) HashKey::HashKey(const void* p) { Set(p); }
{
Set(p);
}
HashKey::HashKey(const char* s) HashKey::HashKey(const char* s) {
{
size = write_size = strlen(s); // note - skip final \0 size = write_size = strlen(s); // note - skip final \0
key = (char*)s; key = (char*)s;
} }
HashKey::HashKey(const String* s) HashKey::HashKey(const String* s) {
{
size = write_size = s->Len(); size = write_size = s->Len();
key = (char*)s->Bytes(); key = (char*)s->Bytes();
} }
HashKey::HashKey(const void* bytes, size_t arg_size) HashKey::HashKey(const void* bytes, size_t arg_size) {
{
size = write_size = arg_size; size = write_size = arg_size;
key = CopyKey((char*)bytes, size); key = CopyKey((char*)bytes, size);
is_our_dynamic = true; is_our_dynamic = true;
} }
HashKey::HashKey(const void* arg_key, size_t arg_size, hash_t arg_hash) HashKey::HashKey(const void* arg_key, size_t arg_size, hash_t arg_hash) {
{
size = write_size = arg_size; size = write_size = arg_size;
hash = arg_hash; hash = arg_hash;
key = CopyKey((char*)arg_key, size); key = CopyKey((char*)arg_key, size);
is_our_dynamic = true; is_our_dynamic = true;
} }
HashKey::HashKey(const void* arg_key, size_t arg_size, hash_t arg_hash, bool /* dont_copy */) HashKey::HashKey(const void* arg_key, size_t arg_size, hash_t arg_hash, bool /* dont_copy */) {
{
size = write_size = arg_size; size = write_size = arg_size;
hash = arg_hash; hash = arg_hash;
key = (char*)arg_key; key = (char*)arg_key;
} }
HashKey::HashKey(const HashKey& other) : HashKey(other.key, other.size, other.hash) { } HashKey::HashKey(const HashKey& other) : HashKey(other.key, other.size, other.hash) {}
HashKey::HashKey(HashKey&& other) noexcept HashKey::HashKey(HashKey&& other) noexcept {
{
hash = other.hash; hash = other.hash;
size = other.size; size = other.size;
write_size = other.write_size; write_size = other.write_size;
@ -199,55 +157,46 @@ HashKey::HashKey(HashKey&& other) noexcept
other.size = 0; other.size = 0;
other.is_our_dynamic = false; other.is_our_dynamic = false;
other.key = nullptr; other.key = nullptr;
} }
HashKey::~HashKey() HashKey::~HashKey() {
{
if ( is_our_dynamic ) if ( is_our_dynamic )
delete[] reinterpret_cast<char*>(key); delete[] reinterpret_cast<char*>(key);
} }
hash_t HashKey::Hash() const hash_t HashKey::Hash() const {
{
if ( hash == 0 ) if ( hash == 0 )
hash = HashBytes(key, size); hash = HashBytes(key, size);
#ifdef DEBUG #ifdef DEBUG
if ( zeek::detail::debug_logger.IsEnabled(DBG_HASHKEY) ) if ( zeek::detail::debug_logger.IsEnabled(DBG_HASHKEY) ) {
{
ODesc d; ODesc d;
Describe(&d); Describe(&d);
DBG_LOG(DBG_HASHKEY, "HashKey %p %s", this, d.Description()); DBG_LOG(DBG_HASHKEY, "HashKey %p %s", this, d.Description());
} }
#endif #endif
return hash; return hash;
} }
void* HashKey::TakeKey() void* HashKey::TakeKey() {
{ if ( is_our_dynamic ) {
if ( is_our_dynamic )
{
is_our_dynamic = false; is_our_dynamic = false;
return key; return key;
} }
else else
return CopyKey(key, size); return CopyKey(key, size);
} }
void HashKey::Describe(ODesc* d) const void HashKey::Describe(ODesc* d) const {
{
char buf[64]; char buf[64];
snprintf(buf, 16, "%0" PRIx64, hash); snprintf(buf, 16, "%0" PRIx64, hash);
d->Add(buf); d->Add(buf);
d->SP(); d->SP();
if ( size > 0 ) if ( size > 0 ) {
{
d->Add(IsAllocated() ? "(" : "["); d->Add(IsAllocated() ? "(" : "[");
for ( size_t i = 0; i < size; i++ ) for ( size_t i = 0; i < size; i++ ) {
{ if ( i > 0 ) {
if ( i > 0 )
{
d->SP(); d->SP();
// Extra spacing every 8 bytes, for readability. // Extra spacing every 8 bytes, for readability.
if ( i % 8 == 0 ) if ( i % 8 == 0 )
@ -255,8 +204,7 @@ void HashKey::Describe(ODesc* d) const
} }
// Don't display unwritten content, only say how much there is. // Don't display unwritten content, only say how much there is.
if ( i > write_size ) if ( i > write_size ) {
{
d->Add("<+"); d->Add("<+");
d->Add(static_cast<uint64_t>(size - write_size - 1)); d->Add(static_cast<uint64_t>(size - write_size - 1));
d->Add(" of "); d->Add(" of ");
@ -271,84 +219,70 @@ void HashKey::Describe(ODesc* d) const
d->Add(IsAllocated() ? ")" : "]"); d->Add(IsAllocated() ? ")" : "]");
} }
} }
char* HashKey::CopyKey(const char* k, size_t s) const char* HashKey::CopyKey(const char* k, size_t s) const {
{
char* k_copy = new char[s]; // s == 0 is okay, returns non-nil char* k_copy = new char[s]; // s == 0 is okay, returns non-nil
memcpy(k_copy, k, s); memcpy(k_copy, k, s);
return k_copy; return k_copy;
} }
hash_t HashKey::HashBytes(const void* bytes, size_t size) hash_t HashKey::HashBytes(const void* bytes, size_t size) { return KeyedHash::Hash64(bytes, size); }
{
return KeyedHash::Hash64(bytes, size);
}
void HashKey::Set(bool b) void HashKey::Set(bool b) {
{
key_u.b = b; key_u.b = b;
key = reinterpret_cast<char*>(&key_u); key = reinterpret_cast<char*>(&key_u);
size = write_size = sizeof(b); size = write_size = sizeof(b);
} }
void HashKey::Set(int i) void HashKey::Set(int i) {
{
key_u.i = i; key_u.i = i;
key = reinterpret_cast<char*>(&key_u); key = reinterpret_cast<char*>(&key_u);
size = write_size = sizeof(i); size = write_size = sizeof(i);
} }
void HashKey::Set(zeek_int_t bi) void HashKey::Set(zeek_int_t bi) {
{
key_u.bi = bi; key_u.bi = bi;
key = reinterpret_cast<char*>(&key_u); key = reinterpret_cast<char*>(&key_u);
size = write_size = sizeof(bi); size = write_size = sizeof(bi);
} }
void HashKey::Set(zeek_uint_t bu) void HashKey::Set(zeek_uint_t bu) {
{
key_u.bi = zeek_int_t(bu); key_u.bi = zeek_int_t(bu);
key = reinterpret_cast<char*>(&key_u); key = reinterpret_cast<char*>(&key_u);
size = write_size = sizeof(bu); size = write_size = sizeof(bu);
} }
void HashKey::Set(uint32_t u) void HashKey::Set(uint32_t u) {
{
key_u.u32 = u; key_u.u32 = u;
key = reinterpret_cast<char*>(&key_u); key = reinterpret_cast<char*>(&key_u);
size = write_size = sizeof(u); size = write_size = sizeof(u);
} }
void HashKey::Set(double d) void HashKey::Set(double d) {
{
key_u.d = d; key_u.d = d;
key = reinterpret_cast<char*>(&key_u); key = reinterpret_cast<char*>(&key_u);
size = write_size = sizeof(d); size = write_size = sizeof(d);
} }
void HashKey::Set(const void* p) void HashKey::Set(const void* p) {
{
key_u.p = p; key_u.p = p;
key = reinterpret_cast<char*>(&key_u); key = reinterpret_cast<char*>(&key_u);
size = write_size = sizeof(p); size = write_size = sizeof(p);
} }
void HashKey::Reserve(const char* tag, size_t addl_size, size_t alignment) void HashKey::Reserve(const char* tag, size_t addl_size, size_t alignment) {
{
ASSERT(! IsAllocated()); ASSERT(! IsAllocated());
size_t s0 = size; size_t s0 = size;
size_t s1 = util::memory_size_align(size, alignment); size_t s1 = util::memory_size_align(size, alignment);
size = s1 + addl_size; size = s1 + addl_size;
DBG_LOG(DBG_HASHKEY, "HashKey %p reserving %lu/%lu: %lu -> %lu -> %lu [%s]", this, addl_size, DBG_LOG(DBG_HASHKEY, "HashKey %p reserving %lu/%lu: %lu -> %lu -> %lu [%s]", this, addl_size, alignment, s0, s1,
alignment, s0, s1, size, tag); size, tag);
} }
void HashKey::Allocate() void HashKey::Allocate() {
{ if ( key != nullptr && key != reinterpret_cast<char*>(&key_u) ) {
if ( key != nullptr && key != reinterpret_cast<char*>(&key_u) )
{
reporter->InternalWarning("usage error in HashKey::Allocate(): already allocated"); reporter->InternalWarning("usage error in HashKey::Allocate(): already allocated");
return; return;
} }
@ -358,70 +292,56 @@ void HashKey::Allocate()
read_size = 0; read_size = 0;
write_size = 0; write_size = 0;
} }
void HashKey::Write(const char* tag, bool b) void HashKey::Write(const char* tag, bool b) { Write(tag, &b, sizeof(b), 0); }
{
Write(tag, &b, sizeof(b), 0);
}
void HashKey::Write(const char* tag, int i, bool align) void HashKey::Write(const char* tag, int i, bool align) {
{ if ( ! IsAllocated() ) {
if ( ! IsAllocated() )
{
Set(i); Set(i);
return; return;
} }
Write(tag, &i, sizeof(i), align ? sizeof(i) : 0); Write(tag, &i, sizeof(i), align ? sizeof(i) : 0);
} }
void HashKey::Write(const char* tag, zeek_int_t bi, bool align) void HashKey::Write(const char* tag, zeek_int_t bi, bool align) {
{ if ( ! IsAllocated() ) {
if ( ! IsAllocated() )
{
Set(bi); Set(bi);
return; return;
} }
Write(tag, &bi, sizeof(bi), align ? sizeof(bi) : 0); Write(tag, &bi, sizeof(bi), align ? sizeof(bi) : 0);
} }
void HashKey::Write(const char* tag, zeek_uint_t bu, bool align) void HashKey::Write(const char* tag, zeek_uint_t bu, bool align) {
{ if ( ! IsAllocated() ) {
if ( ! IsAllocated() )
{
Set(bu); Set(bu);
return; return;
} }
Write(tag, &bu, sizeof(bu), align ? sizeof(bu) : 0); Write(tag, &bu, sizeof(bu), align ? sizeof(bu) : 0);
} }
void HashKey::Write(const char* tag, uint32_t u, bool align) void HashKey::Write(const char* tag, uint32_t u, bool align) {
{ if ( ! IsAllocated() ) {
if ( ! IsAllocated() )
{
Set(u); Set(u);
return; return;
} }
Write(tag, &u, sizeof(u), align ? sizeof(u) : 0); Write(tag, &u, sizeof(u), align ? sizeof(u) : 0);
} }
void HashKey::Write(const char* tag, double d, bool align) void HashKey::Write(const char* tag, double d, bool align) {
{ if ( ! IsAllocated() ) {
if ( ! IsAllocated() )
{
Set(d); Set(d);
return; return;
} }
Write(tag, &d, sizeof(d), align ? sizeof(d) : 0); Write(tag, &d, sizeof(d), align ? sizeof(d) : 0);
} }
void HashKey::Write(const char* tag, const void* bytes, size_t n, size_t alignment) void HashKey::Write(const char* tag, const void* bytes, size_t n, size_t alignment) {
{
size_t s0 = write_size; size_t s0 = write_size;
AlignWrite(alignment); AlignWrite(alignment);
size_t s1 = write_size; size_t s1 = write_size;
@ -430,21 +350,18 @@ void HashKey::Write(const char* tag, const void* bytes, size_t n, size_t alignme
memcpy(key + write_size, bytes, n); memcpy(key + write_size, bytes, n);
write_size += n; write_size += n;
DBG_LOG(DBG_HASHKEY, "HashKey %p writing %lu/%lu: %lu -> %lu -> %lu [%s]", this, n, alignment, DBG_LOG(DBG_HASHKEY, "HashKey %p writing %lu/%lu: %lu -> %lu -> %lu [%s]", this, n, alignment, s0, s1, write_size,
s0, s1, write_size, tag); tag);
} }
void HashKey::SkipWrite(const char* tag, size_t n) void HashKey::SkipWrite(const char* tag, size_t n) {
{ DBG_LOG(DBG_HASHKEY, "HashKey %p skip-writing %lu: %lu -> %lu [%s]", this, n, write_size, write_size + n, tag);
DBG_LOG(DBG_HASHKEY, "HashKey %p skip-writing %lu: %lu -> %lu [%s]", this, n, write_size,
write_size + n, tag);
EnsureWriteSpace(n); EnsureWriteSpace(n);
write_size += n; write_size += n;
} }
void HashKey::AlignWrite(size_t alignment) void HashKey::AlignWrite(size_t alignment) {
{
ASSERT(IsAllocated()); ASSERT(IsAllocated());
if ( alignment == 0 ) if ( alignment == 0 )
@ -455,16 +372,16 @@ void HashKey::AlignWrite(size_t alignment)
write_size = util::memory_size_align(write_size, alignment); write_size = util::memory_size_align(write_size, alignment);
if ( write_size > size ) if ( write_size > size )
reporter->InternalError("buffer overflow in HashKey::AlignWrite(): " reporter->InternalError(
"buffer overflow in HashKey::AlignWrite(): "
"after alignment, %lu bytes used of %lu allocated", "after alignment, %lu bytes used of %lu allocated",
write_size, size); write_size, size);
while ( old_size < write_size ) while ( old_size < write_size )
key[old_size++] = '\0'; key[old_size++] = '\0';
} }
void HashKey::AlignRead(size_t alignment) const void HashKey::AlignRead(size_t alignment) const {
{
ASSERT(IsAllocated()); ASSERT(IsAllocated());
if ( alignment == 0 ) if ( alignment == 0 )
@ -475,43 +392,29 @@ void HashKey::AlignRead(size_t alignment) const
read_size = util::memory_size_align(read_size, alignment); read_size = util::memory_size_align(read_size, alignment);
if ( read_size > size ) if ( read_size > size )
reporter->InternalError("buffer overflow in HashKey::AlignRead(): " reporter->InternalError(
"buffer overflow in HashKey::AlignRead(): "
"after alignment, %lu bytes used of %lu allocated", "after alignment, %lu bytes used of %lu allocated",
read_size, size); read_size, size);
} }
void HashKey::Read(const char* tag, bool& b) const void HashKey::Read(const char* tag, bool& b) const { Read(tag, &b, sizeof(b), 0); }
{
Read(tag, &b, sizeof(b), 0);
}
void HashKey::Read(const char* tag, int& i, bool align) const void HashKey::Read(const char* tag, int& i, bool align) const { Read(tag, &i, sizeof(i), align ? sizeof(i) : 0); }
{
void HashKey::Read(const char* tag, zeek_int_t& i, bool align) const {
Read(tag, &i, sizeof(i), align ? sizeof(i) : 0); Read(tag, &i, sizeof(i), align ? sizeof(i) : 0);
} }
void HashKey::Read(const char* tag, zeek_int_t& i, bool align) const void HashKey::Read(const char* tag, zeek_uint_t& u, bool align) const {
{
Read(tag, &i, sizeof(i), align ? sizeof(i) : 0);
}
void HashKey::Read(const char* tag, zeek_uint_t& u, bool align) const
{
Read(tag, &u, sizeof(u), align ? sizeof(u) : 0); Read(tag, &u, sizeof(u), align ? sizeof(u) : 0);
} }
void HashKey::Read(const char* tag, uint32_t& u, bool align) const void HashKey::Read(const char* tag, uint32_t& u, bool align) const { Read(tag, &u, sizeof(u), align ? sizeof(u) : 0); }
{
Read(tag, &u, sizeof(u), align ? sizeof(u) : 0);
}
void HashKey::Read(const char* tag, double& d, bool align) const void HashKey::Read(const char* tag, double& d, bool align) const { Read(tag, &d, sizeof(d), align ? sizeof(d) : 0); }
{
Read(tag, &d, sizeof(d), align ? sizeof(d) : 0);
}
void HashKey::Read(const char* tag, void* out, size_t n, size_t alignment) const void HashKey::Read(const char* tag, void* out, size_t n, size_t alignment) const {
{
size_t s0 = read_size; size_t s0 = read_size;
AlignRead(alignment); AlignRead(alignment);
size_t s1 = read_size; size_t s1 = read_size;
@ -522,73 +425,69 @@ void HashKey::Read(const char* tag, void* out, size_t n, size_t alignment) const
// in memcpy even if the size is 0. // in memcpy even if the size is 0.
ASSERT(out != nullptr || (out == nullptr && n == 0)); ASSERT(out != nullptr || (out == nullptr && n == 0));
if ( n > 0 ) if ( n > 0 ) {
{
memcpy(out, key + read_size, n); memcpy(out, key + read_size, n);
read_size += n; read_size += n;
} }
DBG_LOG(DBG_HASHKEY, "HashKey %p reading %lu/%lu: %lu -> %lu -> %lu [%s]", this, n, alignment, DBG_LOG(DBG_HASHKEY, "HashKey %p reading %lu/%lu: %lu -> %lu -> %lu [%s]", this, n, alignment, s0, s1, read_size,
s0, s1, read_size, tag); tag);
} }
void HashKey::SkipRead(const char* tag, size_t n) const void HashKey::SkipRead(const char* tag, size_t n) const {
{ DBG_LOG(DBG_HASHKEY, "HashKey %p skip-reading %lu: %lu -> %lu [%s]", this, n, read_size, read_size + n, tag);
DBG_LOG(DBG_HASHKEY, "HashKey %p skip-reading %lu: %lu -> %lu [%s]", this, n, read_size,
read_size + n, tag);
EnsureReadSpace(n); EnsureReadSpace(n);
read_size += n; read_size += n;
} }
void HashKey::EnsureWriteSpace(size_t n) const void HashKey::EnsureWriteSpace(size_t n) const {
{
if ( n == 0 ) if ( n == 0 )
return; return;
if ( ! IsAllocated() ) if ( ! IsAllocated() )
reporter->InternalError("usage error in HashKey::EnsureWriteSpace(): " reporter->InternalError(
"usage error in HashKey::EnsureWriteSpace(): "
"size-checking unreserved buffer"); "size-checking unreserved buffer");
if ( write_size + n > size ) if ( write_size + n > size )
reporter->InternalError("buffer overflow in HashKey::Write(): writing %lu " reporter->InternalError(
"buffer overflow in HashKey::Write(): writing %lu "
"bytes with %lu remaining", "bytes with %lu remaining",
n, size - write_size); n, size - write_size);
} }
void HashKey::EnsureReadSpace(size_t n) const void HashKey::EnsureReadSpace(size_t n) const {
{
if ( n == 0 ) if ( n == 0 )
return; return;
if ( ! IsAllocated() ) if ( ! IsAllocated() )
reporter->InternalError("usage error in HashKey::EnsureReadSpace(): " reporter->InternalError(
"usage error in HashKey::EnsureReadSpace(): "
"size-checking unreserved buffer"); "size-checking unreserved buffer");
if ( read_size + n > size ) if ( read_size + n > size )
reporter->InternalError("buffer overflow in HashKey::EnsureReadSpace(): reading %lu " reporter->InternalError(
"buffer overflow in HashKey::EnsureReadSpace(): reading %lu "
"bytes with %lu remaining", "bytes with %lu remaining",
n, size - read_size); n, size - read_size);
} }
bool HashKey::operator==(const HashKey& other) const bool HashKey::operator==(const HashKey& other) const {
{
// Quick exit for the same object. // Quick exit for the same object.
if ( this == &other ) if ( this == &other )
return true; return true;
return Equal(other.key, other.size, other.hash); return Equal(other.key, other.size, other.hash);
} }
bool HashKey::operator!=(const HashKey& other) const bool HashKey::operator!=(const HashKey& other) const {
{
// Quick exit for different objects. // Quick exit for different objects.
if ( this != &other ) if ( this != &other )
return true; return true;
return ! Equal(other.key, other.size, other.hash); return ! Equal(other.key, other.size, other.hash);
} }
bool HashKey::Equal(const void* other_key, size_t other_size, hash_t other_hash) const bool HashKey::Equal(const void* other_key, size_t other_size, hash_t other_hash) const {
{
// If the key memory is the same just return true. // If the key memory is the same just return true.
if ( key == other_key && size == other_size ) if ( key == other_key && size == other_size )
return true; return true;
@ -599,10 +498,9 @@ bool HashKey::Equal(const void* other_key, size_t other_size, hash_t other_hash)
return false; return false;
return (hash == other_hash) && (size == other_size) && (memcmp(key, other_key, size) == 0); return (hash == other_hash) && (size == other_size) && (memcmp(key, other_key, size) == 0);
} }
HashKey& HashKey::operator=(const HashKey& other) HashKey& HashKey::operator=(const HashKey& other) {
{
if ( this == &other ) if ( this == &other )
return *this; return *this;
@ -618,10 +516,9 @@ HashKey& HashKey::operator=(const HashKey& other)
key = CopyKey(other.key, other.size); key = CopyKey(other.key, other.size);
return *this; return *this;
} }
HashKey& HashKey::operator=(HashKey&& other) noexcept HashKey& HashKey::operator=(HashKey&& other) noexcept {
{
if ( this == &other ) if ( this == &other )
return *this; return *this;
@ -641,32 +538,29 @@ HashKey& HashKey::operator=(HashKey&& other) noexcept
other.key = nullptr; other.key = nullptr;
return *this; return *this;
} }
TEST_SUITE_BEGIN("Hash"); TEST_SUITE_BEGIN("Hash");
TEST_CASE("equality") TEST_CASE("equality") {
{
HashKey h1(12345); HashKey h1(12345);
HashKey h2(12345); HashKey h2(12345);
HashKey h3(67890); HashKey h3(67890);
CHECK(h1 == h2); CHECK(h1 == h2);
CHECK(h1 != h3); CHECK(h1 != h3);
} }
TEST_CASE("copy assignment") TEST_CASE("copy assignment") {
{
HashKey h1(12345); HashKey h1(12345);
HashKey h2 = h1; HashKey h2 = h1;
HashKey h3{h1}; HashKey h3{h1};
CHECK(h1 == h2); CHECK(h1 == h2);
CHECK(h1 == h3); CHECK(h1 == h3);
} }
TEST_CASE("move assignment") TEST_CASE("move assignment") {
{
HashKey h1(12345); HashKey h1(12345);
HashKey h2(12345); HashKey h2(12345);
HashKey h3(12345); HashKey h3(12345);
@ -676,8 +570,8 @@ TEST_CASE("move assignment")
CHECK(h1 == h4); CHECK(h1 == h4);
CHECK(h1 == h5); CHECK(h1 == h5);
} }
TEST_SUITE_END(); TEST_SUITE_END();
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -27,37 +27,32 @@
// to allow md5_hmac_bif access to the hmac seed // to allow md5_hmac_bif access to the hmac seed
#include "zeek/ZeekArgs.h" #include "zeek/ZeekArgs.h"
namespace zeek namespace zeek {
{
class String; class String;
class ODesc; class ODesc;
} } // namespace zeek
namespace zeek::detail namespace zeek::detail {
{
class Frame; class Frame;
class BifReturnVal; class BifReturnVal;
} } // namespace zeek::detail
namespace zeek::BifFunc namespace zeek::BifFunc {
{
extern zeek::detail::BifReturnVal md5_hmac_bif(zeek::detail::Frame* frame, const zeek::Args*); extern zeek::detail::BifReturnVal md5_hmac_bif(zeek::detail::Frame* frame, const zeek::Args*);
} }
namespace zeek::detail namespace zeek::detail {
{
using hash_t = uint64_t; using hash_t = uint64_t;
using hash64_t = uint64_t; using hash64_t = uint64_t;
using hash128_t = uint64_t[2]; using hash128_t = uint64_t[2];
using hash256_t = uint64_t[4]; using hash256_t = uint64_t[4];
class KeyedHash class KeyedHash {
{
public: public:
/** /**
* Generate a 64 bit digest hash. * Generate a 64 bit digest hash.
@ -215,22 +210,15 @@ private:
inline static uint8_t shared_hmac_md5_key[16]; inline static uint8_t shared_hmac_md5_key[16];
inline static bool seeds_initialized = false; inline static bool seeds_initialized = false;
friend void util::detail::hmac_md5(size_t size, const unsigned char* bytes, friend void util::detail::hmac_md5(size_t size, const unsigned char* bytes, unsigned char digest[16]);
unsigned char digest[16]);
friend BifReturnVal BifFunc::md5_hmac_bif(zeek::detail::Frame* frame, const Args*); friend BifReturnVal BifFunc::md5_hmac_bif(zeek::detail::Frame* frame, const Args*);
}; };
enum HashKeyTag enum HashKeyTag { HASH_KEY_INT, HASH_KEY_DOUBLE, HASH_KEY_STRING };
{
HASH_KEY_INT,
HASH_KEY_DOUBLE,
HASH_KEY_STRING
};
constexpr int NUM_HASH_KEYS = HASH_KEY_STRING + 1; constexpr int NUM_HASH_KEYS = HASH_KEY_STRING + 1;
class HashKey class HashKey {
{
public: public:
explicit HashKey() { key_u.u32 = 0; } explicit HashKey() { key_u.u32 = 0; }
explicit HashKey(bool b); explicit HashKey(bool b);
@ -280,15 +268,15 @@ public:
// A HashKey is "allocated" when the underlying key points somewhere // A HashKey is "allocated" when the underlying key points somewhere
// other than our internal key_u union. This is almost like // other than our internal key_u union. This is almost like
// is_our_dynamic, but remains true also after TakeKey(). // is_our_dynamic, but remains true also after TakeKey().
bool IsAllocated() const bool IsAllocated() const { return (key != nullptr && key != reinterpret_cast<const char*>(&key_u)); }
{
return (key != nullptr && key != reinterpret_cast<const char*>(&key_u));
}
// Buffer size reservation. Repeated calls to these methods // Buffer size reservation. Repeated calls to these methods
// incrementally build up the eventual buffer size to be allocated via // incrementally build up the eventual buffer size to be allocated via
// Allocate(). // Allocate().
template <typename T> void ReserveType(const char* tag) { Reserve(tag, sizeof(T), sizeof(T)); } template<typename T>
void ReserveType(const char* tag) {
Reserve(tag, sizeof(T), sizeof(T));
}
void Reserve(const char* tag, size_t addl_size, size_t alignment = 0); void Reserve(const char* tag, size_t addl_size, size_t alignment = 0);
// Allocates the reserved amount of memory // Allocates the reserved amount of memory
@ -387,8 +375,8 @@ protected:
bool is_our_dynamic = false; bool is_our_dynamic = false;
size_t write_size = 0; size_t write_size = 0;
mutable size_t read_size = 0; mutable size_t read_size = 0;
}; };
extern void init_hash_function(); extern void init_hash_function();
} // namespace zeek::detail } // namespace zeek::detail

339
src/ID.cc
View file

@ -22,8 +22,7 @@
#include "zeek/zeekygen/ScriptInfo.h" #include "zeek/zeekygen/ScriptInfo.h"
#include "zeek/zeekygen/utils.h" #include "zeek/zeekygen/utils.h"
namespace zeek namespace zeek {
{
RecordTypePtr id::conn_id; RecordTypePtr id::conn_id;
RecordTypePtr id::endpoint; RecordTypePtr id::endpoint;
@ -37,61 +36,51 @@ TableTypePtr id::count_set;
VectorTypePtr id::string_vec; VectorTypePtr id::string_vec;
VectorTypePtr id::index_vec; VectorTypePtr id::index_vec;
const detail::IDPtr& id::find(std::string_view name) const detail::IDPtr& id::find(std::string_view name) { return zeek::detail::global_scope()->Find(name); }
{
return zeek::detail::global_scope()->Find(name);
}
const TypePtr& id::find_type(std::string_view name) const TypePtr& id::find_type(std::string_view name) {
{
auto id = zeek::detail::global_scope()->Find(name); auto id = zeek::detail::global_scope()->Find(name);
if ( ! id ) if ( ! id )
reporter->InternalError("Failed to find type named: %s", std::string(name).data()); reporter->InternalError("Failed to find type named: %s", std::string(name).data());
return id->GetType(); return id->GetType();
} }
const ValPtr& id::find_val(std::string_view name) const ValPtr& id::find_val(std::string_view name) {
{
auto id = zeek::detail::global_scope()->Find(name); auto id = zeek::detail::global_scope()->Find(name);
if ( ! id ) if ( ! id )
reporter->InternalError("Failed to find variable named: %s", std::string(name).data()); reporter->InternalError("Failed to find variable named: %s", std::string(name).data());
return id->GetVal(); return id->GetVal();
} }
const ValPtr& id::find_const(std::string_view name) const ValPtr& id::find_const(std::string_view name) {
{
auto id = zeek::detail::global_scope()->Find(name); auto id = zeek::detail::global_scope()->Find(name);
if ( ! id ) if ( ! id )
reporter->InternalError("Failed to find variable named: %s", std::string(name).data()); reporter->InternalError("Failed to find variable named: %s", std::string(name).data());
if ( ! id->IsConst() ) if ( ! id->IsConst() )
reporter->InternalError("Variable is not 'const', but expected to be: %s", reporter->InternalError("Variable is not 'const', but expected to be: %s", std::string(name).data());
std::string(name).data());
return id->GetVal(); return id->GetVal();
} }
FuncPtr id::find_func(std::string_view name) FuncPtr id::find_func(std::string_view name) {
{
const auto& v = id::find_val(name); const auto& v = id::find_val(name);
if ( ! v ) if ( ! v )
return nullptr; return nullptr;
if ( ! IsFunc(v->GetType()->Tag()) ) if ( ! IsFunc(v->GetType()->Tag()) )
reporter->InternalError("Expected variable '%s' to be a function", reporter->InternalError("Expected variable '%s' to be a function", std::string(name).data());
std::string(name).data());
return v.get()->As<FuncVal*>()->AsFuncPtr(); return v.get()->As<FuncVal*>()->AsFuncPtr();
} }
void id::detail::init_types() void id::detail::init_types() {
{
conn_id = id::find_type<RecordType>("conn_id"); conn_id = id::find_type<RecordType>("conn_id");
endpoint = id::find_type<RecordType>("endpoint"); endpoint = id::find_type<RecordType>("endpoint");
connection = id::find_type<RecordType>("connection"); connection = id::find_type<RecordType>("connection");
@ -103,13 +92,11 @@ void id::detail::init_types()
count_set = id::find_type<TableType>("count_set"); count_set = id::find_type<TableType>("count_set");
string_vec = id::find_type<VectorType>("string_vec"); string_vec = id::find_type<VectorType>("string_vec");
index_vec = id::find_type<VectorType>("index_vec"); index_vec = id::find_type<VectorType>("index_vec");
} }
namespace detail namespace detail {
{
ID::ID(const char* arg_name, IDScope arg_scope, bool arg_is_export) ID::ID(const char* arg_name, IDScope arg_scope, bool arg_is_export) {
{
name = util::copy_string(arg_name); name = util::copy_string(arg_name);
scope = arg_scope; scope = arg_scope;
is_export = arg_is_export; is_export = arg_is_export;
@ -128,31 +115,20 @@ ID::ID(const char* arg_name, IDScope arg_scope, bool arg_is_export)
infer_return_type = false; infer_return_type = false;
SetLocationInfo(&start_location, &end_location); SetLocationInfo(&start_location, &end_location);
} }
ID::~ID() ID::~ID() {
{
ClearOptInfo(); ClearOptInfo();
delete[] name; delete[] name;
} }
std::string ID::ModuleName() const std::string ID::ModuleName() const { return extract_module_name(name); }
{
return extract_module_name(name);
}
void ID::SetType(TypePtr t) void ID::SetType(TypePtr t) { type = std::move(t); }
{
type = std::move(t);
}
void ID::ClearVal() void ID::ClearVal() { val = nullptr; }
{
val = nullptr;
}
void ID::SetVal(ValPtr v) void ID::SetVal(ValPtr v) {
{
val = std::move(v); val = std::move(v);
Modified(); Modified();
@ -160,13 +136,10 @@ void ID::SetVal(ValPtr v)
UpdateValID(); UpdateValID();
#endif #endif
if ( type && val && type->Tag() == TYPE_FUNC && if ( type && val && type->Tag() == TYPE_FUNC && type->AsFuncType()->Flavor() == FUNC_FLAVOR_EVENT ) {
type->AsFuncType()->Flavor() == FUNC_FLAVOR_EVENT )
{
EventHandler* handler = event_registry->Lookup(name); EventHandler* handler = event_registry->Lookup(name);
auto func = val.get()->As<FuncVal*>()->AsFuncPtr(); auto func = val.get()->As<FuncVal*>()->AsFuncPtr();
if ( ! handler ) if ( ! handler ) {
{
handler = new EventHandler(name); handler = new EventHandler(name);
handler->SetFunc(func); handler->SetFunc(func);
event_registry->Register(handler, true); event_registry->Register(handler, true);
@ -174,106 +147,85 @@ void ID::SetVal(ValPtr v)
if ( ! IsExport() ) if ( ! IsExport() )
register_new_event({NewRef{}, this}); register_new_event({NewRef{}, this});
} }
else else {
{
// Otherwise, internally defined events cannot // Otherwise, internally defined events cannot
// have local handler. // have local handler.
handler->SetFunc(func); handler->SetFunc(func);
} }
} }
} }
void ID::SetVal(ValPtr v, InitClass c) void ID::SetVal(ValPtr v, InitClass c) {
{ if ( c == INIT_NONE || c == INIT_FULL ) {
if ( c == INIT_NONE || c == INIT_FULL )
{
SetVal(std::move(v)); SetVal(std::move(v));
return; return;
} }
if ( type->Tag() != TYPE_TABLE && (type->Tag() != TYPE_PATTERN || c == INIT_REMOVE) && if ( type->Tag() != TYPE_TABLE && (type->Tag() != TYPE_PATTERN || c == INIT_REMOVE) &&
(type->Tag() != TYPE_VECTOR || c == INIT_REMOVE) ) (type->Tag() != TYPE_VECTOR || c == INIT_REMOVE) ) {
{
if ( c == INIT_EXTRA ) if ( c == INIT_EXTRA )
Error("+= initializer only applies to tables, sets, vectors and patterns", v.get()); Error("+= initializer only applies to tables, sets, vectors and patterns", v.get());
else else
Error("-= initializer only applies to tables and sets", v.get()); Error("-= initializer only applies to tables and sets", v.get());
} }
else else {
{ if ( c == INIT_EXTRA ) {
if ( c == INIT_EXTRA ) if ( ! val ) {
{
if ( ! val )
{
SetVal(std::move(v)); SetVal(std::move(v));
return; return;
} }
else else
v->AddTo(val.get(), false); v->AddTo(val.get(), false);
} }
else else {
{
if ( val ) if ( val )
v->RemoveFrom(val.get()); v->RemoveFrom(val.get());
} }
} }
} }
void ID::SetVal(ExprPtr ev, InitClass c) void ID::SetVal(ExprPtr ev, InitClass c) {
{
const auto& a = attrs->Find(c == INIT_EXTRA ? ATTR_ADD_FUNC : ATTR_DEL_FUNC); const auto& a = attrs->Find(c == INIT_EXTRA ? ATTR_ADD_FUNC : ATTR_DEL_FUNC);
if ( ! a ) if ( ! a )
Internal("no add/delete function in ID::SetVal"); Internal("no add/delete function in ID::SetVal");
if ( ! val ) if ( ! val ) {
{ Error(zeek::util::fmt("%s initializer applied to ID without value", c == INIT_EXTRA ? "+=" : "-="), this);
Error(zeek::util::fmt("%s initializer applied to ID without value",
c == INIT_EXTRA ? "+=" : "-="),
this);
return; return;
} }
EvalFunc(a->GetExpr(), std::move(ev)); EvalFunc(a->GetExpr(), std::move(ev));
} }
bool ID::IsRedefinable() const bool ID::IsRedefinable() const { return GetAttr(ATTR_REDEF) != nullptr; }
{
return GetAttr(ATTR_REDEF) != nullptr;
}
void ID::SetAttrs(AttributesPtr a) void ID::SetAttrs(AttributesPtr a) {
{
attrs = nullptr; attrs = nullptr;
AddAttrs(std::move(a)); AddAttrs(std::move(a));
} }
void ID::UpdateValAttrs() void ID::UpdateValAttrs() {
{
if ( ! attrs ) if ( ! attrs )
return; return;
auto tag = GetType()->Tag(); auto tag = GetType()->Tag();
if ( tag == TYPE_FUNC ) if ( tag == TYPE_FUNC ) {
{
const auto& attr = attrs->Find(ATTR_ERROR_HANDLER); const auto& attr = attrs->Find(ATTR_ERROR_HANDLER);
if ( attr ) if ( attr )
event_registry->SetErrorHandler(Name()); event_registry->SetErrorHandler(Name());
} }
if ( tag == TYPE_RECORD ) if ( tag == TYPE_RECORD ) {
{
const auto& attr = attrs->Find(ATTR_LOG); const auto& attr = attrs->Find(ATTR_LOG);
if ( attr ) if ( attr ) {
{
// Apply &log to all record fields. // Apply &log to all record fields.
RecordType* rt = GetType()->AsRecordType(); RecordType* rt = GetType()->AsRecordType();
for ( int i = 0; i < rt->NumFields(); ++i ) for ( int i = 0; i < rt->NumFields(); ++i ) {
{
TypeDecl* fd = rt->FieldDecl(i); TypeDecl* fd = rt->FieldDecl(i);
if ( ! fd->attrs ) if ( ! fd->attrs )
@ -294,28 +246,20 @@ void ID::UpdateValAttrs()
else if ( vtag == TYPE_FILE ) else if ( vtag == TYPE_FILE )
val->AsFile()->SetAttrs(attrs.get()); val->AsFile()->SetAttrs(attrs.get());
} }
const AttrPtr& ID::GetAttr(AttrTag t) const const AttrPtr& ID::GetAttr(AttrTag t) const { return attrs ? attrs->Find(t) : Attr::nil; }
{
return attrs ? attrs->Find(t) : Attr::nil;
}
bool ID::IsDeprecated() const bool ID::IsDeprecated() const { return GetAttr(ATTR_DEPRECATED) != nullptr; }
{
return GetAttr(ATTR_DEPRECATED) != nullptr;
}
void ID::MakeDeprecated(ExprPtr deprecation) void ID::MakeDeprecated(ExprPtr deprecation) {
{
if ( IsDeprecated() ) if ( IsDeprecated() )
return; return;
AddAttr(make_intrusive<Attr>(ATTR_DEPRECATED, std::move(deprecation))); AddAttr(make_intrusive<Attr>(ATTR_DEPRECATED, std::move(deprecation)));
} }
std::string ID::GetDeprecationWarning() const std::string ID::GetDeprecationWarning() const {
{
std::string result; std::string result;
const auto& depr_attr = GetAttr(ATTR_DEPRECATED); const auto& depr_attr = GetAttr(ATTR_DEPRECATED);
@ -326,33 +270,29 @@ std::string ID::GetDeprecationWarning() const
return util::fmt("deprecated (%s)", Name()); return util::fmt("deprecated (%s)", Name());
else else
return util::fmt("deprecated (%s): %s", Name(), result.c_str()); return util::fmt("deprecated (%s): %s", Name(), result.c_str());
} }
void ID::AddAttr(AttrPtr a, bool is_redef) void ID::AddAttr(AttrPtr a, bool is_redef) {
{
std::vector<AttrPtr> attrv{std::move(a)}; std::vector<AttrPtr> attrv{std::move(a)};
auto attrs = make_intrusive<Attributes>(std::move(attrv), GetType(), false, IsGlobal()); auto attrs = make_intrusive<Attributes>(std::move(attrv), GetType(), false, IsGlobal());
AddAttrs(std::move(attrs), is_redef); AddAttrs(std::move(attrs), is_redef);
} }
void ID::AddAttrs(AttributesPtr a, bool is_redef) void ID::AddAttrs(AttributesPtr a, bool is_redef) {
{
if ( attrs ) if ( attrs )
attrs->AddAttrs(a, is_redef); attrs->AddAttrs(a, is_redef);
else else
attrs = std::move(a); attrs = std::move(a);
UpdateValAttrs(); UpdateValAttrs();
} }
void ID::RemoveAttr(AttrTag a) void ID::RemoveAttr(AttrTag a) {
{
if ( attrs ) if ( attrs )
attrs->RemoveAttr(a); attrs->RemoveAttr(a);
} }
void ID::SetOption() void ID::SetOption() {
{
if ( is_option ) if ( is_option )
return; return;
@ -361,25 +301,22 @@ void ID::SetOption()
// option implied redefinable // option implied redefinable
if ( ! IsRedefinable() ) if ( ! IsRedefinable() )
AddAttr(make_intrusive<Attr>(ATTR_REDEF)); AddAttr(make_intrusive<Attr>(ATTR_REDEF));
} }
void ID::EvalFunc(ExprPtr ef, ExprPtr ev) void ID::EvalFunc(ExprPtr ef, ExprPtr ev) {
{
auto arg1 = make_intrusive<detail::ConstExpr>(val); auto arg1 = make_intrusive<detail::ConstExpr>(val);
auto args = make_intrusive<detail::ListExpr>(); auto args = make_intrusive<detail::ListExpr>();
args->Append(std::move(arg1)); args->Append(std::move(arg1));
args->Append(std::move(ev)); args->Append(std::move(ev));
auto ce = make_intrusive<CallExpr>(std::move(ef), std::move(args)); auto ce = make_intrusive<CallExpr>(std::move(ef), std::move(args));
SetVal(ce->Eval(nullptr)); SetVal(ce->Eval(nullptr));
} }
TraversalCode ID::Traverse(TraversalCallback* cb) const TraversalCode ID::Traverse(TraversalCallback* cb) const {
{
TraversalCode tc = cb->PreID(this); TraversalCode tc = cb->PreID(this);
HANDLE_TC_STMT_PRE(tc); HANDLE_TC_STMT_PRE(tc);
if ( is_type ) if ( is_type ) {
{
tc = cb->PreTypedef(this); tc = cb->PreTypedef(this);
HANDLE_TC_STMT_PRE(tc); HANDLE_TC_STMT_PRE(tc);
@ -388,14 +325,12 @@ TraversalCode ID::Traverse(TraversalCallback* cb) const
} }
// FIXME: Perhaps we should be checking at other than global scope. // FIXME: Perhaps we should be checking at other than global scope.
else if ( val && IsFunc(val->GetType()->Tag()) && cb->current_scope == detail::global_scope() ) else if ( val && IsFunc(val->GetType()->Tag()) && cb->current_scope == detail::global_scope() ) {
{
tc = val->AsFunc()->Traverse(cb); tc = val->AsFunc()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc); HANDLE_TC_STMT_PRE(tc);
} }
else if ( ! is_enum_const ) else if ( ! is_enum_const ) {
{
tc = cb->PreDecl(this); tc = cb->PreDecl(this);
HANDLE_TC_STMT_PRE(tc); HANDLE_TC_STMT_PRE(tc);
@ -405,44 +340,35 @@ TraversalCode ID::Traverse(TraversalCallback* cb) const
tc = cb->PostID(this); tc = cb->PostID(this);
HANDLE_TC_EXPR_POST(tc); HANDLE_TC_EXPR_POST(tc);
} }
void ID::Error(const char* msg, const Obj* o2) void ID::Error(const char* msg, const Obj* o2) {
{
Obj::Error(msg, o2, true); Obj::Error(msg, o2, true);
SetType(error_type()); SetType(error_type());
} }
void ID::Describe(ODesc* d) const void ID::Describe(ODesc* d) const { d->Add(name); }
{
d->Add(name);
}
void ID::DescribeExtended(ODesc* d) const void ID::DescribeExtended(ODesc* d) const {
{
d->Add(name); d->Add(name);
if ( type ) if ( type ) {
{
d->Add(" : "); d->Add(" : ");
type->Describe(d); type->Describe(d);
} }
if ( val ) if ( val ) {
{
d->Add(" = "); d->Add(" = ");
val->Describe(d); val->Describe(d);
} }
if ( attrs ) if ( attrs ) {
{
d->Add(" "); d->Add(" ");
attrs->Describe(d); attrs->Describe(d);
} }
} }
void ID::DescribeReSTShort(ODesc* d) const void ID::DescribeReSTShort(ODesc* d) const {
{
if ( is_type ) if ( is_type )
d->Add(":zeek:type:`"); d->Add(":zeek:type:`");
else else
@ -451,26 +377,19 @@ void ID::DescribeReSTShort(ODesc* d) const
d->Add(name); d->Add(name);
d->Add("`"); d->Add("`");
if ( type ) if ( type ) {
{
d->Add(": "); d->Add(": ");
d->Add(":zeek:type:`"); d->Add(":zeek:type:`");
if ( ! is_type && ! type->GetName().empty() ) if ( ! is_type && ! type->GetName().empty() )
d->Add(type->GetName().c_str()); d->Add(type->GetName().c_str());
else else {
{
TypeTag t = type->Tag(); TypeTag t = type->Tag();
switch ( t ) switch ( t ) {
{ case TYPE_TABLE: d->Add(type->IsSet() ? "set" : type_name(t)); break;
case TYPE_TABLE:
d->Add(type->IsSet() ? "set" : type_name(t));
break;
case TYPE_FUNC: case TYPE_FUNC: d->Add(type->AsFuncType()->FlavorString().c_str()); break;
d->Add(type->AsFuncType()->FlavorString().c_str());
break;
case TYPE_ENUM: case TYPE_ENUM:
if ( is_type ) if ( is_type )
@ -479,26 +398,21 @@ void ID::DescribeReSTShort(ODesc* d) const
d->Add(zeekygen_mgr->GetEnumTypeName(Name()).c_str()); d->Add(zeekygen_mgr->GetEnumTypeName(Name()).c_str());
break; break;
default: default: d->Add(type_name(t)); break;
d->Add(type_name(t));
break;
} }
} }
d->Add("`"); d->Add("`");
} }
if ( attrs ) if ( attrs ) {
{
d->SP(); d->SP();
attrs->DescribeReST(d, true); attrs->DescribeReST(d, true);
} }
} }
void ID::DescribeReST(ODesc* d, bool roles_only) const void ID::DescribeReST(ODesc* d, bool roles_only) const {
{ if ( roles_only ) {
if ( roles_only )
{
if ( is_type ) if ( is_type )
d->Add(":zeek:type:`"); d->Add(":zeek:type:`");
else else
@ -506,8 +420,7 @@ void ID::DescribeReST(ODesc* d, bool roles_only) const
d->Add(name); d->Add(name);
d->Add("`"); d->Add("`");
} }
else else {
{
if ( is_type ) if ( is_type )
d->Add(".. zeek:type:: "); d->Add(".. zeek:type:: ");
else else
@ -515,8 +428,7 @@ void ID::DescribeReST(ODesc* d, bool roles_only) const
d->Add(name); d->Add(name);
if ( auto sc = zeek::zeekygen::detail::source_code_range(this) ) if ( auto sc = zeek::zeekygen::detail::source_code_range(this) ) {
{
d->PushIndent(); d->PushIndent();
d->Add(util::fmt(":source-code: %s", sc->data())); d->Add(util::fmt(":source-code: %s", sc->data()));
d->PopIndentNoNL(); d->PopIndentNoNL();
@ -526,36 +438,28 @@ void ID::DescribeReST(ODesc* d, bool roles_only) const
d->PushIndent(); d->PushIndent();
d->NL(); d->NL();
if ( type ) if ( type ) {
{
d->Add(":Type: "); d->Add(":Type: ");
if ( ! is_type && ! type->GetName().empty() ) if ( ! is_type && ! type->GetName().empty() ) {
{
d->Add(":zeek:type:`"); d->Add(":zeek:type:`");
d->Add(type->GetName()); d->Add(type->GetName());
d->Add("`"); d->Add("`");
} }
else else {
{
type->DescribeReST(d, roles_only); type->DescribeReST(d, roles_only);
if ( IsFunc(type->Tag()) ) if ( IsFunc(type->Tag()) ) {
{
auto ft = type->AsFuncType(); auto ft = type->AsFuncType();
if ( ft->Flavor() == FUNC_FLAVOR_EVENT || ft->Flavor() == FUNC_FLAVOR_HOOK ) if ( ft->Flavor() == FUNC_FLAVOR_EVENT || ft->Flavor() == FUNC_FLAVOR_HOOK ) {
{
const auto& protos = ft->Prototypes(); const auto& protos = ft->Prototypes();
if ( protos.size() > 1 ) if ( protos.size() > 1 ) {
{
auto first = true; auto first = true;
for ( const auto& proto : protos ) for ( const auto& proto : protos ) {
{ if ( first ) {
if ( first )
{
first = false; first = false;
continue; continue;
} }
@ -575,8 +479,7 @@ void ID::DescribeReST(ODesc* d, bool roles_only) const
d->NL(); d->NL();
} }
if ( attrs ) if ( attrs ) {
{
d->Add(":Attributes: "); d->Add(":Attributes: ");
attrs->DescribeReST(d); attrs->DescribeReST(d);
d->NL(); d->NL();
@ -586,20 +489,16 @@ void ID::DescribeReST(ODesc* d, bool roles_only) const
// Values within Version module are likely to include a // Values within Version module are likely to include a
// constantly-changing version number and be a frequent // constantly-changing version number and be a frequent
// source of error/desynchronization, so don't include them. // source of error/desynchronization, so don't include them.
ModuleName() != "Version" ) ModuleName() != "Version" ) {
{
d->Add(":Default:"); d->Add(":Default:");
auto ii = zeekygen_mgr->GetIdentifierInfo(Name()); auto ii = zeekygen_mgr->GetIdentifierInfo(Name());
auto redefs = ii->GetRedefs(); auto redefs = ii->GetRedefs();
const auto& iv = ! redefs.empty() && ii->InitialVal() ? ii->InitialVal() : val; const auto& iv = ! redefs.empty() && ii->InitialVal() ? ii->InitialVal() : val;
if ( type->InternalType() == TYPE_INTERNAL_OTHER ) if ( type->InternalType() == TYPE_INTERNAL_OTHER ) {
{ switch ( type->Tag() ) {
switch ( type->Tag() )
{
case TYPE_TABLE: case TYPE_TABLE:
if ( iv->AsTable()->Length() == 0 ) if ( iv->AsTable()->Length() == 0 ) {
{
d->Add(" ``{}``"); d->Add(" ``{}``");
d->NL(); d->NL();
break; break;
@ -618,15 +517,13 @@ void ID::DescribeReST(ODesc* d, bool roles_only) const
} }
} }
else else {
{
d->SP(); d->SP();
iv->DescribeReST(d); iv->DescribeReST(d);
d->NL(); d->NL();
} }
for ( auto& ir : redefs ) for ( auto& ir : redefs ) {
{
if ( ! ir->init_expr ) if ( ! ir->init_expr )
continue; continue;
@ -661,23 +558,18 @@ void ID::DescribeReST(ODesc* d, bool roles_only) const
d->PopIndent(); d->PopIndent();
} }
} }
} }
#ifdef DEBUG #ifdef DEBUG
void ID::UpdateValID() void ID::UpdateValID() {
{
if ( IsGlobal() && val && name && name[0] != '#' ) if ( IsGlobal() && val && name && name[0] != '#' )
val->SetID(this); val->SetID(this);
} }
#endif #endif
void ID::AddOptionHandler(FuncPtr callback, int priority) void ID::AddOptionHandler(FuncPtr callback, int priority) { option_handlers.emplace(priority, std::move(callback)); }
{
option_handlers.emplace(priority, std::move(callback));
}
std::vector<Func*> ID::GetOptionHandlers() const std::vector<Func*> ID::GetOptionHandlers() const {
{
// multimap is sorted // multimap is sorted
// It might be worth caching this if we expect it to be called // It might be worth caching this if we expect it to be called
// a lot... // a lot...
@ -685,14 +577,13 @@ std::vector<Func*> ID::GetOptionHandlers() const
for ( auto& element : option_handlers ) for ( auto& element : option_handlers )
v.push_back(element.second.get()); v.push_back(element.second.get());
return v; return v;
} }
void ID::ClearOptInfo() void ID::ClearOptInfo() {
{
delete opt_info; delete opt_info;
opt_info = nullptr; opt_info = nullptr;
} }
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -13,8 +13,7 @@
#include "zeek/Obj.h" #include "zeek/Obj.h"
#include "zeek/TraverseTypes.h" #include "zeek/TraverseTypes.h"
namespace zeek namespace zeek {
{
class Func; class Func;
class Val; class Val;
@ -31,29 +30,22 @@ using EnumTypePtr = IntrusivePtr<EnumType>;
using ValPtr = IntrusivePtr<Val>; using ValPtr = IntrusivePtr<Val>;
using FuncPtr = IntrusivePtr<Func>; using FuncPtr = IntrusivePtr<Func>;
} } // namespace zeek
namespace zeek::detail namespace zeek::detail {
{
class Attributes; class Attributes;
class Expr; class Expr;
using ExprPtr = IntrusivePtr<Expr>; using ExprPtr = IntrusivePtr<Expr>;
enum InitClass enum InitClass {
{
INIT_NONE, INIT_NONE,
INIT_FULL, INIT_FULL,
INIT_EXTRA, INIT_EXTRA,
INIT_REMOVE, INIT_REMOVE,
INIT_SKIP, INIT_SKIP,
}; };
enum IDScope enum IDScope { SCOPE_FUNCTION, SCOPE_MODULE, SCOPE_GLOBAL };
{
SCOPE_FUNCTION,
SCOPE_MODULE,
SCOPE_GLOBAL
};
class ID; class ID;
using IDPtr = IntrusivePtr<ID>; using IDPtr = IntrusivePtr<ID>;
@ -61,8 +53,7 @@ using IDSet = std::unordered_set<const ID*>;
class IDOptInfo; class IDOptInfo;
class ID final : public Obj, public notifier::detail::Modifiable class ID final : public Obj, public notifier::detail::Modifiable {
{
public: public:
static inline const IDPtr nil; static inline const IDPtr nil;
@ -84,7 +75,10 @@ public:
const TypePtr& GetType() const { return type; } const TypePtr& GetType() const { return type; }
template <class T> IntrusivePtr<T> GetType() const { return cast_intrusive<T>(type); } template<class T>
IntrusivePtr<T> GetType() const {
return cast_intrusive<T>(type);
}
bool IsType() const { return is_type; } bool IsType() const { return is_type; }
@ -179,12 +173,11 @@ protected:
// via the associated pointer, to allow it to be modified in // via the associated pointer, to allow it to be modified in
// contexts where the ID is itself "const". // contexts where the ID is itself "const".
IDOptInfo* opt_info; IDOptInfo* opt_info;
}; };
} // namespace zeek::detail } // namespace zeek::detail
namespace zeek::id namespace zeek::id {
{
/** /**
* Lookup an ID in the global module and return it, if one exists; * Lookup an ID in the global module and return it, if one exists;
@ -208,10 +201,10 @@ const TypePtr& find_type(std::string_view name);
* @param name The identifier name to lookup * @param name The identifier name to lookup
* @return The type of the identifier. * @return The type of the identifier.
*/ */
template <class T> IntrusivePtr<T> find_type(std::string_view name) template<class T>
{ IntrusivePtr<T> find_type(std::string_view name) {
return cast_intrusive<T>(find_type(name)); return cast_intrusive<T>(find_type(name));
} }
/** /**
* Lookup an ID by its name and return its value. A fatal occurs if the ID * Lookup an ID by its name and return its value. A fatal occurs if the ID
@ -227,10 +220,10 @@ const ValPtr& find_val(std::string_view name);
* @param name The identifier name to lookup * @param name The identifier name to lookup
* @return The current value of the identifier. * @return The current value of the identifier.
*/ */
template <class T> IntrusivePtr<T> find_val(std::string_view name) template<class T>
{ IntrusivePtr<T> find_val(std::string_view name) {
return cast_intrusive<T>(find_val(name)); return cast_intrusive<T>(find_val(name));
} }
/** /**
* Lookup an ID by its name and return its value. A fatal occurs if the ID * Lookup an ID by its name and return its value. A fatal occurs if the ID
@ -246,10 +239,10 @@ const ValPtr& find_const(std::string_view name);
* @param name The identifier name to lookup * @param name The identifier name to lookup
* @return The current value of the identifier. * @return The current value of the identifier.
*/ */
template <class T> IntrusivePtr<T> find_const(std::string_view name) template<class T>
{ IntrusivePtr<T> find_const(std::string_view name) {
return cast_intrusive<T>(find_const(name)); return cast_intrusive<T>(find_const(name));
} }
/** /**
* Lookup an ID by its name and return the function it references. * Lookup an ID by its name and return the function it references.
@ -271,10 +264,9 @@ extern TableTypePtr count_set;
extern VectorTypePtr string_vec; extern VectorTypePtr string_vec;
extern VectorTypePtr index_vec; extern VectorTypePtr index_vec;
namespace detail namespace detail {
{
void init_types(); void init_types();
} // namespace detail } // namespace detail
} // namespace zeek::id } // namespace zeek::id

389
src/IP.cc
View file

@ -13,41 +13,34 @@
#include "zeek/Var.h" #include "zeek/Var.h"
#include "zeek/ZeekString.h" #include "zeek/ZeekString.h"
namespace zeek namespace zeek {
{
bool IPv6_Hdr::IsOptionTruncated(uint16_t off) const bool IPv6_Hdr::IsOptionTruncated(uint16_t off) const {
{ if ( Length() < off ) {
if ( Length() < off )
{
reporter->Weird("truncated_IPv6_option"); reporter->Weird("truncated_IPv6_option");
return true; return true;
} }
return false; return false;
} }
static VectorValPtr BuildOptionsVal(const u_char* data, int len) static VectorValPtr BuildOptionsVal(const u_char* data, int len) {
{
auto vv = make_intrusive<VectorVal>(id::find_type<VectorType>("ip6_options")); auto vv = make_intrusive<VectorVal>(id::find_type<VectorType>("ip6_options"));
while ( len > 0 && static_cast<size_t>(len) >= sizeof(struct ip6_opt) ) while ( len > 0 && static_cast<size_t>(len) >= sizeof(struct ip6_opt) ) {
{
static auto ip6_option_type = id::find_type<RecordType>("ip6_option"); static auto ip6_option_type = id::find_type<RecordType>("ip6_option");
const struct ip6_opt* opt = (const struct ip6_opt*)data; const struct ip6_opt* opt = (const struct ip6_opt*)data;
auto rv = make_intrusive<RecordVal>(ip6_option_type); auto rv = make_intrusive<RecordVal>(ip6_option_type);
rv->Assign(0, opt->ip6o_type); rv->Assign(0, opt->ip6o_type);
if ( opt->ip6o_type == 0 ) if ( opt->ip6o_type == 0 ) {
{
// Pad1 option // Pad1 option
rv->Assign(1, 0); rv->Assign(1, 0);
rv->Assign(2, val_mgr->EmptyString()); rv->Assign(2, val_mgr->EmptyString());
data += sizeof(uint8_t); data += sizeof(uint8_t);
len -= sizeof(uint8_t); len -= sizeof(uint8_t);
} }
else else {
{
// PadN or other option // PadN or other option
uint16_t off = 2 * sizeof(uint8_t); uint16_t off = 2 * sizeof(uint8_t);
@ -64,16 +57,13 @@ static VectorValPtr BuildOptionsVal(const u_char* data, int len)
} }
return vv; return vv;
} }
RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const {
{
RecordValPtr rv; RecordValPtr rv;
switch ( type ) switch ( type ) {
{ case IPPROTO_IPV6: {
case IPPROTO_IPV6:
{
static auto ip6_hdr_type = id::find_type<RecordType>("ip6_hdr"); static auto ip6_hdr_type = id::find_type<RecordType>("ip6_hdr");
rv = make_intrusive<RecordVal>(ip6_hdr_type); rv = make_intrusive<RecordVal>(ip6_hdr_type);
const struct ip6_hdr* ip6 = (const struct ip6_hdr*)data; const struct ip6_hdr* ip6 = (const struct ip6_hdr*)data;
@ -87,11 +77,9 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
if ( ! chain ) if ( ! chain )
chain = make_intrusive<VectorVal>(id::find_type<VectorType>("ip6_ext_hdr_chain")); chain = make_intrusive<VectorVal>(id::find_type<VectorType>("ip6_ext_hdr_chain"));
rv->Assign(7, std::move(chain)); rv->Assign(7, std::move(chain));
} } break;
break;
case IPPROTO_HOPOPTS: case IPPROTO_HOPOPTS: {
{
uint16_t off = 2 * sizeof(uint8_t); uint16_t off = 2 * sizeof(uint8_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
return nullptr; return nullptr;
@ -102,11 +90,9 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
rv->Assign(0, hbh->ip6h_nxt); rv->Assign(0, hbh->ip6h_nxt);
rv->Assign(1, hbh->ip6h_len); rv->Assign(1, hbh->ip6h_len);
rv->Assign(2, BuildOptionsVal(data + off, Length() - off)); rv->Assign(2, BuildOptionsVal(data + off, Length() - off));
} } break;
break;
case IPPROTO_DSTOPTS: case IPPROTO_DSTOPTS: {
{
uint16_t off = 2 * sizeof(uint8_t); uint16_t off = 2 * sizeof(uint8_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
return nullptr; return nullptr;
@ -117,11 +103,9 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
rv->Assign(0, dst->ip6d_nxt); rv->Assign(0, dst->ip6d_nxt);
rv->Assign(1, dst->ip6d_len); rv->Assign(1, dst->ip6d_len);
rv->Assign(2, BuildOptionsVal(data + off, Length() - off)); rv->Assign(2, BuildOptionsVal(data + off, Length() - off));
} } break;
break;
case IPPROTO_ROUTING: case IPPROTO_ROUTING: {
{
uint16_t off = 4 * sizeof(uint8_t); uint16_t off = 4 * sizeof(uint8_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
return nullptr; return nullptr;
@ -134,11 +118,9 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
rv->Assign(2, rt->ip6r_type); rv->Assign(2, rt->ip6r_type);
rv->Assign(3, rt->ip6r_segleft); rv->Assign(3, rt->ip6r_segleft);
rv->Assign(4, new String(data + off, Length() - off, true)); rv->Assign(4, new String(data + off, Length() - off, true));
} } break;
break;
case IPPROTO_FRAGMENT: case IPPROTO_FRAGMENT: {
{
static auto ip6_fragment_type = id::find_type<RecordType>("ip6_fragment"); static auto ip6_fragment_type = id::find_type<RecordType>("ip6_fragment");
rv = make_intrusive<RecordVal>(ip6_fragment_type); rv = make_intrusive<RecordVal>(ip6_fragment_type);
const struct ip6_frag* frag = (const struct ip6_frag*)data; const struct ip6_frag* frag = (const struct ip6_frag*)data;
@ -148,11 +130,9 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
rv->Assign(3, (ntohs(frag->ip6f_offlg) & 0x0006) >> 1); rv->Assign(3, (ntohs(frag->ip6f_offlg) & 0x0006) >> 1);
rv->Assign(4, static_cast<bool>(ntohs(frag->ip6f_offlg) & 0x0001)); rv->Assign(4, static_cast<bool>(ntohs(frag->ip6f_offlg) & 0x0001));
rv->Assign(5, static_cast<uint32_t>(ntohl(frag->ip6f_ident))); rv->Assign(5, static_cast<uint32_t>(ntohl(frag->ip6f_ident)));
} } break;
break;
case IPPROTO_AH: case IPPROTO_AH: {
{
static auto ip6_ah_type = id::find_type<RecordType>("ip6_ah"); static auto ip6_ah_type = id::find_type<RecordType>("ip6_ah");
rv = make_intrusive<RecordVal>(ip6_ah_type); rv = make_intrusive<RecordVal>(ip6_ah_type);
rv->Assign(0, ((ip6_ext*)data)->ip6e_nxt); rv->Assign(0, ((ip6_ext*)data)->ip6e_nxt);
@ -160,29 +140,24 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
rv->Assign(2, ntohs(((uint16_t*)data)[1])); rv->Assign(2, ntohs(((uint16_t*)data)[1]));
rv->Assign(3, static_cast<uint32_t>(ntohl(((uint32_t*)data)[1]))); rv->Assign(3, static_cast<uint32_t>(ntohl(((uint32_t*)data)[1])));
if ( Length() >= 12 ) if ( Length() >= 12 ) {
{
// Sequence Number and ICV fields can only be extracted if // Sequence Number and ICV fields can only be extracted if
// Payload Len was non-zero for this header. // Payload Len was non-zero for this header.
rv->Assign(4, static_cast<uint32_t>(ntohl(((uint32_t*)data)[2]))); rv->Assign(4, static_cast<uint32_t>(ntohl(((uint32_t*)data)[2])));
uint16_t off = 3 * sizeof(uint32_t); uint16_t off = 3 * sizeof(uint32_t);
rv->Assign(5, new String(data + off, Length() - off, true)); rv->Assign(5, new String(data + off, Length() - off, true));
} }
} } break;
break;
case IPPROTO_ESP: case IPPROTO_ESP: {
{
static auto ip6_esp_type = id::find_type<RecordType>("ip6_esp"); static auto ip6_esp_type = id::find_type<RecordType>("ip6_esp");
rv = make_intrusive<RecordVal>(ip6_esp_type); rv = make_intrusive<RecordVal>(ip6_esp_type);
const uint32_t* esp = (const uint32_t*)data; const uint32_t* esp = (const uint32_t*)data;
rv->Assign(0, static_cast<uint32_t>(ntohl(esp[0]))); rv->Assign(0, static_cast<uint32_t>(ntohl(esp[0])));
rv->Assign(1, static_cast<uint32_t>(ntohl(esp[1]))); rv->Assign(1, static_cast<uint32_t>(ntohl(esp[1])));
} } break;
break;
case IPPROTO_MOBILITY: case IPPROTO_MOBILITY: {
{
static auto ip6_mob_type = id::find_type<RecordType>("ip6_mobility_hdr"); static auto ip6_mob_type = id::find_type<RecordType>("ip6_mobility_hdr");
rv = make_intrusive<RecordVal>(ip6_mob_type); rv = make_intrusive<RecordVal>(ip6_mob_type);
const struct ip6_mobility* mob = (const struct ip6_mobility*)data; const struct ip6_mobility* mob = (const struct ip6_mobility*)data;
@ -208,10 +183,8 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
static auto ip6_mob_back_type = id::find_type<RecordType>("ip6_mobility_back"); static auto ip6_mob_back_type = id::find_type<RecordType>("ip6_mobility_back");
static auto ip6_mob_be_type = id::find_type<RecordType>("ip6_mobility_be"); static auto ip6_mob_be_type = id::find_type<RecordType>("ip6_mobility_be");
switch ( mob->ip6mob_type ) switch ( mob->ip6mob_type ) {
{ case 0: {
case 0:
{
off += sizeof(uint16_t); off += sizeof(uint16_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
break; break;
@ -223,8 +196,7 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
break; break;
} }
case 1: case 1: {
{
off += sizeof(uint16_t) + sizeof(uint64_t); off += sizeof(uint16_t) + sizeof(uint64_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
break; break;
@ -237,8 +209,7 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
break; break;
} }
case 2: case 2: {
{
off += sizeof(uint16_t) + sizeof(uint64_t); off += sizeof(uint16_t) + sizeof(uint64_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
break; break;
@ -251,8 +222,7 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
break; break;
} }
case 3: case 3: {
{
off += sizeof(uint16_t) + 2 * sizeof(uint64_t); off += sizeof(uint16_t) + 2 * sizeof(uint64_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
break; break;
@ -260,15 +230,13 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
auto m = make_intrusive<RecordVal>(ip6_mob_hot_type); auto m = make_intrusive<RecordVal>(ip6_mob_hot_type);
m->Assign(0, ntohs(*((uint16_t*)msg_data))); m->Assign(0, ntohs(*((uint16_t*)msg_data)));
m->Assign(1, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t))))); m->Assign(1, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t)))));
m->Assign( m->Assign(2, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t) + sizeof(uint64_t)))));
2, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t) + sizeof(uint64_t)))));
m->Assign(3, BuildOptionsVal(data + off, Length() - off)); m->Assign(3, BuildOptionsVal(data + off, Length() - off));
msg->Assign(4, std::move(m)); msg->Assign(4, std::move(m));
break; break;
} }
case 4: case 4: {
{
off += sizeof(uint16_t) + 2 * sizeof(uint64_t); off += sizeof(uint16_t) + 2 * sizeof(uint64_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
break; break;
@ -276,45 +244,37 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
auto m = make_intrusive<RecordVal>(ip6_mob_cot_type); auto m = make_intrusive<RecordVal>(ip6_mob_cot_type);
m->Assign(0, ntohs(*((uint16_t*)msg_data))); m->Assign(0, ntohs(*((uint16_t*)msg_data)));
m->Assign(1, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t))))); m->Assign(1, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t)))));
m->Assign( m->Assign(2, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t) + sizeof(uint64_t)))));
2, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t) + sizeof(uint64_t)))));
m->Assign(3, BuildOptionsVal(data + off, Length() - off)); m->Assign(3, BuildOptionsVal(data + off, Length() - off));
msg->Assign(5, std::move(m)); msg->Assign(5, std::move(m));
break; break;
} }
case 5: case 5: {
{
off += 3 * sizeof(uint16_t); off += 3 * sizeof(uint16_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
break; break;
auto m = make_intrusive<RecordVal>(ip6_mob_bu_type); auto m = make_intrusive<RecordVal>(ip6_mob_bu_type);
m->Assign(0, ntohs(*((uint16_t*)msg_data))); m->Assign(0, ntohs(*((uint16_t*)msg_data)));
m->Assign(1, static_cast<bool>( m->Assign(1, static_cast<bool>(ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x8000));
ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x8000)); m->Assign(2, static_cast<bool>(ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x4000));
m->Assign(2, static_cast<bool>( m->Assign(3, static_cast<bool>(ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x2000));
ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x4000)); m->Assign(4, static_cast<bool>(ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x1000));
m->Assign(3, static_cast<bool>(
ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x2000));
m->Assign(4, static_cast<bool>(
ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x1000));
m->Assign(5, ntohs(*((uint16_t*)(msg_data + 2 * sizeof(uint16_t))))); m->Assign(5, ntohs(*((uint16_t*)(msg_data + 2 * sizeof(uint16_t)))));
m->Assign(6, BuildOptionsVal(data + off, Length() - off)); m->Assign(6, BuildOptionsVal(data + off, Length() - off));
msg->Assign(6, std::move(m)); msg->Assign(6, std::move(m));
break; break;
} }
case 6: case 6: {
{
off += 3 * sizeof(uint16_t); off += 3 * sizeof(uint16_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
break; break;
auto m = make_intrusive<RecordVal>(ip6_mob_back_type); auto m = make_intrusive<RecordVal>(ip6_mob_back_type);
m->Assign(0, *((uint8_t*)msg_data)); m->Assign(0, *((uint8_t*)msg_data));
m->Assign(1, m->Assign(1, static_cast<bool>(*((uint8_t*)(msg_data + sizeof(uint8_t))) & 0x80));
static_cast<bool>(*((uint8_t*)(msg_data + sizeof(uint8_t))) & 0x80));
m->Assign(2, ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t))))); m->Assign(2, ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))));
m->Assign(3, ntohs(*((uint16_t*)(msg_data + 2 * sizeof(uint16_t))))); m->Assign(3, ntohs(*((uint16_t*)(msg_data + 2 * sizeof(uint16_t)))));
m->Assign(4, BuildOptionsVal(data + off, Length() - off)); m->Assign(4, BuildOptionsVal(data + off, Length() - off));
@ -322,8 +282,7 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
break; break;
} }
case 7: case 7: {
{
off += sizeof(uint16_t) + sizeof(in6_addr); off += sizeof(uint16_t) + sizeof(in6_addr);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
break; break;
@ -337,53 +296,32 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
break; break;
} }
default: default: reporter->Weird("unknown_mobility_type", util::fmt("%d", mob->ip6mob_type)); break;
reporter->Weird("unknown_mobility_type", util::fmt("%d", mob->ip6mob_type));
break;
} }
rv->Assign(5, std::move(msg)); rv->Assign(5, std::move(msg));
} } break;
break;
default: default: break;
break;
} }
return rv; return rv;
} }
RecordValPtr IPv6_Hdr::ToVal() const RecordValPtr IPv6_Hdr::ToVal() const { return ToVal(nullptr); }
{
return ToVal(nullptr);
}
IPAddr IP_Hdr::IPHeaderSrcAddr() const IPAddr IP_Hdr::IPHeaderSrcAddr() const { return ip4 ? IPAddr(ip4->ip_src) : IPAddr(ip6->ip6_src); }
{
return ip4 ? IPAddr(ip4->ip_src) : IPAddr(ip6->ip6_src);
}
IPAddr IP_Hdr::IPHeaderDstAddr() const IPAddr IP_Hdr::IPHeaderDstAddr() const { return ip4 ? IPAddr(ip4->ip_dst) : IPAddr(ip6->ip6_dst); }
{
return ip4 ? IPAddr(ip4->ip_dst) : IPAddr(ip6->ip6_dst);
}
IPAddr IP_Hdr::SrcAddr() const IPAddr IP_Hdr::SrcAddr() const { return ip4 ? IPAddr(ip4->ip_src) : ip6_hdrs->SrcAddr(); }
{
return ip4 ? IPAddr(ip4->ip_src) : ip6_hdrs->SrcAddr();
}
IPAddr IP_Hdr::DstAddr() const IPAddr IP_Hdr::DstAddr() const { return ip4 ? IPAddr(ip4->ip_dst) : ip6_hdrs->DstAddr(); }
{
return ip4 ? IPAddr(ip4->ip_dst) : ip6_hdrs->DstAddr();
}
RecordValPtr IP_Hdr::ToIPHdrVal() const RecordValPtr IP_Hdr::ToIPHdrVal() const {
{
RecordValPtr rval; RecordValPtr rval;
if ( ip4 ) if ( ip4 ) {
{
static auto ip4_hdr_type = id::find_type<RecordType>("ip4_hdr"); static auto ip4_hdr_type = id::find_type<RecordType>("ip4_hdr");
rval = make_intrusive<RecordVal>(ip4_hdr_type); rval = make_intrusive<RecordVal>(ip4_hdr_type);
rval->Assign(0, ip4->ip_hl * 4); rval->Assign(0, ip4->ip_hl * 4);
@ -399,22 +337,19 @@ RecordValPtr IP_Hdr::ToIPHdrVal() const
rval->Assign(10, make_intrusive<AddrVal>(ip4->ip_src.s_addr)); rval->Assign(10, make_intrusive<AddrVal>(ip4->ip_src.s_addr));
rval->Assign(11, make_intrusive<AddrVal>(ip4->ip_dst.s_addr)); rval->Assign(11, make_intrusive<AddrVal>(ip4->ip_dst.s_addr));
} }
else else {
{
rval = ((*ip6_hdrs)[0])->ToVal(ip6_hdrs->ToVal()); rval = ((*ip6_hdrs)[0])->ToVal(ip6_hdrs->ToVal());
} }
return rval; return rval;
} }
RecordValPtr IP_Hdr::ToPktHdrVal() const RecordValPtr IP_Hdr::ToPktHdrVal() const {
{
static auto pkt_hdr_type = id::find_type<RecordType>("pkt_hdr"); static auto pkt_hdr_type = id::find_type<RecordType>("pkt_hdr");
return ToPktHdrVal(make_intrusive<RecordVal>(pkt_hdr_type), 0); return ToPktHdrVal(make_intrusive<RecordVal>(pkt_hdr_type), 0);
} }
RecordValPtr IP_Hdr::ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const RecordValPtr IP_Hdr::ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const {
{
static auto tcp_hdr_type = id::find_type<RecordType>("tcp_hdr"); static auto tcp_hdr_type = id::find_type<RecordType>("tcp_hdr");
static auto udp_hdr_type = id::find_type<RecordType>("udp_hdr"); static auto udp_hdr_type = id::find_type<RecordType>("udp_hdr");
static auto icmp_hdr_type = id::find_type<RecordType>("icmp_hdr"); static auto icmp_hdr_type = id::find_type<RecordType>("icmp_hdr");
@ -428,10 +363,8 @@ RecordValPtr IP_Hdr::ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const
const u_char* data = Payload(); const u_char* data = Payload();
int proto = NextProto(); int proto = NextProto();
switch ( proto ) switch ( proto ) {
{ case IPPROTO_TCP: {
case IPPROTO_TCP:
{
if ( PayloadLen() < sizeof(struct tcphdr) ) if ( PayloadLen() < sizeof(struct tcphdr) )
break; break;
@ -461,8 +394,7 @@ RecordValPtr IP_Hdr::ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const
break; break;
} }
case IPPROTO_UDP: case IPPROTO_UDP: {
{
if ( PayloadLen() < sizeof(struct udphdr) ) if ( PayloadLen() < sizeof(struct udphdr) )
break; break;
@ -477,8 +409,7 @@ RecordValPtr IP_Hdr::ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const
break; break;
} }
case IPPROTO_ICMP: case IPPROTO_ICMP: {
{
if ( PayloadLen() < sizeof(struct icmp) ) if ( PayloadLen() < sizeof(struct icmp) )
break; break;
@ -491,8 +422,7 @@ RecordValPtr IP_Hdr::ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const
break; break;
} }
case IPPROTO_ICMPV6: case IPPROTO_ICMPV6: {
{
if ( PayloadLen() < sizeof(struct icmp6_hdr) ) if ( PayloadLen() < sizeof(struct icmp6_hdr) )
break; break;
@ -505,57 +435,47 @@ RecordValPtr IP_Hdr::ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const
break; break;
} }
default: default: {
{
// This is not a protocol we understand. // This is not a protocol we understand.
break; break;
} }
} }
return pkt_hdr; return pkt_hdr;
} }
static inline bool isIPv6ExtHeader(uint8_t type) static inline bool isIPv6ExtHeader(uint8_t type) {
{ switch ( type ) {
switch ( type )
{
case IPPROTO_HOPOPTS: case IPPROTO_HOPOPTS:
case IPPROTO_ROUTING: case IPPROTO_ROUTING:
case IPPROTO_DSTOPTS: case IPPROTO_DSTOPTS:
case IPPROTO_FRAGMENT: case IPPROTO_FRAGMENT:
case IPPROTO_AH: case IPPROTO_AH:
case IPPROTO_ESP: case IPPROTO_ESP:
case IPPROTO_MOBILITY: case IPPROTO_MOBILITY: return true;
return true; default: return false;
default:
return false;
}
} }
}
IPv6_Hdr_Chain::~IPv6_Hdr_Chain() IPv6_Hdr_Chain::~IPv6_Hdr_Chain() {
{
for ( size_t i = 0; i < chain.size(); ++i ) for ( size_t i = 0; i < chain.size(); ++i )
delete chain[i]; delete chain[i];
delete homeAddr; delete homeAddr;
delete finalDst; delete finalDst;
} }
void IPv6_Hdr_Chain::Init(const struct ip6_hdr* ip6, uint64_t total_len, bool set_next, void IPv6_Hdr_Chain::Init(const struct ip6_hdr* ip6, uint64_t total_len, bool set_next, uint16_t next) {
uint16_t next)
{
length = 0; length = 0;
uint8_t current_type, next_type; uint8_t current_type, next_type;
next_type = IPPROTO_IPV6; next_type = IPPROTO_IPV6;
const u_char* hdrs = (const u_char*)ip6; const u_char* hdrs = (const u_char*)ip6;
if ( total_len < (int)sizeof(struct ip6_hdr) ) if ( total_len < (int)sizeof(struct ip6_hdr) ) {
{
reporter->InternalWarning("truncated IP header in IPv6_HdrChain::Init"); reporter->InternalWarning("truncated IP header in IPv6_HdrChain::Init");
return; return;
} }
do do {
{
// We can't determine a given header's length if there's less than // We can't determine a given header's length if there's less than
// two bytes of data available (2nd byte of extension headers is length) // two bytes of data available (2nd byte of extension headers is length)
if ( total_len < 2 ) if ( total_len < 2 )
@ -568,14 +488,12 @@ void IPv6_Hdr_Chain::Init(const struct ip6_hdr* ip6, uint64_t total_len, bool se
uint16_t cur_len = p->Length(); uint16_t cur_len = p->Length();
// If this header is truncated, don't add it to chain, don't go further. // If this header is truncated, don't add it to chain, don't go further.
if ( cur_len > total_len ) if ( cur_len > total_len ) {
{
delete p; delete p;
return; return;
} }
if ( set_next && next_type == IPPROTO_FRAGMENT ) if ( set_next && next_type == IPPROTO_FRAGMENT ) {
{
p->ChangeNext(next); p->ChangeNext(next);
next_type = next; next_type = next;
} }
@ -594,53 +512,45 @@ void IPv6_Hdr_Chain::Init(const struct ip6_hdr* ip6, uint64_t total_len, bool se
length += cur_len; length += cur_len;
total_len -= cur_len; total_len -= cur_len;
} while ( current_type != IPPROTO_FRAGMENT && current_type != IPPROTO_ESP && } while ( current_type != IPPROTO_FRAGMENT && current_type != IPPROTO_ESP && current_type != IPPROTO_MOBILITY &&
current_type != IPPROTO_MOBILITY && isIPv6ExtHeader(next_type) ); isIPv6ExtHeader(next_type) );
} }
bool IPv6_Hdr_Chain::IsFragment() const bool IPv6_Hdr_Chain::IsFragment() const {
{ if ( chain.empty() ) {
if ( chain.empty() )
{
reporter->InternalWarning("empty IPv6 header chain"); reporter->InternalWarning("empty IPv6 header chain");
return false; return false;
} }
return chain[chain.size() - 1]->Type() == IPPROTO_FRAGMENT; return chain[chain.size() - 1]->Type() == IPPROTO_FRAGMENT;
} }
IPAddr IPv6_Hdr_Chain::SrcAddr() const IPAddr IPv6_Hdr_Chain::SrcAddr() const {
{
if ( homeAddr ) if ( homeAddr )
return {*homeAddr}; return {*homeAddr};
if ( chain.empty() ) if ( chain.empty() ) {
{
reporter->InternalWarning("empty IPv6 header chain"); reporter->InternalWarning("empty IPv6 header chain");
return {}; return {};
} }
return IPAddr{((const struct ip6_hdr*)(chain[0]->Data()))->ip6_src}; return IPAddr{((const struct ip6_hdr*)(chain[0]->Data()))->ip6_src};
} }
IPAddr IPv6_Hdr_Chain::DstAddr() const IPAddr IPv6_Hdr_Chain::DstAddr() const {
{
if ( finalDst ) if ( finalDst )
return {*finalDst}; return {*finalDst};
if ( chain.empty() ) if ( chain.empty() ) {
{
reporter->InternalWarning("empty IPv6 header chain"); reporter->InternalWarning("empty IPv6 header chain");
return {}; return {};
} }
return IPAddr{((const struct ip6_hdr*)(chain[0]->Data()))->ip6_dst}; return IPAddr{((const struct ip6_hdr*)(chain[0]->Data()))->ip6_dst};
} }
void IPv6_Hdr_Chain::ProcessRoutingHeader(const struct ip6_rthdr* r, uint16_t len) void IPv6_Hdr_Chain::ProcessRoutingHeader(const struct ip6_rthdr* r, uint16_t len) {
{ if ( finalDst ) {
if ( finalDst )
{
// RFC 2460 section 4.1 says Routing should occur at most once. // RFC 2460 section 4.1 says Routing should occur at most once.
reporter->Weird(SrcAddr(), DstAddr(), "multiple_routing_headers"); reporter->Weird(SrcAddr(), DstAddr(), "multiple_routing_headers");
return; return;
@ -649,12 +559,10 @@ void IPv6_Hdr_Chain::ProcessRoutingHeader(const struct ip6_rthdr* r, uint16_t le
// Last 16 bytes of header (for all known types) is the address we want. // Last 16 bytes of header (for all known types) is the address we want.
const in6_addr* addr = (const in6_addr*)(((const u_char*)r) + len - 16); const in6_addr* addr = (const in6_addr*)(((const u_char*)r) + len - 16);
switch ( r->ip6r_type ) switch ( r->ip6r_type ) {
{
case 0: // Defined by RFC 2460, deprecated by RFC 5095 case 0: // Defined by RFC 2460, deprecated by RFC 5095
{ {
if ( r->ip6r_segleft > 0 && r->ip6r_len >= 2 ) if ( r->ip6r_segleft > 0 && r->ip6r_len >= 2 ) {
{
if ( r->ip6r_len % 2 == 0 ) if ( r->ip6r_len % 2 == 0 )
finalDst = new IPAddr(*addr); finalDst = new IPAddr(*addr);
else else
@ -663,30 +571,23 @@ void IPv6_Hdr_Chain::ProcessRoutingHeader(const struct ip6_rthdr* r, uint16_t le
// Always raise a weird since this type is deprecated. // Always raise a weird since this type is deprecated.
reporter->Weird(SrcAddr(), DstAddr(), "routing0_hdr"); reporter->Weird(SrcAddr(), DstAddr(), "routing0_hdr");
} } break;
break;
case 2: // Defined by Mobile IPv6 RFC 6275. case 2: // Defined by Mobile IPv6 RFC 6275.
{ {
if ( r->ip6r_segleft > 0 ) if ( r->ip6r_segleft > 0 ) {
{
if ( r->ip6r_len == 2 ) if ( r->ip6r_len == 2 )
finalDst = new IPAddr(*addr); finalDst = new IPAddr(*addr);
else else
reporter->Weird(SrcAddr(), DstAddr(), "bad_routing2_len"); reporter->Weird(SrcAddr(), DstAddr(), "bad_routing2_len");
} }
} } break;
break;
default: default: reporter->Weird(SrcAddr(), DstAddr(), "unknown_routing_type", util::fmt("%d", r->ip6r_type)); break;
reporter->Weird(SrcAddr(), DstAddr(), "unknown_routing_type",
util::fmt("%d", r->ip6r_type));
break;
}
} }
}
void IPv6_Hdr_Chain::ProcessDstOpts(const struct ip6_dest* d, uint16_t len) void IPv6_Hdr_Chain::ProcessDstOpts(const struct ip6_dest* d, uint16_t len) {
{
// Skip two bytes to get the beginning of the first option structure. These // Skip two bytes to get the beginning of the first option structure. These
// two bytes are the protocol for the next header and extension header length, // two bytes are the protocol for the next header and extension header length,
// already known to exist before calling this method. See header format: // already known to exist before calling this method. See header format:
@ -697,39 +598,32 @@ void IPv6_Hdr_Chain::ProcessDstOpts(const struct ip6_dest* d, uint16_t len)
len -= 2 * sizeof(uint8_t); len -= 2 * sizeof(uint8_t);
data += 2 * sizeof(uint8_t); data += 2 * sizeof(uint8_t);
while ( len > 0 ) while ( len > 0 ) {
{
const struct ip6_opt* opt = (const struct ip6_opt*)data; const struct ip6_opt* opt = (const struct ip6_opt*)data;
switch ( opt->ip6o_type ) switch ( opt->ip6o_type ) {
{
case 0: case 0:
// If option type is zero, it's a Pad0 and can be just a single // If option type is zero, it's a Pad0 and can be just a single
// byte in width. Skip over it. // byte in width. Skip over it.
data += sizeof(uint8_t); data += sizeof(uint8_t);
len -= sizeof(uint8_t); len -= sizeof(uint8_t);
break; break;
default: default: {
{
// Double-check that the len can hold the whole option structure. // Double-check that the len can hold the whole option structure.
// Otherwise we get a buffer-overflow when we check the option_len. // Otherwise we get a buffer-overflow when we check the option_len.
// Also check that it holds everything for the option itself. // Also check that it holds everything for the option itself.
if ( len < sizeof(struct ip6_opt) || len < sizeof(struct ip6_opt) + opt->ip6o_len ) if ( len < sizeof(struct ip6_opt) || len < sizeof(struct ip6_opt) + opt->ip6o_len ) {
{
reporter->Weird(SrcAddr(), DstAddr(), "bad_ipv6_dest_opt_len"); reporter->Weird(SrcAddr(), DstAddr(), "bad_ipv6_dest_opt_len");
len = 0; len = 0;
break; break;
} }
if ( opt->ip6o_type == if ( opt->ip6o_type == 201 ) // Home Address Option, Mobile IPv6 RFC 6275 section 6.3
201 ) // Home Address Option, Mobile IPv6 RFC 6275 section 6.3
{
if ( opt->ip6o_len == sizeof(struct in6_addr) )
{ {
if ( opt->ip6o_len == sizeof(struct in6_addr) ) {
if ( homeAddr ) if ( homeAddr )
reporter->Weird(SrcAddr(), DstAddr(), "multiple_home_addr_opts"); reporter->Weird(SrcAddr(), DstAddr(), "multiple_home_addr_opts");
else else
homeAddr = new IPAddr( homeAddr = new IPAddr(*((const in6_addr*)(data + sizeof(struct ip6_opt))));
*((const in6_addr*)(data + sizeof(struct ip6_opt))));
} }
else else
reporter->Weird(SrcAddr(), DstAddr(), "bad_home_addr_len"); reporter->Weird(SrcAddr(), DstAddr(), "bad_home_addr_len");
@ -737,14 +631,12 @@ void IPv6_Hdr_Chain::ProcessDstOpts(const struct ip6_dest* d, uint16_t len)
data += sizeof(struct ip6_opt) + opt->ip6o_len; data += sizeof(struct ip6_opt) + opt->ip6o_len;
len -= sizeof(struct ip6_opt) + opt->ip6o_len; len -= sizeof(struct ip6_opt) + opt->ip6o_len;
} } break;
break;
}
} }
} }
}
VectorValPtr IPv6_Hdr_Chain::ToVal() const VectorValPtr IPv6_Hdr_Chain::ToVal() const {
{
static auto ip6_ext_hdr_type = id::find_type<RecordType>("ip6_ext_hdr"); static auto ip6_ext_hdr_type = id::find_type<RecordType>("ip6_ext_hdr");
static auto ip6_hopopts_type = id::find_type<RecordType>("ip6_hopopts"); static auto ip6_hopopts_type = id::find_type<RecordType>("ip6_hopopts");
static auto ip6_dstopts_type = id::find_type<RecordType>("ip6_dstopts"); static auto ip6_dstopts_type = id::find_type<RecordType>("ip6_dstopts");
@ -755,53 +647,33 @@ VectorValPtr IPv6_Hdr_Chain::ToVal() const
static auto ip6_ext_hdr_chain_type = id::find_type<VectorType>("ip6_ext_hdr_chain"); static auto ip6_ext_hdr_chain_type = id::find_type<VectorType>("ip6_ext_hdr_chain");
auto rval = make_intrusive<VectorVal>(ip6_ext_hdr_chain_type); auto rval = make_intrusive<VectorVal>(ip6_ext_hdr_chain_type);
for ( size_t i = 1; i < chain.size(); ++i ) for ( size_t i = 1; i < chain.size(); ++i ) {
{
auto v = chain[i]->ToVal(); auto v = chain[i]->ToVal();
auto ext_hdr = make_intrusive<RecordVal>(ip6_ext_hdr_type); auto ext_hdr = make_intrusive<RecordVal>(ip6_ext_hdr_type);
uint8_t type = chain[i]->Type(); uint8_t type = chain[i]->Type();
ext_hdr->Assign(0, type); ext_hdr->Assign(0, type);
switch ( type ) switch ( type ) {
{ case IPPROTO_HOPOPTS: ext_hdr->Assign(1, std::move(v)); break;
case IPPROTO_HOPOPTS: case IPPROTO_DSTOPTS: ext_hdr->Assign(2, std::move(v)); break;
ext_hdr->Assign(1, std::move(v)); case IPPROTO_ROUTING: ext_hdr->Assign(3, std::move(v)); break;
break; case IPPROTO_FRAGMENT: ext_hdr->Assign(4, std::move(v)); break;
case IPPROTO_DSTOPTS: case IPPROTO_AH: ext_hdr->Assign(5, std::move(v)); break;
ext_hdr->Assign(2, std::move(v)); case IPPROTO_ESP: ext_hdr->Assign(6, std::move(v)); break;
break; case IPPROTO_MOBILITY: ext_hdr->Assign(7, std::move(v)); break;
case IPPROTO_ROUTING: default: reporter->InternalWarning("IPv6_Hdr_Chain bad header %d", type); continue;
ext_hdr->Assign(3, std::move(v));
break;
case IPPROTO_FRAGMENT:
ext_hdr->Assign(4, std::move(v));
break;
case IPPROTO_AH:
ext_hdr->Assign(5, std::move(v));
break;
case IPPROTO_ESP:
ext_hdr->Assign(6, std::move(v));
break;
case IPPROTO_MOBILITY:
ext_hdr->Assign(7, std::move(v));
break;
default:
reporter->InternalWarning("IPv6_Hdr_Chain bad header %d", type);
continue;
} }
rval->Assign(rval->Size(), std::move(ext_hdr)); rval->Assign(rval->Size(), std::move(ext_hdr));
} }
return rval; return rval;
} }
IP_Hdr* IP_Hdr::Copy() const IP_Hdr* IP_Hdr::Copy() const {
{
char* new_hdr = new char[HdrLen()]; char* new_hdr = new char[HdrLen()];
if ( ip4 ) if ( ip4 ) {
{
memcpy(new_hdr, ip4, HdrLen()); memcpy(new_hdr, ip4, HdrLen());
return new IP_Hdr((const struct ip*)new_hdr, true); return new IP_Hdr((const struct ip*)new_hdr, true);
} }
@ -810,10 +682,9 @@ IP_Hdr* IP_Hdr::Copy() const
const struct ip6_hdr* new_ip6 = (const struct ip6_hdr*)new_hdr; const struct ip6_hdr* new_ip6 = (const struct ip6_hdr*)new_hdr;
IPv6_Hdr_Chain* new_ip6_hdrs = ip6_hdrs->Copy(new_ip6); IPv6_Hdr_Chain* new_ip6_hdrs = ip6_hdrs->Copy(new_ip6);
return new IP_Hdr(new_ip6, true, 0, new_ip6_hdrs); return new IP_Hdr(new_ip6, true, 0, new_ip6_hdrs);
} }
IPv6_Hdr_Chain* IPv6_Hdr_Chain::Copy(const ip6_hdr* new_hdr) const IPv6_Hdr_Chain* IPv6_Hdr_Chain::Copy(const ip6_hdr* new_hdr) const {
{
IPv6_Hdr_Chain* rval = new IPv6_Hdr_Chain; IPv6_Hdr_Chain* rval = new IPv6_Hdr_Chain;
rval->length = length; rval->length = length;
@ -823,8 +694,7 @@ IPv6_Hdr_Chain* IPv6_Hdr_Chain::Copy(const ip6_hdr* new_hdr) const
if ( finalDst ) if ( finalDst )
rval->finalDst = new IPAddr(*finalDst); rval->finalDst = new IPAddr(*finalDst);
if ( chain.empty() ) if ( chain.empty() ) {
{
reporter->InternalWarning("empty IPv6 header chain"); reporter->InternalWarning("empty IPv6 header chain");
delete rval; delete rval;
return nullptr; return nullptr;
@ -833,13 +703,12 @@ IPv6_Hdr_Chain* IPv6_Hdr_Chain::Copy(const ip6_hdr* new_hdr) const
const u_char* new_data = (const u_char*)new_hdr; const u_char* new_data = (const u_char*)new_hdr;
const u_char* old_data = chain[0]->Data(); const u_char* old_data = chain[0]->Data();
for ( size_t i = 0; i < chain.size(); ++i ) for ( size_t i = 0; i < chain.size(); ++i ) {
{
int off = chain[i]->Data() - old_data; int off = chain[i]->Data() - old_data;
rval->chain.push_back(new IPv6_Hdr(chain[i]->Type(), new_data + off)); rval->chain.push_back(new IPv6_Hdr(chain[i]->Type(), new_data + off));
} }
return rval; return rval;
} }
} // namespace zeek } // namespace zeek

157
src/IP.h
View file

@ -20,8 +20,7 @@
#include "zeek/IntrusivePtr.h" #include "zeek/IntrusivePtr.h"
namespace zeek namespace zeek {
{
class IPAddr; class IPAddr;
class RecordVal; class RecordVal;
@ -29,106 +28,83 @@ class VectorVal;
using RecordValPtr = IntrusivePtr<RecordVal>; using RecordValPtr = IntrusivePtr<RecordVal>;
using VectorValPtr = IntrusivePtr<VectorVal>; using VectorValPtr = IntrusivePtr<VectorVal>;
namespace detail namespace detail {
{
class FragReassembler; class FragReassembler;
} }
#ifndef IPPROTO_MOBILITY #ifndef IPPROTO_MOBILITY
#define IPPROTO_MOBILITY 135 #define IPPROTO_MOBILITY 135
#endif #endif
struct ip6_mobility struct ip6_mobility {
{
uint8_t ip6mob_payload; uint8_t ip6mob_payload;
uint8_t ip6mob_len; uint8_t ip6mob_len;
uint8_t ip6mob_type; uint8_t ip6mob_type;
uint8_t ip6mob_rsv; uint8_t ip6mob_rsv;
uint16_t ip6mob_chksum; uint16_t ip6mob_chksum;
}; };
/** /**
* Base class for IPv6 header/extensions. * Base class for IPv6 header/extensions.
*/ */
class IPv6_Hdr class IPv6_Hdr {
{
public: public:
/** /**
* Construct an IPv6 header or extension header from assigned type number. * Construct an IPv6 header or extension header from assigned type number.
*/ */
IPv6_Hdr(uint8_t t, const u_char* d) : type(t), data(d) { } IPv6_Hdr(uint8_t t, const u_char* d) : type(t), data(d) {}
/** /**
* Replace the value of the next protocol field. * Replace the value of the next protocol field.
*/ */
void ChangeNext(uint8_t next_type) void ChangeNext(uint8_t next_type) {
{ switch ( type ) {
switch ( type ) case IPPROTO_IPV6: ((ip6_hdr*)data)->ip6_nxt = next_type; break;
{
case IPPROTO_IPV6:
((ip6_hdr*)data)->ip6_nxt = next_type;
break;
case IPPROTO_HOPOPTS: case IPPROTO_HOPOPTS:
case IPPROTO_DSTOPTS: case IPPROTO_DSTOPTS:
case IPPROTO_ROUTING: case IPPROTO_ROUTING:
case IPPROTO_FRAGMENT: case IPPROTO_FRAGMENT:
case IPPROTO_AH: case IPPROTO_AH:
case IPPROTO_MOBILITY: case IPPROTO_MOBILITY: ((ip6_ext*)data)->ip6e_nxt = next_type; break;
((ip6_ext*)data)->ip6e_nxt = next_type;
break;
case IPPROTO_ESP: case IPPROTO_ESP:
default: default: break;
break;
} }
} }
~IPv6_Hdr() { } ~IPv6_Hdr() {}
/** /**
* Returns the assigned IPv6 extension header type number of the header * Returns the assigned IPv6 extension header type number of the header
* that immediately follows this one. * that immediately follows this one.
*/ */
uint8_t NextHdr() const uint8_t NextHdr() const {
{ switch ( type ) {
switch ( type ) case IPPROTO_IPV6: return ((ip6_hdr*)data)->ip6_nxt;
{
case IPPROTO_IPV6:
return ((ip6_hdr*)data)->ip6_nxt;
case IPPROTO_HOPOPTS: case IPPROTO_HOPOPTS:
case IPPROTO_DSTOPTS: case IPPROTO_DSTOPTS:
case IPPROTO_ROUTING: case IPPROTO_ROUTING:
case IPPROTO_FRAGMENT: case IPPROTO_FRAGMENT:
case IPPROTO_AH: case IPPROTO_AH:
case IPPROTO_MOBILITY: case IPPROTO_MOBILITY: return ((ip6_ext*)data)->ip6e_nxt;
return ((ip6_ext*)data)->ip6e_nxt;
case IPPROTO_ESP: case IPPROTO_ESP:
default: default: return IPPROTO_NONE;
return IPPROTO_NONE;
} }
} }
/** /**
* Returns the length of the header in bytes. * Returns the length of the header in bytes.
*/ */
uint16_t Length() const uint16_t Length() const {
{ switch ( type ) {
switch ( type ) case IPPROTO_IPV6: return 40;
{
case IPPROTO_IPV6:
return 40;
case IPPROTO_HOPOPTS: case IPPROTO_HOPOPTS:
case IPPROTO_DSTOPTS: case IPPROTO_DSTOPTS:
case IPPROTO_ROUTING: case IPPROTO_ROUTING:
case IPPROTO_MOBILITY: case IPPROTO_MOBILITY: return 8 + 8 * ((ip6_ext*)data)->ip6e_len;
return 8 + 8 * ((ip6_ext*)data)->ip6e_len; case IPPROTO_FRAGMENT: return 8;
case IPPROTO_FRAGMENT: case IPPROTO_AH: return 8 + 4 * ((ip6_ext*)data)->ip6e_len;
return 8; case IPPROTO_ESP: return 8; // encrypted payload begins after 8 bytes
case IPPROTO_AH: default: return 0;
return 8 + 4 * ((ip6_ext*)data)->ip6e_len;
case IPPROTO_ESP:
return 8; // encrypted payload begins after 8 bytes
default:
return 0;
} }
} }
@ -154,10 +130,9 @@ protected:
private: private:
bool IsOptionTruncated(uint16_t off) const; bool IsOptionTruncated(uint16_t off) const;
}; };
class IPv6_Hdr_Chain class IPv6_Hdr_Chain {
{
public: public:
/** /**
* Initializes the header chain from an IPv6 header structure. * Initializes the header chain from an IPv6 header structure.
@ -195,8 +170,7 @@ public:
/** /**
* Returns pointer to fragment header structure if the chain contains one. * Returns pointer to fragment header structure if the chain contains one.
*/ */
const struct ip6_frag* GetFragHdr() const const struct ip6_frag* GetFragHdr() const {
{
return IsFragment() ? (const struct ip6_frag*)chain[chain.size() - 1]->Data() : nullptr; return IsFragment() ? (const struct ip6_frag*)chain[chain.size() - 1]->Data() : nullptr;
} }
@ -204,10 +178,7 @@ public:
* If the header chain is a fragment, returns the offset in number of bytes * If the header chain is a fragment, returns the offset in number of bytes
* relative to the start of the Fragmentable Part of the original packet. * relative to the start of the Fragmentable Part of the original packet.
*/ */
uint16_t FragOffset() const uint16_t FragOffset() const { return IsFragment() ? (ntohs(GetFragHdr()->ip6f_offlg) & 0xfff8) : 0; }
{
return IsFragment() ? (ntohs(GetFragHdr()->ip6f_offlg) & 0xfff8) : 0;
}
/** /**
* If the header chain is a fragment, returns the identification field. * If the header chain is a fragment, returns the identification field.
@ -250,10 +221,7 @@ protected:
* Initializes the header chain from an IPv6 header structure, and replaces * Initializes the header chain from an IPv6 header structure, and replaces
* the first next protocol pointer field that points to a fragment header. * the first next protocol pointer field that points to a fragment header.
*/ */
IPv6_Hdr_Chain(const struct ip6_hdr* ip6, uint16_t next, uint64_t len) IPv6_Hdr_Chain(const struct ip6_hdr* ip6, uint16_t next, uint64_t len) { Init(ip6, len, true, next); }
{
Init(ip6, len, true, next);
}
/** /**
* Initializes the header chain from an IPv6 header structure of a given * Initializes the header chain from an IPv6 header structure of a given
@ -291,14 +259,13 @@ protected:
* non-zero segments left. * non-zero segments left.
*/ */
IPAddr* finalDst = nullptr; IPAddr* finalDst = nullptr;
}; };
/** /**
* A class that wraps either an IPv4 or IPv6 packet and abstracts methods * A class that wraps either an IPv4 or IPv6 packet and abstracts methods
* for inquiring about common features between the two. * for inquiring about common features between the two.
*/ */
class IP_Hdr class IP_Hdr {
{
public: public:
/** /**
* Construct the header wrapper from an IPv4 packet. Caller must have * Construct the header wrapper from an IPv4 packet. Caller must have
@ -308,9 +275,7 @@ public:
* @param reassembled whether this header is for a reassembled packet. * @param reassembled whether this header is for a reassembled packet.
*/ */
IP_Hdr(const struct ip* arg_ip4, bool arg_del, bool reassembled = false) IP_Hdr(const struct ip* arg_ip4, bool arg_del, bool reassembled = false)
: ip4(arg_ip4), del(arg_del), reassembled(reassembled) : ip4(arg_ip4), del(arg_del), reassembled(reassembled) {}
{
}
/** /**
* Construct the header wrapper from an IPv6 packet. Caller must have * Construct the header wrapper from an IPv6 packet. Caller must have
@ -324,12 +289,9 @@ public:
* @param c an already-constructed header chain to take ownership of. * @param c an already-constructed header chain to take ownership of.
* @param reassembled whether this header is for a reassembled packet. * @param reassembled whether this header is for a reassembled packet.
*/ */
IP_Hdr(const struct ip6_hdr* arg_ip6, bool arg_del, uint64_t len, IP_Hdr(const struct ip6_hdr* arg_ip6, bool arg_del, uint64_t len, const IPv6_Hdr_Chain* c = nullptr,
const IPv6_Hdr_Chain* c = nullptr, bool reassembled = false) bool reassembled = false)
: ip6(arg_ip6), ip6_hdrs(c ? c : new IPv6_Hdr_Chain(ip6, len)), del(arg_del), : ip6(arg_ip6), ip6_hdrs(c ? c : new IPv6_Hdr_Chain(ip6, len)), del(arg_del), reassembled(reassembled) {}
reassembled(reassembled)
{
}
/** /**
* Copy a header. The internal buffer which contains the header data * Copy a header. The internal buffer which contains the header data
@ -341,14 +303,12 @@ public:
/** /**
* Destructor. * Destructor.
*/ */
~IP_Hdr() ~IP_Hdr() {
{
delete ip6_hdrs; delete ip6_hdrs;
if ( del ) if ( del ) {
{ delete[] (struct ip*)ip4;
delete[](struct ip*) ip4; delete[] (struct ip6_hdr*)ip6;
delete[](struct ip6_hdr*) ip6;
} }
} }
@ -391,8 +351,7 @@ public:
* Returns a pointer to the payload of the IP packet, usually an * Returns a pointer to the payload of the IP packet, usually an
* upper-layer protocol. * upper-layer protocol.
*/ */
const u_char* Payload() const const u_char* Payload() const {
{
if ( ip4 ) if ( ip4 )
return ((const u_char*)ip4) + ip4->ip_hl * 4; return ((const u_char*)ip4) + ip4->ip_hl * 4;
@ -403,8 +362,7 @@ public:
* Returns a pointer to the mobility header of the IP packet, if present, * Returns a pointer to the mobility header of the IP packet, if present,
* else a null pointer. * else a null pointer.
*/ */
const ip6_mobility* MobilityHeader() const const ip6_mobility* MobilityHeader() const {
{
if ( ip4 ) if ( ip4 )
return nullptr; return nullptr;
else if ( (*ip6_hdrs)[ip6_hdrs->Size() - 1]->Type() != IPPROTO_MOBILITY ) else if ( (*ip6_hdrs)[ip6_hdrs->Size() - 1]->Type() != IPPROTO_MOBILITY )
@ -420,10 +378,8 @@ public:
* Also returns 0 if the IPv4 length field is set to zero - which is, e.g., * Also returns 0 if the IPv4 length field is set to zero - which is, e.g.,
* the case when TCP segment offloading is enabled. * the case when TCP segment offloading is enabled.
*/ */
uint16_t PayloadLen() const uint16_t PayloadLen() const {
{ if ( ip4 ) {
if ( ip4 )
{
// prevent overflow in case of segment offloading/zeroed header length. // prevent overflow in case of segment offloading/zeroed header length.
auto total_len = ntohs(ip4->ip_len); auto total_len = ntohs(ip4->ip_len);
return total_len ? total_len - ip4->ip_hl * 4 : 0; return total_len ? total_len - ip4->ip_hl * 4 : 0;
@ -435,8 +391,7 @@ public:
/** /**
* Returns the length of the IP packet (length of headers and payload). * Returns the length of the IP packet (length of headers and payload).
*/ */
uint32_t TotalLen() const uint32_t TotalLen() const {
{
if ( ip4 ) if ( ip4 )
return ntohs(ip4->ip_len); return ntohs(ip4->ip_len);
@ -451,8 +406,7 @@ public:
/** /**
* For IPv6 header chains, returns the type of the last header in the chain. * For IPv6 header chains, returns the type of the last header in the chain.
*/ */
uint8_t LastHeader() const uint8_t LastHeader() const {
{
if ( ip4 ) if ( ip4 )
return IPPROTO_RAW; return IPPROTO_RAW;
@ -468,8 +422,7 @@ public:
* upper-layer protocol. For IPv6, this returns the last (extension) * upper-layer protocol. For IPv6, this returns the last (extension)
* header's Next Header value. * header's Next Header value.
*/ */
unsigned char NextProto() const unsigned char NextProto() const {
{
if ( ip4 ) if ( ip4 )
return ip4->ip_p; return ip4->ip_p;
@ -488,19 +441,13 @@ public:
/** /**
* Returns whether the IP header indicates this packet is a fragment. * Returns whether the IP header indicates this packet is a fragment.
*/ */
bool IsFragment() const bool IsFragment() const { return ip4 ? (ntohs(ip4->ip_off) & 0x3fff) != 0 : ip6_hdrs->IsFragment(); }
{
return ip4 ? (ntohs(ip4->ip_off) & 0x3fff) != 0 : ip6_hdrs->IsFragment();
}
/** /**
* Returns the fragment packet's offset in relation to the original * Returns the fragment packet's offset in relation to the original
* packet in bytes. * packet in bytes.
*/ */
uint16_t FragOffset() const uint16_t FragOffset() const { return ip4 ? (ntohs(ip4->ip_off) & 0x1fff) * 8 : ip6_hdrs->FragOffset(); }
{
return ip4 ? (ntohs(ip4->ip_off) & 0x1fff) * 8 : ip6_hdrs->FragOffset();
}
/** /**
* Returns the fragment packet's identification field. * Returns the fragment packet's identification field.
@ -553,6 +500,6 @@ private:
const IPv6_Hdr_Chain* ip6_hdrs = nullptr; const IPv6_Hdr_Chain* ip6_hdrs = nullptr;
bool del = false; bool del = false;
bool reassembled = false; bool reassembled = false;
}; };
} // namespace zeek } // namespace zeek

View file

@ -13,28 +13,23 @@
#include "zeek/ZeekString.h" #include "zeek/ZeekString.h"
#include "zeek/analyzer/Manager.h" #include "zeek/analyzer/Manager.h"
namespace zeek namespace zeek {
{
const IPAddr IPAddr::v4_unspecified = IPAddr(in4_addr{}); const IPAddr IPAddr::v4_unspecified = IPAddr(in4_addr{});
const IPAddr IPAddr::v6_unspecified = IPAddr(); const IPAddr IPAddr::v6_unspecified = IPAddr();
namespace detail namespace detail {
{
ConnKey::ConnKey(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, ConnKey::ConnKey(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, TransportProto t,
TransportProto t, bool one_way) bool one_way) {
{
Init(src, dst, src_port, dst_port, t, one_way); Init(src, dst, src_port, dst_port, t, one_way);
} }
ConnKey::ConnKey(const ConnTuple& id) ConnKey::ConnKey(const ConnTuple& id) {
{
Init(id.src_addr, id.dst_addr, id.src_port, id.dst_port, id.proto, id.is_one_way); Init(id.src_addr, id.dst_addr, id.src_port, id.dst_port, id.proto, id.is_one_way);
} }
ConnKey& ConnKey::operator=(const ConnKey& rhs) ConnKey& ConnKey::operator=(const ConnKey& rhs) {
{
if ( this == &rhs ) if ( this == &rhs )
return *this; return *this;
@ -51,13 +46,11 @@ ConnKey& ConnKey::operator=(const ConnKey& rhs)
valid = rhs.valid; valid = rhs.valid;
return *this; return *this;
} }
ConnKey::ConnKey(Val* v) ConnKey::ConnKey(Val* v) {
{
const auto& vt = v->GetType(); const auto& vt = v->GetType();
if ( ! IsRecord(vt->Tag()) ) if ( ! IsRecord(vt->Tag()) ) {
{
valid = false; valid = false;
return; return;
} }
@ -68,23 +61,20 @@ ConnKey::ConnKey(Val* v)
int orig_h, orig_p; // indices into record's value list int orig_h, orig_p; // indices into record's value list
int resp_h, resp_p; int resp_h, resp_p;
if ( vr == id::conn_id ) if ( vr == id::conn_id ) {
{
orig_h = 0; orig_h = 0;
orig_p = 1; orig_p = 1;
resp_h = 2; resp_h = 2;
resp_p = 3; resp_p = 3;
} }
else else {
{
// While it's not a conn_id, it may have equivalent fields. // While it's not a conn_id, it may have equivalent fields.
orig_h = vr->FieldOffset("orig_h"); orig_h = vr->FieldOffset("orig_h");
resp_h = vr->FieldOffset("resp_h"); resp_h = vr->FieldOffset("resp_h");
orig_p = vr->FieldOffset("orig_p"); orig_p = vr->FieldOffset("orig_p");
resp_p = vr->FieldOffset("resp_p"); resp_p = vr->FieldOffset("resp_p");
if ( orig_h < 0 || resp_h < 0 || orig_p < 0 || resp_p < 0 ) if ( orig_h < 0 || resp_h < 0 || orig_p < 0 || resp_p < 0 ) {
{
valid = false; valid = false;
return; return;
} }
@ -99,13 +89,12 @@ ConnKey::ConnKey(Val* v)
auto orig_portv = vl->GetFieldAs<PortVal>(orig_p); auto orig_portv = vl->GetFieldAs<PortVal>(orig_p);
auto resp_portv = vl->GetFieldAs<PortVal>(resp_p); auto resp_portv = vl->GetFieldAs<PortVal>(resp_p);
Init(orig_addr, resp_addr, htons((unsigned short)orig_portv->Port()), Init(orig_addr, resp_addr, htons((unsigned short)orig_portv->Port()), htons((unsigned short)resp_portv->Port()),
htons((unsigned short)resp_portv->Port()), orig_portv->PortType(), false); orig_portv->PortType(), false);
} }
void ConnKey::Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, void ConnKey::Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, TransportProto t,
TransportProto t, bool one_way) bool one_way) {
{
// Because of padding in the object, this needs to memset to clear out // Because of padding in the object, this needs to memset to clear out
// the extra memory used by padding. Otherwise, the session key stuff // the extra memory used by padding. Otherwise, the session key stuff
// doesn't work quite right. // doesn't work quite right.
@ -114,15 +103,13 @@ void ConnKey::Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint
// Lookup up connection based on canonical ordering, which is // Lookup up connection based on canonical ordering, which is
// the smaller of <src addr, src port> and <dst addr, dst port> // the smaller of <src addr, src port> and <dst addr, dst port>
// followed by the other. // followed by the other.
if ( one_way || addr_port_canon_lt(src, src_port, dst, dst_port) ) if ( one_way || addr_port_canon_lt(src, src_port, dst, dst_port) ) {
{
ip1 = src.in6; ip1 = src.in6;
ip2 = dst.in6; ip2 = dst.in6;
port1 = src_port; port1 = src_port;
port2 = dst_port; port2 = dst_port;
} }
else else {
{
ip1 = dst.in6; ip1 = dst.in6;
ip2 = src.in6; ip2 = src.in6;
port1 = dst_port; port1 = dst_port;
@ -131,32 +118,25 @@ void ConnKey::Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint
transport = t; transport = t;
valid = true; valid = true;
} }
} // namespace detail } // namespace detail
IPAddr::IPAddr(const String& s) IPAddr::IPAddr(const String& s) { Init(s.CheckString()); }
{
Init(s.CheckString());
}
std::unique_ptr<detail::HashKey> IPAddr::MakeHashKey() const std::unique_ptr<detail::HashKey> IPAddr::MakeHashKey() const {
{
return std::make_unique<detail::HashKey>((void*)in6.s6_addr, sizeof(in6.s6_addr)); return std::make_unique<detail::HashKey>((void*)in6.s6_addr, sizeof(in6.s6_addr));
} }
static inline uint32_t bit_mask32(int bottom_bits) static inline uint32_t bit_mask32(int bottom_bits) {
{
if ( bottom_bits >= 32 ) if ( bottom_bits >= 32 )
return 0xffffffff; return 0xffffffff;
return (((uint32_t)1) << bottom_bits) - 1; return (((uint32_t)1) << bottom_bits) - 1;
} }
void IPAddr::Mask(int top_bits_to_keep) void IPAddr::Mask(int top_bits_to_keep) {
{ if ( top_bits_to_keep < 0 || top_bits_to_keep > 128 ) {
if ( top_bits_to_keep < 0 || top_bits_to_keep > 128 )
{
reporter->Error("Bad IPAddr::Mask value %d", top_bits_to_keep); reporter->Error("Bad IPAddr::Mask value %d", top_bits_to_keep);
return; return;
} }
@ -174,12 +154,10 @@ void IPAddr::Mask(int top_bits_to_keep)
for ( unsigned int i = 0; i < 4; ++i ) for ( unsigned int i = 0; i < 4; ++i )
p[i] &= mask_bits[i]; p[i] &= mask_bits[i];
} }
void IPAddr::ReverseMask(int top_bits_to_chop) void IPAddr::ReverseMask(int top_bits_to_chop) {
{ if ( top_bits_to_chop < 0 || top_bits_to_chop > 128 ) {
if ( top_bits_to_chop < 0 || top_bits_to_chop > 128 )
{
reporter->Error("Bad IPAddr::ReverseMask value %d", top_bits_to_chop); reporter->Error("Bad IPAddr::ReverseMask value %d", top_bits_to_chop);
return; return;
} }
@ -197,10 +175,9 @@ void IPAddr::ReverseMask(int top_bits_to_chop)
for ( unsigned int i = 0; i < 4; ++i ) for ( unsigned int i = 0; i < 4; ++i )
p[i] &= mask_bits[i]; p[i] &= mask_bits[i];
} }
bool IPAddr::ConvertString(const char* s, in6_addr* result) bool IPAddr::ConvertString(const char* s, in6_addr* result) {
{
for ( auto p = s; *p; ++p ) for ( auto p = s; *p; ++p )
if ( *p == ':' ) if ( *p == ':' )
// IPv6 // IPv6
@ -229,21 +206,17 @@ bool IPAddr::ConvertString(const char* s, in6_addr* result)
memcpy(result->s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix)); memcpy(result->s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix));
memcpy(&result->s6_addr[12], &addr, sizeof(uint32_t)); memcpy(&result->s6_addr[12], &addr, sizeof(uint32_t));
return true; return true;
} }
void IPAddr::Init(const char* s) void IPAddr::Init(const char* s) {
{ if ( ! ConvertString(s, &in6) ) {
if ( ! ConvertString(s, &in6) )
{
reporter->Error("Bad IP address: %s", s); reporter->Error("Bad IP address: %s", s);
memset(in6.s6_addr, 0, sizeof(in6.s6_addr)); memset(in6.s6_addr, 0, sizeof(in6.s6_addr));
} }
} }
std::string IPAddr::AsString() const std::string IPAddr::AsString() const {
{ if ( GetFamily() == IPv4 ) {
if ( GetFamily() == IPv4 )
{
char s[INET_ADDRSTRLEN]; char s[INET_ADDRSTRLEN];
if ( ! zeek_inet_ntop(AF_INET, &in6.s6_addr[12], s, INET_ADDRSTRLEN) ) if ( ! zeek_inet_ntop(AF_INET, &in6.s6_addr[12], s, INET_ADDRSTRLEN) )
@ -251,8 +224,7 @@ std::string IPAddr::AsString() const
else else
return s; return s;
} }
else else {
{
char s[INET6_ADDRSTRLEN]; char s[INET6_ADDRSTRLEN];
if ( ! zeek_inet_ntop(AF_INET6, in6.s6_addr, s, INET6_ADDRSTRLEN) ) if ( ! zeek_inet_ntop(AF_INET6, in6.s6_addr, s, INET6_ADDRSTRLEN) )
@ -260,31 +232,26 @@ std::string IPAddr::AsString() const
else else
return s; return s;
} }
} }
std::string IPAddr::AsHexString() const std::string IPAddr::AsHexString() const {
{
char buf[33]; char buf[33];
if ( GetFamily() == IPv4 ) if ( GetFamily() == IPv4 ) {
{
uint32_t* p = (uint32_t*)&in6.s6_addr[12]; uint32_t* p = (uint32_t*)&in6.s6_addr[12];
snprintf(buf, sizeof(buf), "%08x", (uint32_t)ntohl(*p)); snprintf(buf, sizeof(buf), "%08x", (uint32_t)ntohl(*p));
} }
else else {
{
uint32_t* p = (uint32_t*)in6.s6_addr; uint32_t* p = (uint32_t*)in6.s6_addr;
snprintf(buf, sizeof(buf), "%08x%08x%08x%08x", (uint32_t)ntohl(p[0]), (uint32_t)ntohl(p[1]), snprintf(buf, sizeof(buf), "%08x%08x%08x%08x", (uint32_t)ntohl(p[0]), (uint32_t)ntohl(p[1]),
(uint32_t)ntohl(p[2]), (uint32_t)ntohl(p[3])); (uint32_t)ntohl(p[2]), (uint32_t)ntohl(p[3]));
} }
return buf; return buf;
} }
std::string IPAddr::PtrName() const std::string IPAddr::PtrName() const {
{ if ( GetFamily() == IPv4 ) {
if ( GetFamily() == IPv4 )
{
char buf[256]; char buf[256];
uint32_t* p = (uint32_t*)&in6.s6_addr[12]; uint32_t* p = (uint32_t*)&in6.s6_addr[12];
uint32_t a = ntohl(*p); uint32_t a = ntohl(*p);
@ -295,17 +262,14 @@ std::string IPAddr::PtrName() const
snprintf(buf, sizeof(buf), "%u.%u.%u.%u.in-addr.arpa", a0, a1, a2, a3); snprintf(buf, sizeof(buf), "%u.%u.%u.%u.in-addr.arpa", a0, a1, a2, a3);
return buf; return buf;
} }
else else {
{
static const char hex_digit[] = "0123456789abcdef"; static const char hex_digit[] = "0123456789abcdef";
std::string ptr_name("ip6.arpa"); std::string ptr_name("ip6.arpa");
uint32_t* p = (uint32_t*)in6.s6_addr; uint32_t* p = (uint32_t*)in6.s6_addr;
for ( unsigned int i = 0; i < 4; ++i ) for ( unsigned int i = 0; i < 4; ++i ) {
{
uint32_t a = ntohl(p[i]); uint32_t a = ntohl(p[i]);
for ( unsigned int j = 1; j <= 8; ++j ) for ( unsigned int j = 1; j <= 8; ++j ) {
{
ptr_name.insert(0, 1, '.'); ptr_name.insert(0, 1, '.');
ptr_name.insert(0, 1, hex_digit[(a >> (32 - j * 4)) & 0x0f]); ptr_name.insert(0, 1, hex_digit[(a >> (32 - j * 4)) & 0x0f]);
} }
@ -313,68 +277,57 @@ std::string IPAddr::PtrName() const
return ptr_name; return ptr_name;
} }
} }
IPPrefix::IPPrefix(const in4_addr& in4, uint8_t length) : prefix(in4), length(96 + length) IPPrefix::IPPrefix(const in4_addr& in4, uint8_t length) : prefix(in4), length(96 + length) {
{ if ( length > 32 ) {
if ( length > 32 )
{
reporter->Error("Bad in4_addr IPPrefix length : %d", length); reporter->Error("Bad in4_addr IPPrefix length : %d", length);
this->length = 0; this->length = 0;
} }
prefix.Mask(this->length); prefix.Mask(this->length);
} }
IPPrefix::IPPrefix(const in6_addr& in6, uint8_t length) : prefix(in6), length(length) IPPrefix::IPPrefix(const in6_addr& in6, uint8_t length) : prefix(in6), length(length) {
{ if ( length > 128 ) {
if ( length > 128 )
{
reporter->Error("Bad in6_addr IPPrefix length : %d", length); reporter->Error("Bad in6_addr IPPrefix length : %d", length);
this->length = 0; this->length = 0;
} }
prefix.Mask(this->length); prefix.Mask(this->length);
} }
bool IPAddr::CheckPrefixLength(uint8_t length, bool len_is_v6_relative) const bool IPAddr::CheckPrefixLength(uint8_t length, bool len_is_v6_relative) const {
{ if ( GetFamily() == IPv4 && ! len_is_v6_relative ) {
if ( GetFamily() == IPv4 && ! len_is_v6_relative )
{
if ( length > 32 ) if ( length > 32 )
return false; return false;
} }
else else {
{
if ( length > 128 ) if ( length > 128 )
return false; return false;
} }
return true; return true;
} }
IPPrefix::IPPrefix(const IPAddr& addr, uint8_t length, bool len_is_v6_relative) : prefix(addr) IPPrefix::IPPrefix(const IPAddr& addr, uint8_t length, bool len_is_v6_relative) : prefix(addr) {
{ if ( prefix.CheckPrefixLength(length, len_is_v6_relative) ) {
if ( prefix.CheckPrefixLength(length, len_is_v6_relative) )
{
if ( prefix.GetFamily() == IPv4 && ! len_is_v6_relative ) if ( prefix.GetFamily() == IPv4 && ! len_is_v6_relative )
this->length = length + 96; this->length = length + 96;
else else
this->length = length; this->length = length;
} }
else else {
{
auto vstr = prefix.GetFamily() == IPv4 ? "v4" : "v6"; auto vstr = prefix.GetFamily() == IPv4 ? "v4" : "v6";
reporter->Error("Bad IPAddr(%s) IPPrefix length : %d", vstr, length); reporter->Error("Bad IPAddr(%s) IPPrefix length : %d", vstr, length);
this->length = 0; this->length = 0;
} }
prefix.Mask(this->length); prefix.Mask(this->length);
} }
std::string IPPrefix::AsString() const std::string IPPrefix::AsString() const {
{
char l[16]; char l[16];
if ( prefix.GetFamily() == IPv4 ) if ( prefix.GetFamily() == IPv4 )
@ -383,12 +336,10 @@ std::string IPPrefix::AsString() const
modp_uitoa10(length, l); modp_uitoa10(length, l);
return prefix.AsString() + "/" + l; return prefix.AsString() + "/" + l;
} }
std::unique_ptr<detail::HashKey> IPPrefix::MakeHashKey() const std::unique_ptr<detail::HashKey> IPPrefix::MakeHashKey() const {
{ struct {
struct
{
in6_addr ip; in6_addr ip;
uint32_t len; uint32_t len;
} key; } key;
@ -397,10 +348,9 @@ std::unique_ptr<detail::HashKey> IPPrefix::MakeHashKey() const
key.len = Length(); key.len = Length();
return std::make_unique<detail::HashKey>(&key, sizeof(key)); return std::make_unique<detail::HashKey>(&key, sizeof(key));
} }
bool IPPrefix::ConvertString(const char* text, IPPrefix* result) bool IPPrefix::ConvertString(const char* text, IPPrefix* result) {
{
std::string s(text); std::string s(text);
size_t slash_loc = s.find('/'); size_t slash_loc = s.find('/');
@ -422,6 +372,6 @@ bool IPPrefix::ConvertString(const char* text, IPPrefix* result)
*result = IPPrefix(ip, len); *result = IPPrefix(ip, len);
return true; return true;
} }
} // namespace zeek } // namespace zeek

View file

@ -12,20 +12,17 @@
using in4_addr = in_addr; using in4_addr = in_addr;
namespace zeek namespace zeek {
{
class String; class String;
struct ConnTuple; struct ConnTuple;
class Val; class Val;
namespace detail namespace detail {
{
class HashKey; class HashKey;
class ConnKey class ConnKey {
{
public: public:
in6_addr ip1; in6_addr ip1;
in6_addr ip2; in6_addr ip2;
@ -34,8 +31,7 @@ public:
TransportProto transport = TRANSPORT_UNKNOWN; TransportProto transport = TRANSPORT_UNKNOWN;
bool valid = true; bool valid = true;
ConnKey(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, ConnKey(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, TransportProto t, bool one_way);
TransportProto t, bool one_way);
ConnKey(const ConnTuple& conn); ConnKey(const ConnTuple& conn);
ConnKey(const ConnKey& rhs) { *this = rhs; } ConnKey(const ConnKey& rhs) { *this = rhs; }
ConnKey(Val* v); ConnKey(Val* v);
@ -50,17 +46,16 @@ public:
ConnKey& operator=(const ConnKey& rhs); ConnKey& operator=(const ConnKey& rhs);
private: private:
void Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, void Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, TransportProto t,
TransportProto t, bool one_way); bool one_way);
}; };
} // namespace detail } // namespace detail
/** /**
* Class storing both IPv4 and IPv6 addresses. * Class storing both IPv4 and IPv6 addresses.
*/ */
class IPAddr class IPAddr {
{
public: public:
/** /**
* Address family. * Address family.
@ -70,11 +65,7 @@ public:
/** /**
* Byte order. * Byte order.
*/ */
enum ByteOrder enum ByteOrder { Host, Network };
{
Host,
Network
};
/** /**
* Constructs the unspecified IPv6 address (all 128 bits zeroed). * Constructs the unspecified IPv6 address (all 128 bits zeroed).
@ -86,8 +77,7 @@ public:
* *
* @param in6 The IPv6 address. * @param in6 The IPv6 address.
*/ */
explicit IPAddr(const in4_addr& in4) explicit IPAddr(const in4_addr& in4) {
{
memcpy(in6.s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix)); memcpy(in6.s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix));
memcpy(&in6.s6_addr[12], &in4.s_addr, sizeof(in4.s_addr)); memcpy(&in6.s6_addr[12], &in4.s_addr, sizeof(in4.s_addr));
} }
@ -97,7 +87,7 @@ public:
* *
* @param in6 The IPv6 address. * @param in6 The IPv6 address.
*/ */
explicit IPAddr(const in6_addr& arg_in6) : in6(arg_in6) { } explicit IPAddr(const in6_addr& arg_in6) : in6(arg_in6) {}
/** /**
* Constructs an address instance from a string representation. * Constructs an address instance from a string representation.
@ -150,8 +140,7 @@ public:
/** /**
* Returns the address' family. * Returns the address' family.
*/ */
Family GetFamily() const Family GetFamily() const {
{
if ( memcmp(in6.s6_addr, v4_mapped_prefix, 12) == 0 ) if ( memcmp(in6.s6_addr, v4_mapped_prefix, 12) == 0 )
return IPv4; return IPv4;
@ -166,8 +155,7 @@ public:
/** /**
* Returns true if the address represents a multicast address. * Returns true if the address represents a multicast address.
*/ */
bool IsMulticast() const bool IsMulticast() const {
{
if ( GetFamily() == IPv4 ) if ( GetFamily() == IPv4 )
return in6.s6_addr[12] == 224; return in6.s6_addr[12] == 224;
@ -177,11 +165,10 @@ public:
/** /**
* Returns true if the address represents a broadcast address. * Returns true if the address represents a broadcast address.
*/ */
bool IsBroadcast() const bool IsBroadcast() const {
{
if ( GetFamily() == IPv4 ) if ( GetFamily() == IPv4 )
return ((in6.s6_addr[12] == 0xff) && (in6.s6_addr[13] == 0xff) && return ((in6.s6_addr[12] == 0xff) && (in6.s6_addr[13] == 0xff) && (in6.s6_addr[14] == 0xff) &&
(in6.s6_addr[14] == 0xff) && (in6.s6_addr[15] == 0xff)); (in6.s6_addr[15] == 0xff));
return false; return false;
} }
@ -198,15 +185,12 @@ public:
* @return The number of 32-bit words the raw representation uses. This * @return The number of 32-bit words the raw representation uses. This
* will be 1 for an IPv4 address and 4 for an IPv6 address. * will be 1 for an IPv4 address and 4 for an IPv6 address.
*/ */
int GetBytes(const uint32_t** bytes) const int GetBytes(const uint32_t** bytes) const {
{ if ( GetFamily() == IPv4 ) {
if ( GetFamily() == IPv4 )
{
*bytes = (uint32_t*)&in6.s6_addr[12]; *bytes = (uint32_t*)&in6.s6_addr[12];
return 1; return 1;
} }
else else {
{
*bytes = (uint32_t*)in6.s6_addr; *bytes = (uint32_t*)in6.s6_addr;
return 4; return 4;
} }
@ -223,12 +207,10 @@ public:
* @param order The byte-order in which the returned raw bytes are copied. * @param order The byte-order in which the returned raw bytes are copied.
* The default is network order. * The default is network order.
*/ */
void CopyIPv6(uint32_t* bytes, ByteOrder order = Network) const void CopyIPv6(uint32_t* bytes, ByteOrder order = Network) const {
{
memcpy(bytes, in6.s6_addr, sizeof(in6.s6_addr)); memcpy(bytes, in6.s6_addr, sizeof(in6.s6_addr));
if ( order == Host ) if ( order == Host ) {
{
for ( unsigned int i = 0; i < 4; ++i ) for ( unsigned int i = 0; i < 4; ++i )
bytes[i] = ntohl(bytes[i]); bytes[i] = ntohl(bytes[i]);
} }
@ -238,10 +220,7 @@ public:
* Retrieves a copy of the IPv6 raw byte representation of the address. * Retrieves a copy of the IPv6 raw byte representation of the address.
* @see CopyIPv6(uint32_t) * @see CopyIPv6(uint32_t)
*/ */
void CopyIPv6(in6_addr* arg_in6) const void CopyIPv6(in6_addr* arg_in6) const { memcpy(arg_in6->s6_addr, in6.s6_addr, sizeof(in6.s6_addr)); }
{
memcpy(arg_in6->s6_addr, in6.s6_addr, sizeof(in6.s6_addr));
}
/** /**
* Retrieves a copy of the IPv4 raw byte representation of the address. * Retrieves a copy of the IPv4 raw byte representation of the address.
@ -251,10 +230,7 @@ public:
* @param in4 The pointer to a memory location in which the raw bytes * @param in4 The pointer to a memory location in which the raw bytes
* of the address are to be copied in network byte-order. * of the address are to be copied in network byte-order.
*/ */
void CopyIPv4(in4_addr* in4) const void CopyIPv4(in4_addr* in4) const { memcpy(&in4->s_addr, &in6.s6_addr[12], sizeof(in4->s_addr)); }
{
memcpy(&in4->s_addr, &in6.s6_addr[12], sizeof(in4->s_addr));
}
/** /**
* Returns a key that can be used to lookup the IP Address in a hash table. * Returns a key that can be used to lookup the IP Address in a hash table.
@ -287,8 +263,7 @@ public:
/** /**
* Assignment operator. * Assignment operator.
*/ */
IPAddr& operator=(const IPAddr& other) IPAddr& operator=(const IPAddr& other) {
{
// No self-assignment check here because it's correct without it and // No self-assignment check here because it's correct without it and
// makes the common case faster. // makes the common case faster.
in6 = other.in6; in6 = other.in6;
@ -299,8 +274,7 @@ public:
* Bitwise OR operator returns the IP address resulting from the bitwise * Bitwise OR operator returns the IP address resulting from the bitwise
* OR operation on the raw bytes of this address with another. * OR operation on the raw bytes of this address with another.
*/ */
IPAddr operator|(const IPAddr& other) IPAddr operator|(const IPAddr& other) {
{
in6_addr result; in6_addr result;
for ( int i = 0; i < 16; ++i ) for ( int i = 0; i < 16; ++i )
result.s6_addr[i] = this->in6.s6_addr[i] | other.in6.s6_addr[i]; result.s6_addr[i] = this->in6.s6_addr[i] | other.in6.s6_addr[i];
@ -320,8 +294,7 @@ public:
* in an URI. For IPv4 addresses, this is the same as AsString(), but * in an URI. For IPv4 addresses, this is the same as AsString(), but
* IPv6 addresses are encased in square brackets. * IPv6 addresses are encased in square brackets.
*/ */
std::string AsURIString() const std::string AsURIString() const {
{
if ( GetFamily() == IPv4 ) if ( GetFamily() == IPv4 )
return AsString(); return AsString();
@ -348,8 +321,7 @@ public:
/** /**
* Comparison operator for IP address. * Comparison operator for IP address.
*/ */
friend bool operator==(const IPAddr& addr1, const IPAddr& addr2) friend bool operator==(const IPAddr& addr1, const IPAddr& addr2) {
{
return memcmp(&addr1.in6, &addr2.in6, sizeof(in6_addr)) == 0; return memcmp(&addr1.in6, &addr2.in6, sizeof(in6_addr)) == 0;
} }
@ -360,15 +332,11 @@ public:
* IP addresses. However, the order does not necessarily correspond to * IP addresses. However, the order does not necessarily correspond to
* their numerical values. * their numerical values.
*/ */
friend bool operator<(const IPAddr& addr1, const IPAddr& addr2) friend bool operator<(const IPAddr& addr1, const IPAddr& addr2) {
{
return memcmp(&addr1.in6, &addr2.in6, sizeof(in6_addr)) < 0; return memcmp(&addr1.in6, &addr2.in6, sizeof(in6_addr)) < 0;
} }
friend bool operator<=(const IPAddr& addr1, const IPAddr& addr2) friend bool operator<=(const IPAddr& addr1, const IPAddr& addr2) { return addr1 < addr2 || addr1 == addr2; }
{
return addr1 < addr2 || addr1 == addr2;
}
friend bool operator>=(const IPAddr& addr1, const IPAddr& addr2) { return ! (addr1 < addr2); } friend bool operator>=(const IPAddr& addr1, const IPAddr& addr2) { return ! (addr1 < addr2); }
@ -412,8 +380,7 @@ public:
* *
* @return whether the string is a valid IP address * @return whether the string is a valid IP address
*/ */
static bool IsValid(const char* s) static bool IsValid(const char* s) {
{
in6_addr tmp; in6_addr tmp;
return ConvertString(s, &tmp); return ConvertString(s, &tmp);
} }
@ -444,73 +411,57 @@ private:
// Top 96 bits of a v4-mapped-addr. // Top 96 bits of a v4-mapped-addr.
static constexpr uint8_t v4_mapped_prefix[12] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}; static constexpr uint8_t v4_mapped_prefix[12] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff};
}; };
inline IPAddr::IPAddr(Family family, const uint32_t* bytes, ByteOrder order) inline IPAddr::IPAddr(Family family, const uint32_t* bytes, ByteOrder order) {
{ if ( family == IPv4 ) {
if ( family == IPv4 )
{
memcpy(in6.s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix)); memcpy(in6.s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix));
memcpy(&in6.s6_addr[12], bytes, sizeof(uint32_t)); memcpy(&in6.s6_addr[12], bytes, sizeof(uint32_t));
if ( order == Host ) if ( order == Host ) {
{
uint32_t* p = (uint32_t*)&in6.s6_addr[12]; uint32_t* p = (uint32_t*)&in6.s6_addr[12];
*p = htonl(*p); *p = htonl(*p);
} }
} }
else else {
{
memcpy(in6.s6_addr, bytes, sizeof(in6.s6_addr)); memcpy(in6.s6_addr, bytes, sizeof(in6.s6_addr));
if ( order == Host ) if ( order == Host ) {
{ for ( unsigned int i = 0; i < 4; ++i ) {
for ( unsigned int i = 0; i < 4; ++i )
{
uint32_t* p = (uint32_t*)&in6.s6_addr[i * 4]; uint32_t* p = (uint32_t*)&in6.s6_addr[i * 4];
*p = htonl(*p); *p = htonl(*p);
} }
} }
} }
} }
inline bool IPAddr::IsLoopback() const inline bool IPAddr::IsLoopback() const {
{
if ( GetFamily() == IPv4 ) if ( GetFamily() == IPv4 )
return in6.s6_addr[12] == 127; return in6.s6_addr[12] == 127;
else else
return ((in6.s6_addr[0] == 0) && (in6.s6_addr[1] == 0) && (in6.s6_addr[2] == 0) && return ((in6.s6_addr[0] == 0) && (in6.s6_addr[1] == 0) && (in6.s6_addr[2] == 0) && (in6.s6_addr[3] == 0) &&
(in6.s6_addr[3] == 0) && (in6.s6_addr[4] == 0) && (in6.s6_addr[5] == 0) && (in6.s6_addr[4] == 0) && (in6.s6_addr[5] == 0) && (in6.s6_addr[6] == 0) && (in6.s6_addr[7] == 0) &&
(in6.s6_addr[6] == 0) && (in6.s6_addr[7] == 0) && (in6.s6_addr[8] == 0) && (in6.s6_addr[8] == 0) && (in6.s6_addr[9] == 0) && (in6.s6_addr[10] == 0) && (in6.s6_addr[11] == 0) &&
(in6.s6_addr[9] == 0) && (in6.s6_addr[10] == 0) && (in6.s6_addr[11] == 0) && (in6.s6_addr[12] == 0) && (in6.s6_addr[13] == 0) && (in6.s6_addr[14] == 0) && (in6.s6_addr[15] == 1));
(in6.s6_addr[12] == 0) && (in6.s6_addr[13] == 0) && (in6.s6_addr[14] == 0) && }
(in6.s6_addr[15] == 1));
}
inline void IPAddr::ConvertToThreadingValue(threading::Value::addr_t* v) const inline void IPAddr::ConvertToThreadingValue(threading::Value::addr_t* v) const {
{
v->family = GetFamily(); v->family = GetFamily();
switch ( v->family ) switch ( v->family ) {
{ case IPv4: CopyIPv4(&v->in.in4); return;
case IPv4:
CopyIPv4(&v->in.in4);
return;
case IPv6: case IPv6: CopyIPv6(&v->in.in6); return;
CopyIPv6(&v->in.in6);
return;
}
} }
}
/** /**
* Class storing both IPv4 and IPv6 prefixes * Class storing both IPv4 and IPv6 prefixes
* (i.e., \c 192.168.1.1/16 and \c FD00::/8. * (i.e., \c 192.168.1.1/16 and \c FD00::/8.
*/ */
class IPPrefix class IPPrefix {
{
public: public:
/** /**
* Constructs a prefix 0/0. * Constructs a prefix 0/0.
@ -555,7 +506,7 @@ public:
/** /**
* Copy constructor. * Copy constructor.
*/ */
IPPrefix(const IPPrefix& other) : prefix(other.prefix), length(other.length) { } IPPrefix(const IPPrefix& other) : prefix(other.prefix), length(other.length) {}
/** /**
* Destructor. * Destructor.
@ -585,8 +536,7 @@ public:
* *
* @param addr The address to test. * @param addr The address to test.
*/ */
bool Contains(const IPAddr& addr) const bool Contains(const IPAddr& addr) const {
{
IPAddr p(addr); IPAddr p(addr);
p.Mask(length); p.Mask(length);
return p == prefix; return p == prefix;
@ -594,8 +544,7 @@ public:
/** /**
* Assignment operator. * Assignment operator.
*/ */
IPPrefix& operator=(const IPPrefix& other) IPPrefix& operator=(const IPPrefix& other) {
{
// No self-assignment check here because it's correct without it and // No self-assignment check here because it's correct without it and
// makes the common case faster. // makes the common case faster.
prefix = other.prefix; prefix = other.prefix;
@ -621,8 +570,7 @@ public:
* Converts the prefix into the type used internally by the * Converts the prefix into the type used internally by the
* inter-thread communication. * inter-thread communication.
*/ */
void ConvertToThreadingValue(threading::Value::subnet_t* v) const void ConvertToThreadingValue(threading::Value::subnet_t* v) const {
{
v->length = length; v->length = length;
prefix.ConvertToThreadingValue(&v->prefix); prefix.ConvertToThreadingValue(&v->prefix);
} }
@ -630,8 +578,7 @@ public:
/** /**
* Comparison operator for IP prefix. * Comparison operator for IP prefix.
*/ */
friend bool operator==(const IPPrefix& net1, const IPPrefix& net2) friend bool operator==(const IPPrefix& net1, const IPPrefix& net2) {
{
return net1.Prefix() == net2.Prefix() && net1.Length() == net2.Length(); return net1.Prefix() == net2.Prefix() && net1.Length() == net2.Length();
} }
@ -642,8 +589,7 @@ public:
* IP prefix. However, the order does not necessarily corresponding to their * IP prefix. However, the order does not necessarily corresponding to their
* numerical values. * numerical values.
*/ */
friend bool operator<(const IPPrefix& net1, const IPPrefix& net2) friend bool operator<(const IPPrefix& net1, const IPPrefix& net2) {
{
if ( net1.Prefix() < net2.Prefix() ) if ( net1.Prefix() < net2.Prefix() )
return true; return true;
@ -654,10 +600,7 @@ public:
return false; return false;
} }
friend bool operator<=(const IPPrefix& net1, const IPPrefix& net2) friend bool operator<=(const IPPrefix& net1, const IPPrefix& net2) { return net1 < net2 || net1 == net2; }
{
return net1 < net2 || net1 == net2;
}
friend bool operator>=(const IPPrefix& net1, const IPPrefix& net2) { return ! (net1 < net2); } friend bool operator>=(const IPPrefix& net1, const IPPrefix& net2) { return ! (net1 < net2); }
@ -679,8 +622,7 @@ public:
* *
* @return whether the string is a valid IP address prefix * @return whether the string is a valid IP address prefix
*/ */
static bool IsValid(const char* s) static bool IsValid(const char* s) {
{
IPPrefix tmp; IPPrefix tmp;
return ConvertString(s, &tmp); return ConvertString(s, &tmp);
} }
@ -688,6 +630,6 @@ public:
private: private:
IPAddr prefix; // We store it as an address with the non-prefix bits masked out via Mask(). IPAddr prefix; // We store it as an address with the non-prefix bits masked out via Mask().
uint8_t length = 0; // The bit length of the prefix relative to full IPv6 addr. uint8_t length = 0; // The bit length of the prefix relative to full IPv6 addr.
}; };
} // namespace zeek } // namespace zeek

View file

@ -4,11 +4,9 @@
#include <cstring> #include <cstring>
namespace zeek::detail namespace zeek::detail {
{
void IntSet::Expand(unsigned int i) void IntSet::Expand(unsigned int i) {
{
unsigned int newsize = i / 8 + 1; unsigned int newsize = i / 8 + 1;
unsigned char* newset = new unsigned char[newsize]; unsigned char* newset = new unsigned char[newsize];
@ -18,6 +16,6 @@ void IntSet::Expand(unsigned int i)
delete[] set; delete[] set;
size = newsize; size = newsize;
set = newset; set = newset;
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -8,11 +8,9 @@
#include <cstring> #include <cstring>
namespace zeek::detail namespace zeek::detail {
{
class IntSet class IntSet {
{
public: public:
// n is a hint for the value of the largest integer. // n is a hint for the value of the largest integer.
explicit IntSet(unsigned int n = 1); explicit IntSet(unsigned int n = 1);
@ -29,44 +27,32 @@ private:
unsigned int size; unsigned int size;
unsigned char* set; unsigned char* set;
}; };
inline IntSet::IntSet(unsigned int n) inline IntSet::IntSet(unsigned int n) {
{
size = n / 8 + 1; size = n / 8 + 1;
set = new unsigned char[size]; set = new unsigned char[size];
memset(set, 0, size); memset(set, 0, size);
} }
inline IntSet::~IntSet() inline IntSet::~IntSet() { delete[] set; }
{
delete[] set;
}
inline void IntSet::Insert(unsigned int i) inline void IntSet::Insert(unsigned int i) {
{
if ( i / 8 >= size ) if ( i / 8 >= size )
Expand(i); Expand(i);
set[i / 8] |= (1 << (i % 8)); set[i / 8] |= (1 << (i % 8));
} }
inline void IntSet::Remove(unsigned int i) inline void IntSet::Remove(unsigned int i) {
{
if ( i / 8 >= size ) if ( i / 8 >= size )
Expand(i); Expand(i);
else else
set[i / 8] &= ~(1 << (i % 8)); set[i / 8] &= ~(1 << (i % 8));
} }
inline bool IntSet::Contains(unsigned int i) const inline bool IntSet::Contains(unsigned int i) const { return i / 8 < size ? set[i / 8] & (1 << (i % 8)) : false; }
{
return i / 8 < size ? set[i / 8] & (1 << (i % 8)) : false;
}
inline void IntSet::Clear() inline void IntSet::Clear() { memset(set, 0, size); }
{
memset(set, 0, size);
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -8,24 +8,19 @@
#include "Obj.h" #include "Obj.h"
namespace zeek namespace zeek {
{
/** /**
* A tag class for the #IntrusivePtr constructor which means: adopt * A tag class for the #IntrusivePtr constructor which means: adopt
* the reference from the caller. * the reference from the caller.
*/ */
struct AdoptRef struct AdoptRef {};
{
};
/** /**
* A tag class for the #IntrusivePtr constructor which means: create a * A tag class for the #IntrusivePtr constructor which means: create a
* new reference to the object. * new reference to the object.
*/ */
struct NewRef struct NewRef {};
{
};
/** /**
* This has to be forward declared and known here in order for us to be able * This has to be forward declared and known here in order for us to be able
@ -55,8 +50,8 @@ class OpaqueVal;
* should use a smart pointer whenever possible to reduce boilerplate code and * should use a smart pointer whenever possible to reduce boilerplate code and
* increase robustness of the code (in particular w.r.t. exceptions). * increase robustness of the code (in particular w.r.t. exceptions).
*/ */
template <class T> class IntrusivePtr template<class T>
{ class IntrusivePtr {
public: public:
// -- member types // -- member types
@ -74,8 +69,7 @@ public:
constexpr IntrusivePtr() noexcept = default; constexpr IntrusivePtr() noexcept = default;
constexpr IntrusivePtr(std::nullptr_t) noexcept : IntrusivePtr() constexpr IntrusivePtr(std::nullptr_t) noexcept : IntrusivePtr() {
{
// nop // nop
} }
@ -87,7 +81,7 @@ public:
* *
* @param raw_ptr Pointer to the shared object. * @param raw_ptr Pointer to the shared object.
*/ */
constexpr IntrusivePtr(AdoptRef, pointer raw_ptr) noexcept : ptr_(raw_ptr) { } constexpr IntrusivePtr(AdoptRef, pointer raw_ptr) noexcept : ptr_(raw_ptr) {}
/** /**
* Constructs a new intrusive pointer for managing the lifetime of the object * Constructs a new intrusive pointer for managing the lifetime of the object
@ -97,29 +91,24 @@ public:
* *
* @param raw_ptr Pointer to the shared object. * @param raw_ptr Pointer to the shared object.
*/ */
IntrusivePtr(NewRef, pointer raw_ptr) noexcept : ptr_(raw_ptr) IntrusivePtr(NewRef, pointer raw_ptr) noexcept : ptr_(raw_ptr) {
{
if ( ptr_ ) if ( ptr_ )
Ref(ptr_); Ref(ptr_);
} }
IntrusivePtr(IntrusivePtr&& other) noexcept : ptr_(other.release()) IntrusivePtr(IntrusivePtr&& other) noexcept : ptr_(other.release()) {
{
// nop // nop
} }
IntrusivePtr(const IntrusivePtr& other) noexcept : IntrusivePtr(NewRef{}, other.get()) { } IntrusivePtr(const IntrusivePtr& other) noexcept : IntrusivePtr(NewRef{}, other.get()) {}
template <class U, class = std::enable_if_t<std::is_convertible_v<U*, T*>>> template<class U, class = std::enable_if_t<std::is_convertible_v<U*, T*>>>
IntrusivePtr(IntrusivePtr<U> other) noexcept : ptr_(other.release()) IntrusivePtr(IntrusivePtr<U> other) noexcept : ptr_(other.release()) {
{
// nop // nop
} }
~IntrusivePtr() ~IntrusivePtr() {
{ if ( ptr_ ) {
if ( ptr_ )
{
// Specializing `OpaqueVal` as MSVC compiler does not detect it // Specializing `OpaqueVal` as MSVC compiler does not detect it
// inheriting from `zeek::Obj` so we have to do that manually. // inheriting from `zeek::Obj` so we have to do that manually.
if constexpr ( std::is_same_v<T, OpaqueVal> ) if constexpr ( std::is_same_v<T, OpaqueVal> )
@ -131,8 +120,7 @@ public:
void swap(IntrusivePtr& other) noexcept { std::swap(ptr_, other.ptr_); } void swap(IntrusivePtr& other) noexcept { std::swap(ptr_, other.ptr_); }
friend void swap(IntrusivePtr& a, IntrusivePtr& b) noexcept friend void swap(IntrusivePtr& a, IntrusivePtr& b) noexcept {
{
using std::swap; using std::swap;
swap(a.ptr_, b.ptr_); swap(a.ptr_, b.ptr_);
} }
@ -144,23 +132,19 @@ public:
*/ */
pointer release() noexcept { return std::exchange(ptr_, nullptr); } pointer release() noexcept { return std::exchange(ptr_, nullptr); }
IntrusivePtr& operator=(const IntrusivePtr& other) noexcept IntrusivePtr& operator=(const IntrusivePtr& other) noexcept {
{
IntrusivePtr tmp{other}; IntrusivePtr tmp{other};
swap(tmp); swap(tmp);
return *this; return *this;
} }
IntrusivePtr& operator=(IntrusivePtr&& other) noexcept IntrusivePtr& operator=(IntrusivePtr&& other) noexcept {
{
swap(other); swap(other);
return *this; return *this;
} }
IntrusivePtr& operator=(std::nullptr_t) noexcept IntrusivePtr& operator=(std::nullptr_t) noexcept {
{ if ( ptr_ ) {
if ( ptr_ )
{
Unref(ptr_); Unref(ptr_);
ptr_ = nullptr; ptr_ = nullptr;
} }
@ -179,7 +163,7 @@ public:
private: private:
pointer ptr_ = nullptr; pointer ptr_ = nullptr;
}; };
/** /**
* Convenience function for creating a reference counted object and wrapping it * Convenience function for creating a reference counted object and wrapping it
@ -189,11 +173,11 @@ private:
* @note This function assumes that any @c T starts with a reference count of 1. * @note This function assumes that any @c T starts with a reference count of 1.
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T, class... Ts> IntrusivePtr<T> make_intrusive(Ts&&... args) template<class T, class... Ts>
{ IntrusivePtr<T> make_intrusive(Ts&&... args) {
// Assumes that objects start with a reference count of 1! // Assumes that objects start with a reference count of 1!
return {AdoptRef{}, new T(std::forward<Ts>(args)...)}; return {AdoptRef{}, new T(std::forward<Ts>(args)...)};
} }
/** /**
* Casts an @c IntrusivePtr object to another by way of static_cast on * Casts an @c IntrusivePtr object to another by way of static_cast on
@ -201,78 +185,78 @@ template <class T, class... Ts> IntrusivePtr<T> make_intrusive(Ts&&... args)
* @param p The pointer of type @c U to cast to another type, @c T. * @param p The pointer of type @c U to cast to another type, @c T.
* @return The pointer, as cast to type @c T. * @return The pointer, as cast to type @c T.
*/ */
template <class T, class U> IntrusivePtr<T> cast_intrusive(IntrusivePtr<U> p) noexcept template<class T, class U>
{ IntrusivePtr<T> cast_intrusive(IntrusivePtr<U> p) noexcept {
return {AdoptRef{}, static_cast<T*>(p.release())}; return {AdoptRef{}, static_cast<T*>(p.release())};
} }
// -- comparison to nullptr ---------------------------------------------------- // -- comparison to nullptr ----------------------------------------------------
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator==(const zeek::IntrusivePtr<T>& x, std::nullptr_t) template<class T>
{ bool operator==(const zeek::IntrusivePtr<T>& x, std::nullptr_t) {
return ! x; return ! x;
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator==(std::nullptr_t, const zeek::IntrusivePtr<T>& x) template<class T>
{ bool operator==(std::nullptr_t, const zeek::IntrusivePtr<T>& x) {
return ! x; return ! x;
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator!=(const zeek::IntrusivePtr<T>& x, std::nullptr_t) template<class T>
{ bool operator!=(const zeek::IntrusivePtr<T>& x, std::nullptr_t) {
return static_cast<bool>(x); return static_cast<bool>(x);
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator!=(std::nullptr_t, const zeek::IntrusivePtr<T>& x) template<class T>
{ bool operator!=(std::nullptr_t, const zeek::IntrusivePtr<T>& x) {
return static_cast<bool>(x); return static_cast<bool>(x);
} }
// -- comparison to raw pointer ------------------------------------------------ // -- comparison to raw pointer ------------------------------------------------
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator==(const zeek::IntrusivePtr<T>& x, const T* y) template<class T>
{ bool operator==(const zeek::IntrusivePtr<T>& x, const T* y) {
return x.get() == y; return x.get() == y;
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator==(const T* x, const zeek::IntrusivePtr<T>& y) template<class T>
{ bool operator==(const T* x, const zeek::IntrusivePtr<T>& y) {
return x == y.get(); return x == y.get();
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator!=(const zeek::IntrusivePtr<T>& x, const T* y) template<class T>
{ bool operator!=(const zeek::IntrusivePtr<T>& x, const T* y) {
return x.get() != y; return x.get() != y;
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator!=(const T* x, const zeek::IntrusivePtr<T>& y) template<class T>
{ bool operator!=(const T* x, const zeek::IntrusivePtr<T>& y) {
return x != y.get(); return x != y.get();
} }
// -- comparison to intrusive pointer ------------------------------------------ // -- comparison to intrusive pointer ------------------------------------------
@ -282,35 +266,27 @@ template <class T> bool operator!=(const T* x, const zeek::IntrusivePtr<T>& y)
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T, class U> template<class T, class U>
auto operator==(const zeek::IntrusivePtr<T>& x, const zeek::IntrusivePtr<U>& y) auto operator==(const zeek::IntrusivePtr<T>& x, const zeek::IntrusivePtr<U>& y) -> decltype(x.get() == y.get()) {
-> decltype(x.get() == y.get())
{
return x.get() == y.get(); return x.get() == y.get();
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T, class U> template<class T, class U>
auto operator!=(const zeek::IntrusivePtr<T>& x, const zeek::IntrusivePtr<U>& y) auto operator!=(const zeek::IntrusivePtr<T>& x, const zeek::IntrusivePtr<U>& y) -> decltype(x.get() != y.get()) {
-> decltype(x.get() != y.get())
{
return x.get() != y.get(); return x.get() != y.get();
} }
} // namespace zeek } // namespace zeek
// -- hashing ------------------------------------------------ // -- hashing ------------------------------------------------
namespace std namespace std {
{ template<class T>
template <class T> struct hash<zeek::IntrusivePtr<T>> struct hash<zeek::IntrusivePtr<T>> {
{
// Hash of intrusive pointer is the same as hash of the raw pointer it holds. // Hash of intrusive pointer is the same as hash of the raw pointer it holds.
size_t operator()(const zeek::IntrusivePtr<T>& v) const noexcept size_t operator()(const zeek::IntrusivePtr<T>& v) const noexcept { return std::hash<T*>{}(v.get()); }
{ };
return std::hash<T*>{}(v.get()); } // namespace std
}
};
}

View file

@ -2,18 +2,16 @@
#include "zeek/3rdparty/doctest.h" #include "zeek/3rdparty/doctest.h"
TEST_CASE("list construction") TEST_CASE("list construction") {
{
zeek::List<int> list; zeek::List<int> list;
CHECK(list.empty()); CHECK(list.empty());
zeek::List<int> list2(10); zeek::List<int> list2(10);
CHECK(list2.empty()); CHECK(list2.empty());
CHECK(list2.max() == 10); CHECK(list2.max() == 10);
} }
TEST_CASE("list operation") TEST_CASE("list operation") {
{
zeek::List<int> list({1, 2, 3}); zeek::List<int> list({1, 2, 3});
CHECK(list.size() == 3); CHECK(list.size() == 3);
CHECK(list.max() == 3); CHECK(list.max() == 3);
@ -82,10 +80,9 @@ TEST_CASE("list operation")
list.clear(); list.clear();
CHECK(list.size() == 0); CHECK(list.size() == 0);
CHECK(list.max() == 0); CHECK(list.max() == 0);
} }
TEST_CASE("list iteration") TEST_CASE("list iteration") {
{
zeek::List<int> list({1, 2, 3, 4}); zeek::List<int> list({1, 2, 3, 4});
int index = 1; int index = 1;
@ -95,10 +92,9 @@ TEST_CASE("list iteration")
index = 1; index = 1;
for ( auto it = list.begin(); it != list.end(); index++, ++it ) for ( auto it = list.begin(); it != list.end(); index++, ++it )
CHECK(*it == index); CHECK(*it == index);
} }
TEST_CASE("plists") TEST_CASE("plists") {
{
zeek::PList<int> list; zeek::PList<int> list;
list.push_back(new int{1}); list.push_back(new int{1});
list.push_back(new int{2}); list.push_back(new int{2});
@ -114,10 +110,9 @@ TEST_CASE("plists")
for ( auto v : list ) for ( auto v : list )
delete v; delete v;
list.clear(); list.clear();
} }
TEST_CASE("unordered list operation") TEST_CASE("unordered list operation") {
{
zeek::List<int, zeek::ListOrder::UNORDERED> list({1, 2, 3, 4}); zeek::List<int, zeek::ListOrder::UNORDERED> list({1, 2, 3, 4});
CHECK(list.size() == 4); CHECK(list.size() == 4);
@ -128,4 +123,4 @@ TEST_CASE("unordered list operation")
CHECK(list[0] == 1); CHECK(list[0] == 1);
CHECK(list[1] == 4); CHECK(list[1] == 4);
CHECK(list[2] == 3); CHECK(list[2] == 3);
} }

View file

@ -27,28 +27,21 @@
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek namespace zeek {
{
enum class ListOrder : int enum class ListOrder : int { ORDERED, UNORDERED };
{
ORDERED,
UNORDERED
};
template <typename T, ListOrder Order = ListOrder::ORDERED> class List template<typename T, ListOrder Order = ListOrder::ORDERED>
{ class List {
public: public:
constexpr static int DEFAULT_LIST_SIZE = 10; constexpr static int DEFAULT_LIST_SIZE = 10;
constexpr static int LIST_GROWTH_FACTOR = 2; constexpr static int LIST_GROWTH_FACTOR = 2;
~List() { free(entries); } ~List() { free(entries); }
explicit List(int size = 0) explicit List(int size = 0) {
{
num_entries = 0; num_entries = 0;
if ( size <= 0 ) if ( size <= 0 ) {
{
max_entries = 0; max_entries = 0;
entries = nullptr; entries = nullptr;
return; return;
@ -59,8 +52,7 @@ public:
entries = (T*)util::safe_malloc(max_entries * sizeof(T)); entries = (T*)util::safe_malloc(max_entries * sizeof(T));
} }
List(const List& b) List(const List& b) {
{
max_entries = b.max_entries; max_entries = b.max_entries;
num_entries = b.num_entries; num_entries = b.num_entries;
@ -73,8 +65,7 @@ public:
entries[i] = b.entries[i]; entries[i] = b.entries[i];
} }
List(List&& b) List(List&& b) {
{
entries = b.entries; entries = b.entries;
num_entries = b.num_entries; num_entries = b.num_entries;
max_entries = b.max_entries; max_entries = b.max_entries;
@ -83,17 +74,15 @@ public:
b.num_entries = b.max_entries = 0; b.num_entries = b.max_entries = 0;
} }
List(const T* arr, int n) List(const T* arr, int n) {
{
num_entries = max_entries = n; num_entries = max_entries = n;
entries = (T*)util::safe_malloc(max_entries * sizeof(T)); entries = (T*)util::safe_malloc(max_entries * sizeof(T));
memcpy(entries, arr, n * sizeof(T)); memcpy(entries, arr, n * sizeof(T));
} }
List(std::initializer_list<T> il) : List(il.begin(), il.size()) { } List(std::initializer_list<T> il) : List(il.begin(), il.size()) {}
List& operator=(const List& b) List& operator=(const List& b) {
{
if ( this == &b ) if ( this == &b )
return *this; return *this;
@ -113,8 +102,7 @@ public:
return *this; return *this;
} }
List& operator=(List&& b) List& operator=(List&& b) {
{
if ( this == &b ) if ( this == &b )
return *this; return *this;
@ -148,8 +136,7 @@ public:
if ( new_size < num_entries ) if ( new_size < num_entries )
new_size = num_entries; // do not lose any entries new_size = num_entries; // do not lose any entries
if ( new_size != max_entries ) if ( new_size != max_entries ) {
{
entries = (T*)util::safe_realloc((void*)entries, sizeof(T) * new_size); entries = (T*)util::safe_realloc((void*)entries, sizeof(T) * new_size);
if ( entries ) if ( entries )
max_entries = new_size; max_entries = new_size;
@ -160,8 +147,7 @@ public:
return max_entries; return max_entries;
} }
void push_front(const T& a) void push_front(const T& a) {
{
if ( num_entries == max_entries ) if ( num_entries == max_entries )
resize(max_entries ? max_entries * LIST_GROWTH_FACTOR : DEFAULT_LIST_SIZE); resize(max_entries ? max_entries * LIST_GROWTH_FACTOR : DEFAULT_LIST_SIZE);
@ -172,8 +158,7 @@ public:
entries[0] = a; entries[0] = a;
} }
void push_back(const T& a) void push_back(const T& a) {
{
if ( num_entries == max_entries ) if ( num_entries == max_entries )
resize(max_entries ? max_entries * LIST_GROWTH_FACTOR : DEFAULT_LIST_SIZE); resize(max_entries ? max_entries * LIST_GROWTH_FACTOR : DEFAULT_LIST_SIZE);
@ -196,8 +181,7 @@ public:
bool remove(const T& a) // delete entry from list bool remove(const T& a) // delete entry from list
{ {
int pos = member_pos(a); int pos = member_pos(a);
if ( pos != -1 ) if ( pos != -1 ) {
{
remove_nth(pos); remove_nth(pos);
return true; return true;
} }
@ -214,15 +198,13 @@ public:
// For data where we don't care about ordering, we don't care about keeping // For data where we don't care about ordering, we don't care about keeping
// the list in the same order when removing an element. Just swap the last // the list in the same order when removing an element. Just swap the last
// element with the element being removed. // element with the element being removed.
if constexpr ( Order == ListOrder::ORDERED ) if constexpr ( Order == ListOrder::ORDERED ) {
{
--num_entries; --num_entries;
for ( ; n < num_entries; ++n ) for ( ; n < num_entries; ++n )
entries[n] = entries[n + 1]; entries[n] = entries[n + 1];
} }
else else {
{
entries[n] = entries[num_entries - 1]; entries[n] = entries[num_entries - 1];
--num_entries; --num_entries;
} }
@ -231,15 +213,13 @@ public:
} }
// Return 0 if ent is not in the list, ent otherwise. // Return 0 if ent is not in the list, ent otherwise.
bool is_member(const T& a) const bool is_member(const T& a) const {
{
int pos = member_pos(a); int pos = member_pos(a);
return pos != -1; return pos != -1;
} }
// Returns -1 if ent is not in the list, otherwise its position. // Returns -1 if ent is not in the list, otherwise its position.
int member_pos(const T& e) const int member_pos(const T& e) const {
{
int i; int i;
for ( i = 0; i < length() && e != entries[i]; ++i ) for ( i = 0; i < length() && e != entries[i]; ++i )
; ;
@ -254,8 +234,7 @@ public:
T old_ent{}; T old_ent{};
if ( ent_index > num_entries - 1 ) if ( ent_index > num_entries - 1 ) { // replacement beyond the end of the list
{ // replacement beyond the end of the list
resize(ent_index + 1); resize(ent_index + 1);
for ( int i = num_entries; i < max_entries; ++i ) for ( int i = num_entries; i < max_entries; ++i )
@ -323,15 +302,16 @@ protected:
T* entries; T* entries;
int max_entries; int max_entries;
int num_entries; int num_entries;
}; };
// Specialization of the List class to store pointers of a type. // Specialization of the List class to store pointers of a type.
template <typename T, ListOrder Order = ListOrder::ORDERED> using PList = List<T*, Order>; template<typename T, ListOrder Order = ListOrder::ORDERED>
using PList = List<T*, Order>;
// Popular type of list: list of strings. // Popular type of list: list of strings.
using name_list = PList<char>; using name_list = PList<char>;
} // namespace zeek } // namespace zeek
// Macro to visit each list element in turn. // Macro to visit each list element in turn.
#define loop_over_list(list, iterator) \ #define loop_over_list(list, iterator) \

View file

@ -10,13 +10,11 @@
#include "zeek/EquivClass.h" #include "zeek/EquivClass.h"
#include "zeek/IntSet.h" #include "zeek/IntSet.h"
namespace zeek::detail namespace zeek::detail {
{
static int nfa_state_id = 0; static int nfa_state_id = 0;
NFA_State::NFA_State(int arg_sym, EquivClass* ec) NFA_State::NFA_State(int arg_sym, EquivClass* ec) {
{
sym = arg_sym; sym = arg_sym;
ccl = nullptr; ccl = nullptr;
accept = NO_ACCEPT; accept = NO_ACCEPT;
@ -35,10 +33,9 @@ NFA_State::NFA_State(int arg_sym, EquivClass* ec)
if ( ec && sym != SYM_EPSILON /* no associated symbol */ ) if ( ec && sym != SYM_EPSILON /* no associated symbol */ )
ec->UniqueChar(sym); ec->UniqueChar(sym);
} }
NFA_State::NFA_State(CCL* arg_ccl) NFA_State::NFA_State(CCL* arg_ccl) {
{
sym = SYM_CCL; sym = SYM_CCL;
ccl = arg_ccl; ccl = arg_ccl;
accept = NO_ACCEPT; accept = NO_ACCEPT;
@ -46,27 +43,23 @@ NFA_State::NFA_State(CCL* arg_ccl)
mark = nullptr; mark = nullptr;
id = ++nfa_state_id; id = ++nfa_state_id;
epsclosure = nullptr; epsclosure = nullptr;
} }
NFA_State::~NFA_State() NFA_State::~NFA_State() {
{
for ( int i = 0; i < xtions.length(); ++i ) for ( int i = 0; i < xtions.length(); ++i )
if ( i > 0 || ! first_trans_is_back_ref ) if ( i > 0 || ! first_trans_is_back_ref )
Unref(xtions[i]); Unref(xtions[i]);
delete epsclosure; delete epsclosure;
} }
void NFA_State::AddXtionsTo(NFA_state_list* ns) void NFA_State::AddXtionsTo(NFA_state_list* ns) {
{
for ( int i = 0; i < xtions.length(); ++i ) for ( int i = 0; i < xtions.length(); ++i )
ns->push_back(xtions[i]); ns->push_back(xtions[i]);
} }
NFA_State* NFA_State::DeepCopy() NFA_State* NFA_State::DeepCopy() {
{ if ( mark ) {
if ( mark )
{
Ref(mark); Ref(mark);
return mark; return mark;
} }
@ -78,20 +71,17 @@ NFA_State* NFA_State::DeepCopy()
copy->AddXtion(xtions[i]->DeepCopy()); copy->AddXtion(xtions[i]->DeepCopy());
return copy; return copy;
} }
void NFA_State::ClearMarks() void NFA_State::ClearMarks() {
{ if ( mark ) {
if ( mark )
{
SetMark(nullptr); SetMark(nullptr);
for ( int i = 0; i < xtions.length(); ++i ) for ( int i = 0; i < xtions.length(); ++i )
xtions[i]->ClearMarks(); xtions[i]->ClearMarks();
} }
} }
NFA_state_list* NFA_State::EpsilonClosure() NFA_state_list* NFA_State::EpsilonClosure() {
{
if ( epsclosure ) if ( epsclosure )
return epsclosure; return epsclosure;
@ -102,17 +92,13 @@ NFA_state_list* NFA_State::EpsilonClosure()
SetMark(this); SetMark(this);
int i; int i;
for ( i = 0; i < states.length(); ++i ) for ( i = 0; i < states.length(); ++i ) {
{
NFA_State* ns = states[i]; NFA_State* ns = states[i];
if ( ns->TransSym() == SYM_EPSILON ) if ( ns->TransSym() == SYM_EPSILON ) {
{
NFA_state_list* x = ns->Transitions(); NFA_state_list* x = ns->Transitions();
for ( int j = 0; j < x->length(); ++j ) for ( int j = 0; j < x->length(); ++j ) {
{
NFA_State* nxt = (*x)[j]; NFA_State* nxt = (*x)[j];
if ( ! nxt->Mark() ) if ( ! nxt->Mark() ) {
{
states.push_back(nxt); states.push_back(nxt);
nxt->SetMark(nxt); nxt->SetMark(nxt);
} }
@ -135,15 +121,11 @@ NFA_state_list* NFA_State::EpsilonClosure()
epsclosure->resize(0); epsclosure->resize(0);
return epsclosure; return epsclosure;
} }
void NFA_State::Describe(ODesc* d) const void NFA_State::Describe(ODesc* d) const { d->Add("NFA state"); }
{
d->Add("NFA state");
}
void NFA_State::Dump(FILE* f) void NFA_State::Dump(FILE* f) {
{
if ( mark ) if ( mark )
return; return;
@ -155,34 +137,25 @@ void NFA_State::Dump(FILE* f)
SetMark(this); SetMark(this);
for ( int i = 0; i < xtions.length(); ++i ) for ( int i = 0; i < xtions.length(); ++i )
xtions[i]->Dump(f); xtions[i]->Dump(f);
} }
NFA_Machine::NFA_Machine(NFA_State* first, NFA_State* final) NFA_Machine::NFA_Machine(NFA_State* first, NFA_State* final) {
{
first_state = first; first_state = first;
final_state = final ? final : first; final_state = final ? final : first;
eol = bol = 0; eol = bol = 0;
} }
NFA_Machine::~NFA_Machine() NFA_Machine::~NFA_Machine() { Unref(first_state); }
{
Unref(first_state);
}
void NFA_Machine::InsertEpsilon() void NFA_Machine::InsertEpsilon() {
{
NFA_State* eps = new EpsilonState(); NFA_State* eps = new EpsilonState();
eps->AddXtion(first_state); eps->AddXtion(first_state);
first_state = eps; first_state = eps;
} }
void NFA_Machine::AppendEpsilon() void NFA_Machine::AppendEpsilon() { AppendState(new EpsilonState()); }
{
AppendState(new EpsilonState());
}
void NFA_Machine::AddAccept(int accept_val) void NFA_Machine::AddAccept(int accept_val) {
{
// Hang the accepting number off an epsilon state. If it is associated // Hang the accepting number off an epsilon state. If it is associated
// with a state that has a non-epsilon out-transition, then the state // with a state that has a non-epsilon out-transition, then the state
// will accept BEFORE it makes that transition, i.e., one character // will accept BEFORE it makes that transition, i.e., one character
@ -192,10 +165,9 @@ void NFA_Machine::AddAccept(int accept_val)
AppendState(new EpsilonState()); AppendState(new EpsilonState());
final_state->SetAccept(accept_val); final_state->SetAccept(accept_val);
} }
void NFA_Machine::LinkCopies(int n) void NFA_Machine::LinkCopies(int n) {
{
if ( n <= 0 ) if ( n <= 0 )
return; return;
@ -212,68 +184,60 @@ void NFA_Machine::LinkCopies(int n)
AppendMachine(copies[i]); AppendMachine(copies[i]);
delete[] copies; delete[] copies;
} }
NFA_Machine* NFA_Machine::DuplicateMachine() NFA_Machine* NFA_Machine::DuplicateMachine() {
{
NFA_State* new_first_state = first_state->DeepCopy(); NFA_State* new_first_state = first_state->DeepCopy();
NFA_Machine* new_m = new NFA_Machine(new_first_state, final_state->Mark()); NFA_Machine* new_m = new NFA_Machine(new_first_state, final_state->Mark());
first_state->ClearMarks(); first_state->ClearMarks();
return new_m; return new_m;
} }
void NFA_Machine::AppendState(NFA_State* s) void NFA_Machine::AppendState(NFA_State* s) {
{
final_state->AddXtion(s); final_state->AddXtion(s);
final_state = s; final_state = s;
} }
void NFA_Machine::AppendMachine(NFA_Machine* m) void NFA_Machine::AppendMachine(NFA_Machine* m) {
{
AppendEpsilon(); AppendEpsilon();
final_state->AddXtion(m->FirstState()); final_state->AddXtion(m->FirstState());
final_state = m->FinalState(); final_state = m->FinalState();
Ref(m->FirstState()); // so states stay around after the following Ref(m->FirstState()); // so states stay around after the following
Unref(m); Unref(m);
} }
void NFA_Machine::MakeOptional() void NFA_Machine::MakeOptional() {
{
InsertEpsilon(); InsertEpsilon();
AppendEpsilon(); AppendEpsilon();
first_state->AddXtion(final_state); first_state->AddXtion(final_state);
Ref(final_state); Ref(final_state);
} }
void NFA_Machine::MakePositiveClosure() void NFA_Machine::MakePositiveClosure() {
{
AppendEpsilon(); AppendEpsilon();
final_state->AddXtion(first_state); final_state->AddXtion(first_state);
// Don't Ref the state the final epsilon points to, otherwise we'll // Don't Ref the state the final epsilon points to, otherwise we'll
// have reference cycles that lead to leaks. // have reference cycles that lead to leaks.
final_state->SetFirstTransIsBackRef(); final_state->SetFirstTransIsBackRef();
} }
void NFA_Machine::MakeRepl(int lower, int upper) void NFA_Machine::MakeRepl(int lower, int upper) {
{
NFA_Machine* dup = nullptr; NFA_Machine* dup = nullptr;
if ( upper > lower || upper == NO_UPPER_BOUND ) if ( upper > lower || upper == NO_UPPER_BOUND )
dup = DuplicateMachine(); dup = DuplicateMachine();
LinkCopies(lower - 1); LinkCopies(lower - 1);
if ( upper == NO_UPPER_BOUND ) if ( upper == NO_UPPER_BOUND ) {
{
dup->MakeClosure(); dup->MakeClosure();
AppendMachine(dup); AppendMachine(dup);
return; return;
} }
while ( upper > lower ) while ( upper > lower ) {
{
NFA_Machine* dup2; NFA_Machine* dup2;
if ( --upper == lower ) if ( --upper == lower )
// Don't need "dup" for any further copies // Don't need "dup" for any further copies
@ -284,21 +248,16 @@ void NFA_Machine::MakeRepl(int lower, int upper)
dup2->MakeOptional(); dup2->MakeOptional();
AppendMachine(dup2); AppendMachine(dup2);
} }
} }
void NFA_Machine::Describe(ODesc* d) const void NFA_Machine::Describe(ODesc* d) const { d->Add("NFA machine"); }
{
d->Add("NFA machine");
}
void NFA_Machine::Dump(FILE* f) void NFA_Machine::Dump(FILE* f) {
{
first_state->Dump(f); first_state->Dump(f);
first_state->ClearMarks(); first_state->ClearMarks();
} }
NFA_Machine* make_alternate(NFA_Machine* m1, NFA_Machine* m2) NFA_Machine* make_alternate(NFA_Machine* m1, NFA_Machine* m2) {
{
if ( ! m1 ) if ( ! m1 )
return m2; return m2;
if ( ! m2 ) if ( ! m2 )
@ -322,25 +281,21 @@ NFA_Machine* make_alternate(NFA_Machine* m1, NFA_Machine* m2)
Unref(m2); Unref(m2);
return new NFA_Machine(first, last); return new NFA_Machine(first, last);
} }
NFA_state_list* epsilon_closure(NFA_state_list* states) NFA_state_list* epsilon_closure(NFA_state_list* states) {
{
// We just keep one of this as it may get quite large. // We just keep one of this as it may get quite large.
static IntSet closuremap; static IntSet closuremap;
closuremap.Clear(); closuremap.Clear();
NFA_state_list* closure = new NFA_state_list; NFA_state_list* closure = new NFA_state_list;
for ( int i = 0; i < states->length(); ++i ) for ( int i = 0; i < states->length(); ++i ) {
{
NFA_state_list* stateclosure = (*states)[i]->EpsilonClosure(); NFA_state_list* stateclosure = (*states)[i]->EpsilonClosure();
for ( int j = 0; j < stateclosure->length(); ++j ) for ( int j = 0; j < stateclosure->length(); ++j ) {
{
NFA_State* ns = (*stateclosure)[j]; NFA_State* ns = (*stateclosure)[j];
if ( ! closuremap.Contains(ns->ID()) ) if ( ! closuremap.Contains(ns->ID()) ) {
{
closuremap.Insert(ns->ID()); closuremap.Insert(ns->ID());
closure->push_back(ns); closure->push_back(ns);
} }
@ -356,14 +311,13 @@ NFA_state_list* epsilon_closure(NFA_state_list* states)
delete states; delete states;
return closure; return closure;
} }
bool NFA_state_cmp_neg(const NFA_State* v1, const NFA_State* v2) bool NFA_state_cmp_neg(const NFA_State* v1, const NFA_State* v2) {
{
if ( v1->ID() < v2->ID() ) if ( v1->ID() < v2->ID() )
return true; return true;
else else
return false; return false;
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -16,13 +16,11 @@
#define SYM_EPSILON 259 #define SYM_EPSILON 259
#define SYM_CCL 260 #define SYM_CCL 260
namespace zeek namespace zeek {
{
class Func; class Func;
namespace detail namespace detail {
{
class CCL; class CCL;
class EquivClass; class EquivClass;
@ -30,8 +28,7 @@ class EquivClass;
class NFA_State; class NFA_State;
using NFA_state_list = PList<NFA_State>; using NFA_state_list = PList<NFA_State>;
class NFA_State : public Obj class NFA_State : public Obj {
{
public: public:
NFA_State(int sym, EquivClass* ec); NFA_State(int sym, EquivClass* ec);
explicit NFA_State(CCL* ccl); explicit NFA_State(CCL* ccl);
@ -77,16 +74,14 @@ protected:
NFA_state_list xtions; NFA_state_list xtions;
NFA_state_list* epsclosure; NFA_state_list* epsclosure;
NFA_State* mark; NFA_State* mark;
}; };
class EpsilonState : public NFA_State class EpsilonState : public NFA_State {
{
public: public:
EpsilonState() : NFA_State(SYM_EPSILON, nullptr) { } EpsilonState() : NFA_State(SYM_EPSILON, nullptr) {}
}; };
class NFA_Machine : public Obj class NFA_Machine : public Obj {
{
public: public:
explicit NFA_Machine(NFA_State* first, NFA_State* final = nullptr); explicit NFA_Machine(NFA_State* first, NFA_State* final = nullptr);
~NFA_Machine() override; ~NFA_Machine() override;
@ -98,8 +93,7 @@ public:
void AddAccept(int accept_val); void AddAccept(int accept_val);
void MakeClosure() void MakeClosure() {
{
MakePositiveClosure(); MakePositiveClosure();
MakeOptional(); MakeOptional();
} }
@ -127,7 +121,7 @@ protected:
NFA_State* first_state; NFA_State* first_state;
NFA_State* final_state; NFA_State* final_state;
int bol, eol; int bol, eol;
}; };
extern NFA_Machine* make_alternate(NFA_Machine* m1, NFA_Machine* m2); extern NFA_Machine* make_alternate(NFA_Machine* m1, NFA_Machine* m2);
@ -141,5 +135,5 @@ extern NFA_state_list* epsilon_closure(NFA_state_list* states);
// For sorting NFA states based on their ID fields (decreasing) // For sorting NFA states based on their ID fields (decreasing)
extern bool NFA_state_cmp_neg(const NFA_State* v1, const NFA_State* v2); extern bool NFA_state_cmp_neg(const NFA_State* v1, const NFA_State* v2);
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -105,8 +105,7 @@ zeek::StringVal* cmd_line_bpf_filter;
zeek::StringVal* global_hash_seed; zeek::StringVal* global_hash_seed;
namespace zeek::detail namespace zeek::detail {
{
int watchdog_interval; int watchdog_interval;
@ -194,29 +193,26 @@ zeek_uint_t bits_per_uid;
zeek_uint_t tunnel_max_changes_per_connection; zeek_uint_t tunnel_max_changes_per_connection;
} // namespace zeek::detail. The namespace has be closed here before we include the netvar_def } // namespace zeek::detail
// files. // files.
// Because of how the BIF include files are built with namespaces already in them, // Because of how the BIF include files are built with namespaces already in them,
// these files need to be included separately before the namespace is opened below. // these files need to be included separately before the namespace is opened below.
static void bif_init_event_handlers() static void bif_init_event_handlers() {
{
#include "event.bif.netvar_init" #include "event.bif.netvar_init"
} }
static void bif_init_net_var() static void bif_init_net_var() {
{
#include "const.bif.netvar_init" #include "const.bif.netvar_init"
#include "packet_analysis.bif.netvar_init" #include "packet_analysis.bif.netvar_init"
#include "reporter.bif.netvar_init" #include "reporter.bif.netvar_init"
#include "supervisor.bif.netvar_init" #include "supervisor.bif.netvar_init"
} }
static void init_bif_types() static void init_bif_types() {
{
#include "types.bif.netvar_init" #include "types.bif.netvar_init"
} }
#include "const.bif.netvar_def" #include "const.bif.netvar_def"
#include "event.bif.netvar_def" #include "event.bif.netvar_def"
@ -226,16 +222,11 @@ static void init_bif_types()
#include "types.bif.netvar_def" #include "types.bif.netvar_def"
// Re-open the namespace now that the bif headers are all included. // Re-open the namespace now that the bif headers are all included.
namespace zeek::detail namespace zeek::detail {
{
void init_event_handlers() void init_event_handlers() { bif_init_event_handlers(); }
{
bif_init_event_handlers();
}
void init_general_global_var() void init_general_global_var() {
{
table_expire_interval = id::find_val("table_expire_interval")->AsInterval(); table_expire_interval = id::find_val("table_expire_interval")->AsInterval();
table_expire_delay = id::find_val("table_expire_delay")->AsInterval(); table_expire_delay = id::find_val("table_expire_delay")->AsInterval();
table_incremental_step = id::find_val("table_incremental_step")->AsCount(); table_incremental_step = id::find_val("table_incremental_step")->AsCount();
@ -244,16 +235,14 @@ void init_general_global_var()
check_for_unused_event_handlers = id::find_val("check_for_unused_event_handlers")->AsBool(); check_for_unused_event_handlers = id::find_val("check_for_unused_event_handlers")->AsBool();
record_all_packets = id::find_val("record_all_packets")->AsBool(); record_all_packets = id::find_val("record_all_packets")->AsBool();
bits_per_uid = id::find_val("bits_per_uid")->AsCount(); bits_per_uid = id::find_val("bits_per_uid")->AsCount();
} }
void init_builtin_types() void init_builtin_types() {
{
init_bif_types(); init_bif_types();
id::detail::init_types(); id::detail::init_types();
} }
void init_net_var() void init_net_var() {
{
bif_init_net_var(); bif_init_net_var();
ignore_checksums = id::find_val("ignore_checksums")->AsBool(); ignore_checksums = id::find_val("ignore_checksums")->AsBool();
@ -272,10 +261,8 @@ void init_net_var()
tcp_partial_close_delay = id::find_val("tcp_partial_close_delay")->AsInterval(); tcp_partial_close_delay = id::find_val("tcp_partial_close_delay")->AsInterval();
tcp_max_initial_window = id::find_val("tcp_max_initial_window")->AsCount(); tcp_max_initial_window = id::find_val("tcp_max_initial_window")->AsCount();
tcp_max_above_hole_without_any_acks = tcp_max_above_hole_without_any_acks = id::find_val("tcp_max_above_hole_without_any_acks")->AsCount();
id::find_val("tcp_max_above_hole_without_any_acks")->AsCount(); tcp_excessive_data_without_further_acks = id::find_val("tcp_excessive_data_without_further_acks")->AsCount();
tcp_excessive_data_without_further_acks =
id::find_val("tcp_excessive_data_without_further_acks")->AsCount();
tcp_max_old_segments = id::find_val("tcp_max_old_segments")->AsCount(); tcp_max_old_segments = id::find_val("tcp_max_old_segments")->AsCount();
non_analyzed_lifetime = id::find_val("non_analyzed_lifetime")->AsInterval(); non_analyzed_lifetime = id::find_val("non_analyzed_lifetime")->AsInterval();
@ -291,8 +278,7 @@ void init_net_var()
udp_content_deliver_all_orig = bool(id::find_val("udp_content_deliver_all_orig")->AsBool()); udp_content_deliver_all_orig = bool(id::find_val("udp_content_deliver_all_orig")->AsBool());
udp_content_deliver_all_resp = bool(id::find_val("udp_content_deliver_all_resp")->AsBool()); udp_content_deliver_all_resp = bool(id::find_val("udp_content_deliver_all_resp")->AsBool());
udp_content_delivery_ports_use_resp = bool( udp_content_delivery_ports_use_resp = bool(id::find_val("udp_content_delivery_ports_use_resp")->AsBool());
id::find_val("udp_content_delivery_ports_use_resp")->AsBool());
dns_session_timeout = id::find_val("dns_session_timeout")->AsInterval(); dns_session_timeout = id::find_val("dns_session_timeout")->AsInterval();
rpc_timeout = id::find_val("rpc_timeout")->AsInterval(); rpc_timeout = id::find_val("rpc_timeout")->AsInterval();
@ -345,8 +331,7 @@ void init_net_var()
dpd_late_match_stop = id::find_val("dpd_late_match_stop")->AsBool(); dpd_late_match_stop = id::find_val("dpd_late_match_stop")->AsBool();
dpd_ignore_ports = id::find_val("dpd_ignore_ports")->AsBool(); dpd_ignore_ports = id::find_val("dpd_ignore_ports")->AsBool();
tunnel_max_changes_per_connection = tunnel_max_changes_per_connection = id::find_val("Tunnel::max_changes_per_connection")->AsCount();
id::find_val("Tunnel::max_changes_per_connection")->AsCount(); }
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -6,8 +6,7 @@
#include "zeek/Stats.h" #include "zeek/Stats.h"
#include "zeek/Val.h" #include "zeek/Val.h"
namespace zeek::detail namespace zeek::detail {
{
extern int watchdog_interval; extern int watchdog_interval;
@ -103,7 +102,7 @@ extern void init_event_handlers();
extern void init_net_var(); extern void init_net_var();
extern void init_builtin_types(); extern void init_builtin_types();
} // namespace zeek::detail } // namespace zeek::detail
#include "const.bif.netvar_h" #include "const.bif.netvar_h"
#include "event.bif.netvar_h" #include "event.bif.netvar_h"

View file

@ -8,51 +8,38 @@
zeek::notifier::detail::Registry zeek::notifier::detail::registry; zeek::notifier::detail::Registry zeek::notifier::detail::registry;
namespace zeek::notifier::detail namespace zeek::notifier::detail {
{
Receiver::Receiver() Receiver::Receiver() { DBG_LOG(DBG_NOTIFIERS, "creating receiver %p", this); }
{
DBG_LOG(DBG_NOTIFIERS, "creating receiver %p", this);
}
Receiver::~Receiver() Receiver::~Receiver() { DBG_LOG(DBG_NOTIFIERS, "deleting receiver %p", this); }
{
DBG_LOG(DBG_NOTIFIERS, "deleting receiver %p", this);
}
Registry::~Registry() Registry::~Registry() {
{
while ( registrations.begin() != registrations.end() ) while ( registrations.begin() != registrations.end() )
Unregister(registrations.begin()->first); Unregister(registrations.begin()->first);
} }
void Registry::Register(Modifiable* m, Receiver* r) void Registry::Register(Modifiable* m, Receiver* r) {
{
DBG_LOG(DBG_NOTIFIERS, "registering object %p for receiver %p", m, r); DBG_LOG(DBG_NOTIFIERS, "registering object %p for receiver %p", m, r);
registrations.insert({m, r}); registrations.insert({m, r});
++m->num_receivers; ++m->num_receivers;
} }
void Registry::Unregister(Modifiable* m, Receiver* r) void Registry::Unregister(Modifiable* m, Receiver* r) {
{
DBG_LOG(DBG_NOTIFIERS, "unregistering object %p from receiver %p", m, r); DBG_LOG(DBG_NOTIFIERS, "unregistering object %p from receiver %p", m, r);
auto x = registrations.equal_range(m); auto x = registrations.equal_range(m);
for ( auto i = x.first; i != x.second; i++ ) for ( auto i = x.first; i != x.second; i++ ) {
{ if ( i->second == r ) {
if ( i->second == r )
{
--i->first->num_receivers; --i->first->num_receivers;
registrations.erase(i); registrations.erase(i);
break; break;
} }
} }
} }
void Registry::Unregister(Modifiable* m) void Registry::Unregister(Modifiable* m) {
{
DBG_LOG(DBG_NOTIFIERS, "unregistering object %p from all notifiers", m); DBG_LOG(DBG_NOTIFIERS, "unregistering object %p from all notifiers", m);
auto x = registrations.equal_range(m); auto x = registrations.equal_range(m);
@ -60,19 +47,17 @@ void Registry::Unregister(Modifiable* m)
--i->first->num_receivers; --i->first->num_receivers;
registrations.erase(x.first, x.second); registrations.erase(x.first, x.second);
} }
void Registry::Modified(Modifiable* m) void Registry::Modified(Modifiable* m) {
{
DBG_LOG(DBG_NOTIFIERS, "object %p has been modified", m); DBG_LOG(DBG_NOTIFIERS, "object %p has been modified", m);
auto x = registrations.equal_range(m); auto x = registrations.equal_range(m);
for ( auto i = x.first; i != x.second; i++ ) for ( auto i = x.first; i != x.second; i++ )
i->second->Modified(m); i->second->Modified(m);
} }
void Registry::Terminate() void Registry::Terminate() {
{
std::set<Receiver*> receivers; std::set<Receiver*> receivers;
for ( auto& r : registrations ) for ( auto& r : registrations )
@ -80,12 +65,11 @@ void Registry::Terminate()
for ( auto& r : receivers ) for ( auto& r : receivers )
r->Terminate(); r->Terminate();
} }
Modifiable::~Modifiable() Modifiable::~Modifiable() {
{
if ( num_receivers ) if ( num_receivers )
registry.Unregister(this); registry.Unregister(this);
} }
} // namespace zeek::notifier::detail } // namespace zeek::notifier::detail

View file

@ -10,14 +10,12 @@
#include <cstdint> #include <cstdint>
#include <unordered_map> #include <unordered_map>
namespace zeek::notifier::detail namespace zeek::notifier::detail {
{
class Modifiable; class Modifiable;
/** Interface class for receivers of notifications. */ /** Interface class for receivers of notifications. */
class Receiver class Receiver {
{
public: public:
Receiver(); Receiver();
virtual ~Receiver(); virtual ~Receiver();
@ -33,12 +31,11 @@ public:
* Callback executed when notification registry is terminating and * Callback executed when notification registry is terminating and
* no further modifications can possibly occur. * no further modifications can possibly occur.
*/ */
virtual void Terminate() { } virtual void Terminate() {}
}; };
/** Singleton class tracking all notification requests globally. */ /** Singleton class tracking all notification requests globally. */
class Registry class Registry {
{
public: public:
~Registry(); ~Registry();
@ -89,7 +86,7 @@ private:
using ModifiableMap = std::unordered_multimap<Modifiable*, Receiver*>; using ModifiableMap = std::unordered_multimap<Modifiable*, Receiver*>;
ModifiableMap registrations; ModifiableMap registrations;
}; };
/** /**
* Singleton object tracking all global notification requests. * Singleton object tracking all global notification requests.
@ -100,15 +97,13 @@ extern Registry registry;
* Base class for objects that can trigger notifications to receivers when * Base class for objects that can trigger notifications to receivers when
* modified. * modified.
*/ */
class Modifiable class Modifiable {
{
public: public:
/** /**
* Calling this method signals to all registered receivers that the * Calling this method signals to all registered receivers that the
* object has been modified. * object has been modified.
*/ */
void Modified() void Modified() {
{
if ( num_receivers ) if ( num_receivers )
registry.Modified(this); registry.Modified(this);
} }
@ -120,6 +115,6 @@ protected:
// Number of currently registered receivers. // Number of currently registered receivers.
uint64_t num_receivers = 0; uint64_t num_receivers = 0;
}; };
} // namespace zeek::notifier::detail } // namespace zeek::notifier::detail

View file

@ -11,18 +11,14 @@
#include "zeek/Func.h" #include "zeek/Func.h"
#include "zeek/plugin/Manager.h" #include "zeek/plugin/Manager.h"
namespace zeek namespace zeek {
{ namespace detail {
namespace detail
{
Location start_location("<start uninitialized>", 0, 0, 0, 0); Location start_location("<start uninitialized>", 0, 0, 0, 0);
Location end_location("<end uninitialized>", 0, 0, 0, 0); Location end_location("<end uninitialized>", 0, 0, 0, 0);
void Location::Describe(ODesc* d) const void Location::Describe(ODesc* d) const {
{ if ( filename ) {
if ( filename )
{
d->Add(filename); d->Add(filename);
if ( first_line == 0 ) if ( first_line == 0 )
@ -31,52 +27,44 @@ void Location::Describe(ODesc* d) const
d->AddSP(","); d->AddSP(",");
} }
if ( last_line != first_line ) if ( last_line != first_line ) {
{
d->Add("lines "); d->Add("lines ");
d->Add(first_line); d->Add(first_line);
d->Add("-"); d->Add("-");
d->Add(last_line); d->Add(last_line);
} }
else else {
{
d->Add("line "); d->Add("line ");
d->Add(first_line); d->Add(first_line);
} }
} }
bool Location::operator==(const Location& l) const bool Location::operator==(const Location& l) const {
{
if ( filename == l.filename || (filename && l.filename && util::streq(filename, l.filename)) ) if ( filename == l.filename || (filename && l.filename && util::streq(filename, l.filename)) )
return first_line == l.first_line && last_line == l.last_line; return first_line == l.first_line && last_line == l.last_line;
else else
return false; return false;
} }
} // namespace detail } // namespace detail
int Obj::suppress_errors = 0; int Obj::suppress_errors = 0;
Obj::~Obj() Obj::~Obj() {
{
if ( notify_plugins ) if ( notify_plugins )
PLUGIN_HOOK_VOID(HOOK_OBJ_DTOR, HookObjDtor(this)); PLUGIN_HOOK_VOID(HOOK_OBJ_DTOR, HookObjDtor(this));
delete location; delete location;
} }
void Obj::Warn(const char* msg, const Obj* obj2, bool pinpoint_only, void Obj::Warn(const char* msg, const Obj* obj2, bool pinpoint_only, const detail::Location* expr_location) const {
const detail::Location* expr_location) const
{
ODesc d; ODesc d;
DoMsg(&d, msg, obj2, pinpoint_only, expr_location); DoMsg(&d, msg, obj2, pinpoint_only, expr_location);
reporter->Warning("%s", d.Description()); reporter->Warning("%s", d.Description());
reporter->PopLocation(); reporter->PopLocation();
} }
void Obj::Error(const char* msg, const Obj* obj2, bool pinpoint_only, void Obj::Error(const char* msg, const Obj* obj2, bool pinpoint_only, const detail::Location* expr_location) const {
const detail::Location* expr_location) const
{
if ( suppress_errors ) if ( suppress_errors )
return; return;
@ -84,10 +72,9 @@ void Obj::Error(const char* msg, const Obj* obj2, bool pinpoint_only,
DoMsg(&d, msg, obj2, pinpoint_only, expr_location); DoMsg(&d, msg, obj2, pinpoint_only, expr_location);
reporter->Error("%s", d.Description()); reporter->Error("%s", d.Description());
reporter->PopLocation(); reporter->PopLocation();
} }
void Obj::BadTag(const char* msg, const char* t1, const char* t2) const void Obj::BadTag(const char* msg, const char* t1, const char* t2) const {
{
char out[512]; char out[512];
if ( t2 ) if ( t2 )
@ -101,10 +88,9 @@ void Obj::BadTag(const char* msg, const char* t1, const char* t2) const
DoMsg(&d, out); DoMsg(&d, out);
reporter->FatalErrorWithCore("%s", d.Description()); reporter->FatalErrorWithCore("%s", d.Description());
reporter->PopLocation(); reporter->PopLocation();
} }
void Obj::Internal(const char* msg) const void Obj::Internal(const char* msg) const {
{
ODesc d; ODesc d;
DoMsg(&d, msg); DoMsg(&d, msg);
auto rcs = render_call_stack(); auto rcs = render_call_stack();
@ -115,29 +101,25 @@ void Obj::Internal(const char* msg) const
reporter->InternalError("%s, call stack: %s", d.Description(), rcs.data()); reporter->InternalError("%s, call stack: %s", d.Description(), rcs.data());
reporter->PopLocation(); reporter->PopLocation();
} }
void Obj::InternalWarning(const char* msg) const void Obj::InternalWarning(const char* msg) const {
{
ODesc d; ODesc d;
DoMsg(&d, msg); DoMsg(&d, msg);
reporter->InternalWarning("%s", d.Description()); reporter->InternalWarning("%s", d.Description());
reporter->PopLocation(); reporter->PopLocation();
} }
void Obj::AddLocation(ODesc* d) const void Obj::AddLocation(ODesc* d) const {
{ if ( ! location ) {
if ( ! location )
{
d->Add("<no location>"); d->Add("<no location>");
return; return;
} }
location->Describe(d); location->Describe(d);
} }
bool Obj::SetLocationInfo(const detail::Location* start, const detail::Location* end) bool Obj::SetLocationInfo(const detail::Location* start, const detail::Location* end) {
{
if ( ! start || ! end ) if ( ! start || ! end )
return false; return false;
@ -150,69 +132,59 @@ bool Obj::SetLocationInfo(const detail::Location* start, const detail::Location*
delete location; delete location;
location = new detail::Location(start->filename, start->first_line, end->last_line, location =
start->first_column, end->last_column); new detail::Location(start->filename, start->first_line, end->last_line, start->first_column, end->last_column);
return true; return true;
} }
void Obj::UpdateLocationEndInfo(const detail::Location& end) void Obj::UpdateLocationEndInfo(const detail::Location& end) {
{
if ( ! location ) if ( ! location )
SetLocationInfo(&end, &end); SetLocationInfo(&end, &end);
location->last_line = end.last_line; location->last_line = end.last_line;
location->last_column = end.last_column; location->last_column = end.last_column;
} }
void Obj::DoMsg(ODesc* d, const char s1[], const Obj* obj2, bool pinpoint_only, void Obj::DoMsg(ODesc* d, const char s1[], const Obj* obj2, bool pinpoint_only,
const detail::Location* expr_location) const const detail::Location* expr_location) const {
{
d->SetShort(); d->SetShort();
d->Add(s1); d->Add(s1);
PinPoint(d, obj2, pinpoint_only); PinPoint(d, obj2, pinpoint_only);
const detail::Location* loc2 = nullptr; const detail::Location* loc2 = nullptr;
if ( obj2 && obj2->GetLocationInfo() != &detail::no_location && if ( obj2 && obj2->GetLocationInfo() != &detail::no_location && *obj2->GetLocationInfo() != *GetLocationInfo() )
*obj2->GetLocationInfo() != *GetLocationInfo() )
loc2 = obj2->GetLocationInfo(); loc2 = obj2->GetLocationInfo();
else if ( expr_location ) else if ( expr_location )
loc2 = expr_location; loc2 = expr_location;
reporter->PushLocation(GetLocationInfo(), loc2); reporter->PushLocation(GetLocationInfo(), loc2);
} }
void Obj::PinPoint(ODesc* d, const Obj* obj2, bool pinpoint_only) const void Obj::PinPoint(ODesc* d, const Obj* obj2, bool pinpoint_only) const {
{
d->Add(" ("); d->Add(" (");
Describe(d); Describe(d);
if ( obj2 && ! pinpoint_only ) if ( obj2 && ! pinpoint_only ) {
{
d->Add(" and "); d->Add(" and ");
obj2->Describe(d); obj2->Describe(d);
} }
d->Add(")"); d->Add(")");
} }
void Obj::Print() const void Obj::Print() const {
{
static File fstderr(stderr); static File fstderr(stderr);
ODesc d(DESC_READABLE, &fstderr); ODesc d(DESC_READABLE, &fstderr);
Describe(&d); Describe(&d);
d.Add("\n"); d.Add("\n");
} }
void bad_ref(int type) void bad_ref(int type) {
{
reporter->InternalError("bad reference count [%d]", type); reporter->InternalError("bad reference count [%d]", type);
abort(); abort();
} }
void obj_delete_func(void* v) void obj_delete_func(void* v) { Unref((Obj*)v); }
{
Unref((Obj*)v);
}
} // namespace zeek } // namespace zeek

View file

@ -6,22 +6,16 @@
#include <climits> #include <climits>
namespace zeek namespace zeek {
{
class ODesc; class ODesc;
namespace detail namespace detail {
{
class Location final class Location final {
{
public: public:
constexpr Location(const char* fname, int line_f, int line_l, int col_f, int col_l) noexcept constexpr Location(const char* fname, int line_f, int line_l, int col_f, int col_l) noexcept
: filename(fname), first_line(line_f), last_line(line_l), first_column(col_f), : filename(fname), first_line(line_f), last_line(line_l), first_column(col_f), last_column(col_l) {}
last_column(col_l)
{
}
Location() = default; Location() = default;
@ -33,7 +27,7 @@ public:
const char* filename = nullptr; const char* filename = nullptr;
int first_line = 0, last_line = 0; int first_line = 0, last_line = 0;
int first_column = 0, last_column = 0; int first_column = 0, last_column = 0;
}; };
#define YYLTYPE zeek::detail::yyltype #define YYLTYPE zeek::detail::yyltype
using yyltype = Location; using yyltype = Location;
@ -48,24 +42,18 @@ extern Location start_location;
extern Location end_location; extern Location end_location;
// Used by parser to set the above. // Used by parser to set the above.
inline void set_location(const Location loc) inline void set_location(const Location loc) { start_location = end_location = loc; }
{
start_location = end_location = loc;
}
inline void set_location(const Location start, const Location end) inline void set_location(const Location start, const Location end) {
{
start_location = start; start_location = start;
end_location = end; end_location = end;
} }
} // namespace detail } // namespace detail
class Obj class Obj {
{
public: public:
Obj() Obj() {
{
// A bit of a hack. We'd like to associate location // A bit of a hack. We'd like to associate location
// information with every object created when parsing, // information with every object created when parsing,
// since for them, the location is generally well-defined. // since for them, the location is generally well-defined.
@ -114,10 +102,7 @@ public:
void AddLocation(ODesc* d) const; void AddLocation(ODesc* d) const;
// Get location info for debugging. // Get location info for debugging.
virtual const detail::Location* GetLocationInfo() const virtual const detail::Location* GetLocationInfo() const { return location ? location : &detail::no_location; }
{
return location ? location : &detail::no_location;
}
virtual bool SetLocationInfo(const detail::Location* loc) { return SetLocationInfo(loc, loc); } virtual bool SetLocationInfo(const detail::Location* loc) { return SetLocationInfo(loc, loc); }
@ -135,8 +120,7 @@ public:
// Helper class to temporarily suppress errors // Helper class to temporarily suppress errors
// as long as there exist any instances. // as long as there exist any instances.
class SuppressErrors class SuppressErrors {
{
public: public:
SuppressErrors() { ++Obj::suppress_errors; } SuppressErrors() { ++Obj::suppress_errors; }
~SuppressErrors() { --Obj::suppress_errors; } ~SuppressErrors() { --Obj::suppress_errors; }
@ -163,29 +147,23 @@ private:
// If non-zero, do not print runtime errors. Useful for // If non-zero, do not print runtime errors. Useful for
// speculative evaluation. // speculative evaluation.
static int suppress_errors; static int suppress_errors;
}; };
// Sometimes useful when dealing with Obj subclasses that have their // Sometimes useful when dealing with Obj subclasses that have their
// own (protected) versions of Error. // own (protected) versions of Error.
inline void Error(const Obj* o, const char* msg) inline void Error(const Obj* o, const char* msg) { o->Error(msg); }
{
o->Error(msg);
}
[[noreturn]] extern void bad_ref(int type); [[noreturn]] extern void bad_ref(int type);
inline void Ref(Obj* o) inline void Ref(Obj* o) {
{
if ( ++(o->ref_cnt) <= 1 ) if ( ++(o->ref_cnt) <= 1 )
bad_ref(0); bad_ref(0);
if ( o->ref_cnt == INT_MAX ) if ( o->ref_cnt == INT_MAX )
bad_ref(1); bad_ref(1);
} }
inline void Unref(Obj* o) inline void Unref(Obj* o) {
{ if ( o && --o->ref_cnt <= 0 ) {
if ( o && --o->ref_cnt <= 0 )
{
if ( o->ref_cnt < 0 ) if ( o->ref_cnt < 0 )
bad_ref(2); bad_ref(2);
delete o; delete o;
@ -193,9 +171,9 @@ inline void Unref(Obj* o)
// We could do the following if o were passed by reference. // We could do the following if o were passed by reference.
// o = (Obj*) 0xcd; // o = (Obj*) 0xcd;
} }
} }
// A dict_delete_func that knows to Unref() dictionary entries. // A dict_delete_func that knows to Unref() dictionary entries.
extern void obj_delete_func(void* v); extern void obj_delete_func(void* v);
} // namespace zeek } // namespace zeek

File diff suppressed because it is too large Load diff

View file

@ -21,22 +21,18 @@
#include "zeek/telemetry/Gauge.h" #include "zeek/telemetry/Gauge.h"
#include "zeek/telemetry/Histogram.h" #include "zeek/telemetry/Histogram.h"
namespace broker namespace broker {
{
class data; class data;
} }
namespace zeek namespace zeek {
{
namespace probabilistic namespace probabilistic {
{
class BloomFilter; class BloomFilter;
} }
namespace probabilistic::detail namespace probabilistic::detail {
{
class CardinalityCounter; class CardinalityCounter;
} }
class OpaqueVal; class OpaqueVal;
using OpaqueValPtr = IntrusivePtr<OpaqueVal>; using OpaqueValPtr = IntrusivePtr<OpaqueVal>;
@ -48,8 +44,7 @@ using BloomFilterValPtr = IntrusivePtr<BloomFilterVal>;
* Singleton that registers all available all available types of opaque * Singleton that registers all available all available types of opaque
* values. This facilitates their serialization into Broker values. * values. This facilitates their serialization into Broker values.
*/ */
class OpaqueMgr class OpaqueMgr {
{
public: public:
using Factory = OpaqueValPtr(); using Factory = OpaqueValPtr();
@ -84,15 +79,15 @@ public:
* Internal helper class to register an OpaqueVal-derived classes * Internal helper class to register an OpaqueVal-derived classes
* with the manager. * with the manager.
*/ */
template <class T> class Register template<class T>
{ class Register {
public: public:
Register(const char* id) { OpaqueMgr::mgr()->_types.emplace(id, &T::OpaqueInstantiate); } Register(const char* id) { OpaqueMgr::mgr()->_types.emplace(id, &T::OpaqueInstantiate); }
}; };
private: private:
std::unordered_map<std::string, Factory*> _types; std::unordered_map<std::string, Factory*> _types;
}; };
/** Macro to insert into an OpaqueVal-derived class's declaration. */ /** Macro to insert into an OpaqueVal-derived class's declaration. */
#define DECLARE_OPAQUE_VALUE(T) \ #define DECLARE_OPAQUE_VALUE(T) \
@ -114,8 +109,7 @@ private:
* completely internally, with no further script-level operators provided * completely internally, with no further script-level operators provided
* (other than bif functions). See OpaqueVal.h for derived classes. * (other than bif functions). See OpaqueVal.h for derived classes.
*/ */
class OpaqueVal : public Val class OpaqueVal : public Val {
{
public: public:
explicit OpaqueVal(OpaqueTypePtr t); explicit OpaqueVal(OpaqueTypePtr t);
~OpaqueVal() override = default; ~OpaqueVal() override = default;
@ -185,14 +179,12 @@ protected:
void ValDescribe(ODesc* d) const override; void ValDescribe(ODesc* d) const override;
void ValDescribeReST(ODesc* d) const override; void ValDescribeReST(ODesc* d) const override;
}; };
class HashVal : public OpaqueVal class HashVal : public OpaqueVal {
{
public: public:
template <class T> template<class T>
static void digest_all(detail::HashAlgorithm alg, const T& vlist, u_char* result) static void digest_all(detail::HashAlgorithm alg, const T& vlist, u_char* result) {
{
auto h = detail::hash_init(alg); auto h = detail::hash_init(alg);
for ( const auto& v : vlist ) for ( const auto& v : vlist )
@ -219,20 +211,17 @@ protected:
private: private:
// This flag exists because Get() can only be called once. // This flag exists because Get() can only be called once.
bool valid; bool valid;
}; };
class MD5Val : public HashVal class MD5Val : public HashVal {
{
public: public:
template <class T> static void digest(const T& vlist, u_char result[MD5_DIGEST_LENGTH]) template<class T>
{ static void digest(const T& vlist, u_char result[MD5_DIGEST_LENGTH]) {
digest_all(detail::Hash_MD5, vlist, result); digest_all(detail::Hash_MD5, vlist, result);
} }
template <class T> template<class T>
static void hmac(const T& vlist, u_char key[MD5_DIGEST_LENGTH], static void hmac(const T& vlist, u_char key[MD5_DIGEST_LENGTH], u_char result[MD5_DIGEST_LENGTH]) {
u_char result[MD5_DIGEST_LENGTH])
{
digest(vlist, result); digest(vlist, result);
for ( int i = 0; i < MD5_DIGEST_LENGTH; ++i ) for ( int i = 0; i < MD5_DIGEST_LENGTH; ++i )
@ -260,13 +249,12 @@ private:
#else #else
MD5_CTX ctx; MD5_CTX ctx;
#endif #endif
}; };
class SHA1Val : public HashVal class SHA1Val : public HashVal {
{
public: public:
template <class T> static void digest(const T& vlist, u_char result[SHA_DIGEST_LENGTH]) template<class T>
{ static void digest(const T& vlist, u_char result[SHA_DIGEST_LENGTH]) {
digest_all(detail::Hash_SHA1, vlist, result); digest_all(detail::Hash_SHA1, vlist, result);
} }
@ -289,13 +277,12 @@ private:
#else #else
SHA_CTX ctx; SHA_CTX ctx;
#endif #endif
}; };
class SHA256Val : public HashVal class SHA256Val : public HashVal {
{
public: public:
template <class T> static void digest(const T& vlist, u_char result[SHA256_DIGEST_LENGTH]) template<class T>
{ static void digest(const T& vlist, u_char result[SHA256_DIGEST_LENGTH]) {
digest_all(detail::Hash_SHA256, vlist, result); digest_all(detail::Hash_SHA256, vlist, result);
} }
@ -318,10 +305,9 @@ private:
#else #else
SHA256_CTX ctx; SHA256_CTX ctx;
#endif #endif
}; };
class EntropyVal : public OpaqueVal class EntropyVal : public OpaqueVal {
{
public: public:
EntropyVal(); EntropyVal();
@ -334,10 +320,9 @@ protected:
DECLARE_OPAQUE_VALUE(EntropyVal) DECLARE_OPAQUE_VALUE(EntropyVal)
private: private:
detail::RandTest state; detail::RandTest state;
}; };
class BloomFilterVal : public OpaqueVal class BloomFilterVal : public OpaqueVal {
{
public: public:
explicit BloomFilterVal(probabilistic::BloomFilter* bf); explicit BloomFilterVal(probabilistic::BloomFilter* bf);
~BloomFilterVal() override; ~BloomFilterVal() override;
@ -371,10 +356,9 @@ private:
TypePtr type; TypePtr type;
detail::CompositeHash* hash; detail::CompositeHash* hash;
probabilistic::BloomFilter* bloom_filter; probabilistic::BloomFilter* bloom_filter;
}; };
class CardinalityVal : public OpaqueVal class CardinalityVal : public OpaqueVal {
{
public: public:
explicit CardinalityVal(probabilistic::detail::CardinalityCounter*); explicit CardinalityVal(probabilistic::detail::CardinalityCounter*);
~CardinalityVal() override; ~CardinalityVal() override;
@ -397,10 +381,9 @@ private:
TypePtr type; TypePtr type;
detail::CompositeHash* hash; detail::CompositeHash* hash;
probabilistic::detail::CardinalityCounter* c; probabilistic::detail::CardinalityCounter* c;
}; };
class ParaglobVal : public OpaqueVal class ParaglobVal : public OpaqueVal {
{
public: public:
explicit ParaglobVal(std::unique_ptr<paraglob::Paraglob> p); explicit ParaglobVal(std::unique_ptr<paraglob::Paraglob> p);
VectorValPtr Get(StringVal*& pattern); VectorValPtr Get(StringVal*& pattern);
@ -408,19 +391,18 @@ public:
bool operator==(const ParaglobVal& other) const; bool operator==(const ParaglobVal& other) const;
protected: protected:
ParaglobVal() : OpaqueVal(paraglob_type) { } ParaglobVal() : OpaqueVal(paraglob_type) {}
DECLARE_OPAQUE_VALUE(ParaglobVal) DECLARE_OPAQUE_VALUE(ParaglobVal)
private: private:
std::unique_ptr<paraglob::Paraglob> internal_paraglob; std::unique_ptr<paraglob::Paraglob> internal_paraglob;
}; };
/** /**
* Base class for metric handles. Handle types are not serializable. * Base class for metric handles. Handle types are not serializable.
*/ */
class TelemetryVal : public OpaqueVal class TelemetryVal : public OpaqueVal {
{
protected: protected:
explicit TelemetryVal(telemetry::IntCounter); explicit TelemetryVal(telemetry::IntCounter);
explicit TelemetryVal(telemetry::IntCounterFamily); explicit TelemetryVal(telemetry::IntCounterFamily);
@ -437,14 +419,14 @@ protected:
broker::expected<broker::data> DoSerialize() const override; broker::expected<broker::data> DoSerialize() const override;
bool DoUnserialize(const broker::data& data) override; bool DoUnserialize(const broker::data& data) override;
}; };
template <class Handle> class TelemetryValImpl : public TelemetryVal template<class Handle>
{ class TelemetryValImpl : public TelemetryVal {
public: public:
using HandleType = Handle; using HandleType = Handle;
explicit TelemetryValImpl(Handle hdl) : TelemetryVal(hdl), hdl(hdl) { } explicit TelemetryValImpl(Handle hdl) : TelemetryVal(hdl), hdl(hdl) {}
Handle GetHandle() const noexcept { return hdl; } Handle GetHandle() const noexcept { return hdl; }
@ -455,7 +437,7 @@ protected:
private: private:
Handle hdl; Handle hdl;
}; };
using IntCounterMetricVal = TelemetryValImpl<telemetry::IntCounter>; using IntCounterMetricVal = TelemetryValImpl<telemetry::IntCounter>;
using IntCounterMetricFamilyVal = TelemetryValImpl<telemetry::IntCounterFamily>; using IntCounterMetricFamilyVal = TelemetryValImpl<telemetry::IntCounterFamily>;
@ -470,4 +452,4 @@ using IntHistogramMetricFamilyVal = TelemetryValImpl<telemetry::IntHistogramFami
using DblHistogramMetricVal = TelemetryValImpl<telemetry::DblHistogram>; using DblHistogramMetricVal = TelemetryValImpl<telemetry::DblHistogram>;
using DblHistogramMetricFamilyVal = TelemetryValImpl<telemetry::DblHistogramFamily>; using DblHistogramMetricFamilyVal = TelemetryValImpl<telemetry::DblHistogramFamily>;
} // namespace zeek } // namespace zeek

View file

@ -19,18 +19,15 @@
#include "zeek/logging/writers/ascii/Ascii.h" #include "zeek/logging/writers/ascii/Ascii.h"
#include "zeek/script_opt/ScriptOpt.h" #include "zeek/script_opt/ScriptOpt.h"
namespace zeek namespace zeek {
{
void Options::filter_supervisor_options() void Options::filter_supervisor_options() {
{
pcap_filter = {}; pcap_filter = {};
signature_files = {}; signature_files = {};
pcap_output_file = {}; pcap_output_file = {};
} }
void Options::filter_supervised_node_options() void Options::filter_supervised_node_options() {
{
auto og = *this; auto og = *this;
*this = {}; *this = {};
@ -69,128 +66,118 @@ void Options::filter_supervised_node_options()
plugins_to_load = og.plugins_to_load; plugins_to_load = og.plugins_to_load;
scripts_to_load = og.scripts_to_load; scripts_to_load = og.scripts_to_load;
script_options_to_set = og.script_options_to_set; script_options_to_set = og.script_options_to_set;
} }
bool fake_dns() bool fake_dns() { return getenv("ZEEK_DNS_FAKE"); }
{
return getenv("ZEEK_DNS_FAKE");
}
extern const char* zeek_version(); extern const char* zeek_version();
void usage(const char* prog, int code) void usage(const char* prog, int code) {
{
fprintf(stderr, "zeek version %s\n", zeek_version()); fprintf(stderr, "zeek version %s\n", zeek_version());
fprintf(stderr, "usage: %s [options] [file ...]\n", prog); fprintf(stderr, "usage: %s [options] [file ...]\n", prog);
fprintf(stderr, "usage: %s --test [doctest-options] -- [options] [file ...]\n", prog); fprintf(stderr, "usage: %s --test [doctest-options] -- [options] [file ...]\n", prog);
fprintf(stderr, " <file> | Zeek script file, or read stdin\n"); fprintf(stderr, " <file> | Zeek script file, or read stdin\n");
fprintf(stderr, fprintf(stderr, " -a|--parse-only | exit immediately after parsing scripts\n");
" -a|--parse-only | exit immediately after parsing scripts\n"); fprintf(stderr, " -b|--bare-mode | don't load scripts from the base/ directory\n");
fprintf(stderr, fprintf(stderr, " -c|--capture-unprocessed <file> | write unprocessed packets to a tcpdump file\n");
" -b|--bare-mode | don't load scripts from the base/ directory\n");
fprintf(stderr,
" -c|--capture-unprocessed <file> | write unprocessed packets to a tcpdump file\n");
fprintf(stderr, " -d|--debug-script | activate Zeek script debugging\n"); fprintf(stderr, " -d|--debug-script | activate Zeek script debugging\n");
fprintf(stderr, " -e|--exec <zeek code> | augment loaded scripts by given code\n"); fprintf(stderr, " -e|--exec <zeek code> | augment loaded scripts by given code\n");
fprintf(stderr, " -f|--filter <filter> | tcpdump filter\n"); fprintf(stderr, " -f|--filter <filter> | tcpdump filter\n");
fprintf(stderr, " -h|--help | command line help\n"); fprintf(stderr, " -h|--help | command line help\n");
fprintf(stderr, " -i|--iface <interface> | read from given interface (only one allowed)\n");
fprintf(stderr, " -p|--prefix <prefix> | add given prefix to Zeek script file resolution\n");
fprintf(stderr, fprintf(stderr,
" -i|--iface <interface> | read from given interface (only one allowed)\n"); " -r|--readfile <readfile> | read from given tcpdump file (only one "
fprintf(
stderr,
" -p|--prefix <prefix> | add given prefix to Zeek script file resolution\n");
fprintf(stderr, " -r|--readfile <readfile> | read from given tcpdump file (only one "
"allowed, pass '-' as the filename to read from stdin)\n"); "allowed, pass '-' as the filename to read from stdin)\n");
fprintf(stderr, " -s|--rulefile <rulefile> | read rules from given file\n"); fprintf(stderr, " -s|--rulefile <rulefile> | read rules from given file\n");
fprintf(stderr, " -t|--tracefile <tracefile> | activate execution tracing\n"); fprintf(stderr, " -t|--tracefile <tracefile> | activate execution tracing\n");
fprintf(stderr, " -u|--usage-issues | find variable usage issues and exit\n"); fprintf(stderr, " -u|--usage-issues | find variable usage issues and exit\n");
fprintf(stderr, " --no-unused-warnings | suppress warnings of unused " fprintf(stderr,
" --no-unused-warnings | suppress warnings of unused "
"functions/hooks/events\n"); "functions/hooks/events\n");
fprintf(stderr, " -v|--version | print version and exit\n"); fprintf(stderr, " -v|--version | print version and exit\n");
fprintf(stderr, " -V|--build-info | print build information and exit\n"); fprintf(stderr, " -V|--build-info | print build information and exit\n");
fprintf(stderr, " -w|--writefile <writefile> | write to given tcpdump file\n"); fprintf(stderr, " -w|--writefile <writefile> | write to given tcpdump file\n");
#ifdef DEBUG #ifdef DEBUG
fprintf(stderr, " -B|--debug <dbgstreams> | Enable debugging output for selected " fprintf(stderr,
" -B|--debug <dbgstreams> | Enable debugging output for selected "
"streams ('-B help' for help)\n"); "streams ('-B help' for help)\n");
#endif #endif
fprintf(stderr, " -C|--no-checksums | ignore checksums\n"); fprintf(stderr, " -C|--no-checksums | ignore checksums\n");
fprintf(stderr, " -D|--deterministic | initialize random seeds to zero\n"); fprintf(stderr, " -D|--deterministic | initialize random seeds to zero\n");
fprintf(stderr, " -E|--event-trace <file> | generate a replayable event trace to " fprintf(stderr,
" -E|--event-trace <file> | generate a replayable event trace to "
"the given file\n"); "the given file\n");
fprintf(stderr, " -F|--force-dns | force DNS\n"); fprintf(stderr, " -F|--force-dns | force DNS\n");
fprintf(stderr, " -G|--load-seeds <file> | load seeds from given file\n"); fprintf(stderr, " -G|--load-seeds <file> | load seeds from given file\n");
fprintf(stderr, " -H|--save-seeds <file> | save seeds to given file\n"); fprintf(stderr, " -H|--save-seeds <file> | save seeds to given file\n");
fprintf(stderr, " -I|--print-id <ID name> | print out given ID\n"); fprintf(stderr, " -I|--print-id <ID name> | print out given ID\n");
fprintf(stderr, " -N|--print-plugins | print available plugins and exit (-NN " fprintf(stderr,
" -N|--print-plugins | print available plugins and exit (-NN "
"for verbose)\n"); "for verbose)\n");
fprintf(stderr, " -O|--optimize <option> | enable script optimization (use -O help " fprintf(stderr,
" -O|--optimize <option> | enable script optimization (use -O help "
"for options)\n"); "for options)\n");
fprintf(stderr, " -0|--optimize-files=<pat> | enable script optimization for all " fprintf(stderr,
" -0|--optimize-files=<pat> | enable script optimization for all "
"functions in files with names containing the given pattern\n"); "functions in files with names containing the given pattern\n");
fprintf(stderr, " -o|--optimize-funcs=<pat> | enable script optimization for " fprintf(stderr,
" -o|--optimize-funcs=<pat> | enable script optimization for "
"functions with names fully matching the given pattern\n"); "functions with names fully matching the given pattern\n");
fprintf(stderr, " -P|--prime-dns | prime DNS\n"); fprintf(stderr, " -P|--prime-dns | prime DNS\n");
fprintf(stderr, fprintf(stderr, " -Q|--time | print execution time summary to stderr\n");
" -Q|--time | print execution time summary to stderr\n");
fprintf(stderr, " -S|--debug-rules | enable rule debugging\n"); fprintf(stderr, " -S|--debug-rules | enable rule debugging\n");
fprintf(stderr, " -T|--re-level <level> | set 'RE_level' for rules\n"); fprintf(stderr, " -T|--re-level <level> | set 'RE_level' for rules\n");
fprintf(stderr, " -U|--status-file <file> | Record process status in file\n"); fprintf(stderr, " -U|--status-file <file> | Record process status in file\n");
fprintf(stderr, " -W|--watchdog | activate watchdog timer\n"); fprintf(stderr, " -W|--watchdog | activate watchdog timer\n");
fprintf(stderr, fprintf(stderr, " -X|--zeekygen <cfgfile> | generate documentation based on config file\n");
" -X|--zeekygen <cfgfile> | generate documentation based on config file\n");
#ifdef USE_PERFTOOLS_DEBUG #ifdef USE_PERFTOOLS_DEBUG
fprintf(stderr, " -m|--mem-leaks | show leaks [perftools]\n"); fprintf(stderr, " -m|--mem-leaks | show leaks [perftools]\n");
fprintf(stderr, " -M|--mem-profile | record heap [perftools]\n"); fprintf(stderr, " -M|--mem-profile | record heap [perftools]\n");
#endif #endif
fprintf( fprintf(stderr, " --profile-scripts[=file] | profile scripts to given file (default stdout)\n");
stderr,
" --profile-scripts[=file] | profile scripts to given file (default stdout)\n");
fprintf(stderr, fprintf(stderr,
" --profile-script-call-stacks | add call stacks to profile output (requires " " --profile-script-call-stacks | add call stacks to profile output (requires "
"--profile-scripts)\n"); "--profile-scripts)\n");
fprintf(stderr, " --pseudo-realtime[=<speedup>] | enable pseudo-realtime for performance " fprintf(stderr,
" --pseudo-realtime[=<speedup>] | enable pseudo-realtime for performance "
"evaluation (default 1)\n"); "evaluation (default 1)\n");
fprintf(stderr, " -j|--jobs | enable supervisor mode\n"); fprintf(stderr, " -j|--jobs | enable supervisor mode\n");
fprintf(stderr, " --test | run unit tests ('--test -h' for help, " fprintf(stderr,
" --test | run unit tests ('--test -h' for help, "
"not available when built without ENABLE_ZEEK_UNIT_TESTS)\n"); "not available when built without ENABLE_ZEEK_UNIT_TESTS)\n");
fprintf(stderr, " $ZEEKPATH | file search path (%s)\n", fprintf(stderr, " $ZEEKPATH | file search path (%s)\n", util::zeek_path().c_str());
util::zeek_path().c_str()); fprintf(stderr, " $ZEEK_PLUGIN_PATH | plugin search path (%s)\n", util::zeek_plugin_path());
fprintf(stderr, " $ZEEK_PLUGIN_PATH | plugin search path (%s)\n",
util::zeek_plugin_path());
fprintf(stderr, " $ZEEK_PLUGIN_ACTIVATE | plugins to always activate (%s)\n", fprintf(stderr, " $ZEEK_PLUGIN_ACTIVATE | plugins to always activate (%s)\n",
util::zeek_plugin_activate()); util::zeek_plugin_activate());
fprintf(stderr, " $ZEEK_PREFIXES | prefix list (%s)\n", fprintf(stderr, " $ZEEK_PREFIXES | prefix list (%s)\n", util::zeek_prefixes().c_str());
util::zeek_prefixes().c_str()); fprintf(stderr, " $ZEEK_DNS_FAKE | disable DNS lookups (%s)\n", fake_dns() ? "on" : "off");
fprintf(stderr, " $ZEEK_DNS_FAKE | disable DNS lookups (%s)\n",
fake_dns() ? "on" : "off");
fprintf(stderr, " $ZEEK_SEED_VALUES | list of space separated seeds (%s)\n", fprintf(stderr, " $ZEEK_SEED_VALUES | list of space separated seeds (%s)\n",
getenv("ZEEK_SEED_VALUES") ? "set" : "not set"); getenv("ZEEK_SEED_VALUES") ? "set" : "not set");
fprintf(stderr, " $ZEEK_SEED_FILE | file to load seeds from (not set)\n"); fprintf(stderr, " $ZEEK_SEED_FILE | file to load seeds from (not set)\n");
fprintf(stderr, " $ZEEK_LOG_SUFFIX | ASCII log file extension (.%s)\n", fprintf(stderr, " $ZEEK_LOG_SUFFIX | ASCII log file extension (.%s)\n",
logging::writer::detail::Ascii::LogExt().c_str()); logging::writer::detail::Ascii::LogExt().c_str());
fprintf(stderr, " $ZEEK_PROFILER_FILE | Output file for script execution " fprintf(stderr,
" $ZEEK_PROFILER_FILE | Output file for script execution "
"statistics (not set)\n"); "statistics (not set)\n");
fprintf(stderr, fprintf(stderr, " $ZEEK_DISABLE_ZEEKYGEN | Disable Zeekygen documentation support (%s)\n",
" $ZEEK_DISABLE_ZEEKYGEN | Disable Zeekygen documentation support (%s)\n",
getenv("ZEEK_DISABLE_ZEEKYGEN") ? "set" : "not set"); getenv("ZEEK_DISABLE_ZEEKYGEN") ? "set" : "not set");
fprintf(stderr, " $ZEEK_DNS_RESOLVER | IPv4/IPv6 address of DNS resolver to use (%s)\n",
getenv("ZEEK_DNS_RESOLVER") ? getenv("ZEEK_DNS_RESOLVER") :
"not set, will use first IPv4 address from /etc/resolv.conf");
fprintf(stderr, fprintf(stderr,
" $ZEEK_DNS_RESOLVER | IPv4/IPv6 address of DNS resolver to use (%s)\n", " $ZEEK_DEBUG_LOG_STDERR | Use stderr for debug logs generated via "
getenv("ZEEK_DNS_RESOLVER")
? getenv("ZEEK_DNS_RESOLVER")
: "not set, will use first IPv4 address from /etc/resolv.conf");
fprintf(stderr, " $ZEEK_DEBUG_LOG_STDERR | Use stderr for debug logs generated via "
"the -B flag"); "the -B flag");
fprintf(stderr, "\n"); fprintf(stderr, "\n");
exit(code); exit(code);
} }
static void print_analysis_help() static void print_analysis_help() {
{
fprintf(stderr, "--optimize options when using ZAM:\n"); fprintf(stderr, "--optimize options when using ZAM:\n");
fprintf(stderr, " ZAM execute scripts using ZAM and all optimizations\n"); fprintf(stderr, " ZAM execute scripts using ZAM and all optimizations\n");
fprintf(stderr, " help print this list\n"); fprintf(stderr, " help print this list\n");
@ -199,42 +186,35 @@ static void print_analysis_help()
fprintf(stderr, " dump-uds dump use-defs to stdout; implies xform\n"); fprintf(stderr, " dump-uds dump use-defs to stdout; implies xform\n");
fprintf(stderr, " dump-xform dump transformed scripts to stdout; implies xform\n"); fprintf(stderr, " dump-xform dump transformed scripts to stdout; implies xform\n");
fprintf(stderr, " dump-ZAM dump generated ZAM code; implies gen-ZAM-code\n"); fprintf(stderr, " dump-ZAM dump generated ZAM code; implies gen-ZAM-code\n");
fprintf(stderr, fprintf(stderr, " gen-ZAM-code generate ZAM code (without turning on additional optimizations)\n");
" gen-ZAM-code generate ZAM code (without turning on additional optimizations)\n");
fprintf(stderr, " inline inline function calls\n"); fprintf(stderr, " inline inline function calls\n");
fprintf(stderr, " no-ZAM-opt omit low-level ZAM optimization\n"); fprintf(stderr, " no-ZAM-opt omit low-level ZAM optimization\n");
fprintf(stderr, " optimize-all optimize all scripts, even inlined ones\n"); fprintf(stderr, " optimize-all optimize all scripts, even inlined ones\n");
fprintf(stderr, " optimize-AST optimize the (transformed) AST; implies xform\n"); fprintf(stderr, " optimize-AST optimize the (transformed) AST; implies xform\n");
fprintf(stderr, fprintf(stderr, " profile-ZAM generate to stdout a ZAM execution profile; implies -O ZAM\n");
" profile-ZAM generate to stdout a ZAM execution profile; implies -O ZAM\n");
fprintf(stderr, " report-recursive report on recursive functions and exit\n"); fprintf(stderr, " report-recursive report on recursive functions and exit\n");
fprintf(stderr, " xform transform scripts to \"reduced\" form\n"); fprintf(stderr, " xform transform scripts to \"reduced\" form\n");
fprintf(stderr, "\n--optimize options when generating C++:\n"); fprintf(stderr, "\n--optimize options when generating C++:\n");
fprintf( fprintf(stderr, " allow-cond allow standalone compilation of functions influenced by conditionals\n");
stderr,
" allow-cond allow standalone compilation of functions influenced by conditionals\n");
fprintf(stderr, " gen-C++ generate C++ script bodies\n"); fprintf(stderr, " gen-C++ generate C++ script bodies\n");
fprintf(stderr, " gen-standalone-C++ generate \"standalone\" C++ script bodies\n"); fprintf(stderr, " gen-standalone-C++ generate \"standalone\" C++ script bodies\n");
fprintf(stderr, " help print this list\n"); fprintf(stderr, " help print this list\n");
fprintf(stderr, " report-C++ report available C++ script bodies and exit\n"); fprintf(stderr, " report-C++ report available C++ script bodies and exit\n");
fprintf(stderr, " report-uncompilable print names of functions that can't be compiled\n"); fprintf(stderr, " report-uncompilable print names of functions that can't be compiled\n");
fprintf(stderr, " use-C++ use available C++ script bodies\n"); fprintf(stderr, " use-C++ use available C++ script bodies\n");
} }
static void set_analysis_option(const char* opt, Options& opts) static void set_analysis_option(const char* opt, Options& opts) {
{
auto& a_o = opts.analysis_options; auto& a_o = opts.analysis_options;
if ( ! opt || util::streq(opt, "ZAM") ) if ( ! opt || util::streq(opt, "ZAM") ) {
{
a_o.inliner = a_o.optimize_AST = a_o.activate = true; a_o.inliner = a_o.optimize_AST = a_o.activate = true;
a_o.gen_ZAM = true; a_o.gen_ZAM = true;
return; return;
} }
if ( util::streq(opt, "help") ) if ( util::streq(opt, "help") ) {
{
print_analysis_help(); print_analysis_help();
exit(0); exit(0);
} }
@ -274,16 +254,14 @@ static void set_analysis_option(const char* opt, Options& opts)
else if ( util::streq(opt, "xform") ) else if ( util::streq(opt, "xform") )
a_o.activate = true; a_o.activate = true;
else else {
{
fprintf(stderr, "zeek: unrecognized -O/--optimize option: %s\n\n", opt); fprintf(stderr, "zeek: unrecognized -O/--optimize option: %s\n\n", opt);
print_analysis_help(); print_analysis_help();
exit(1); exit(1);
} }
} }
Options parse_cmdline(int argc, char** argv) Options parse_cmdline(int argc, char** argv) {
{
Options rval; Options rval;
// When running unit tests, the first argument on the command line must be // When running unit tests, the first argument on the command line must be
@ -295,26 +273,22 @@ Options parse_cmdline(int argc, char** argv)
// Just locally filtering out the args for Zeek usage from doctest args. // Just locally filtering out the args for Zeek usage from doctest args.
std::vector<std::string> zeek_args; std::vector<std::string> zeek_args;
if ( argc > 1 && strcmp(argv[1], "--test") == 0 ) if ( argc > 1 && strcmp(argv[1], "--test") == 0 ) {
{
#ifdef DOCTEST_CONFIG_DISABLE #ifdef DOCTEST_CONFIG_DISABLE
fprintf(stderr, "ERROR: C++ unit tests are disabled for this build.\n" fprintf(stderr,
"ERROR: C++ unit tests are disabled for this build.\n"
" Please re-compile with ENABLE_ZEEK_UNIT_TESTS " " Please re-compile with ENABLE_ZEEK_UNIT_TESTS "
"to run the C++ unit tests.\n"); "to run the C++ unit tests.\n");
exit(1); exit(1);
#endif #endif
auto is_separator = [](const char* cstr) auto is_separator = [](const char* cstr) { return strcmp(cstr, "--") == 0; };
{
return strcmp(cstr, "--") == 0;
};
auto first = argv; auto first = argv;
auto last = argv + argc; auto last = argv + argc;
auto separator = std::find_if(first, last, is_separator); auto separator = std::find_if(first, last, is_separator);
zeek_args.emplace_back(argv[0]); zeek_args.emplace_back(argv[0]);
if ( separator != last ) if ( separator != last ) {
{
auto first_zeek_arg = std::next(separator); auto first_zeek_arg = std::next(separator);
for ( auto i = first_zeek_arg; i != last; ++i ) for ( auto i = first_zeek_arg; i != last; ++i )
@ -326,34 +300,26 @@ Options parse_cmdline(int argc, char** argv)
for ( ptrdiff_t i = 0; i < std::distance(first, separator); ++i ) for ( ptrdiff_t i = 0; i < std::distance(first, separator); ++i )
rval.doctest_args.emplace_back(argv[i]); rval.doctest_args.emplace_back(argv[i]);
} }
else else {
{ if ( argc > 1 ) {
if ( argc > 1 ) auto endsWith = [](const std::string& str, const std::string& suffix) {
{
auto endsWith = [](const std::string& str, const std::string& suffix)
{
return str.size() >= suffix.size() && return str.size() >= suffix.size() &&
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}; };
auto i = 0; auto i = 0;
for ( ; i < argc && ! endsWith(argv[i], "--"); ++i ) for ( ; i < argc && ! endsWith(argv[i], "--"); ++i ) {
{
zeek_args.emplace_back(argv[i]); zeek_args.emplace_back(argv[i]);
} }
if ( i < argc ) if ( i < argc ) {
{
// If a script is invoked with Zeek as the interpreter, the arguments provided // If a script is invoked with Zeek as the interpreter, the arguments provided
// directly in the interpreter line of the script won't be broken apart in the // directly in the interpreter line of the script won't be broken apart in the
// argv on Linux so we split it up here. // argv on Linux so we split it up here.
if ( endsWith(argv[i], "--") && zeek_args.size() == 1 ) if ( endsWith(argv[i], "--") && zeek_args.size() == 1 ) {
{
std::istringstream iss(argv[i]); std::istringstream iss(argv[i]);
for ( std::string s; iss >> s; ) for ( std::string s; iss >> s; ) {
{ if ( ! endsWith(s, "--") ) {
if ( ! endsWith(s, "--") )
{
zeek_args.emplace_back(s); zeek_args.emplace_back(s);
} }
} }
@ -434,8 +400,7 @@ Options parse_cmdline(int argc, char** argv)
}; };
char opts[256]; char opts[256];
util::safe_strncpy(opts, "B:c:E:e:f:G:H:I:i:j::n:O:0:o:p:r:s:T:t:U:w:X:CDFMNPQSWabdhmuvV", util::safe_strncpy(opts, "B:c:E:e:f:G:H:I:i:j::n:O:0:o:p:r:s:T:t:U:w:X:CDFMNPQSWabdhmuvV", sizeof(opts));
sizeof(opts));
int op; int op;
int long_optsind; int long_optsind;
@ -447,40 +412,22 @@ Options parse_cmdline(int argc, char** argv)
for ( size_t i = 0; i < zeek_args.size(); ++i ) for ( size_t i = 0; i < zeek_args.size(); ++i )
zargs[i] = zeek_args[i].data(); zargs[i] = zeek_args[i].data();
while ( (op = getopt_long(zeek_args.size(), zargs.get(), opts, long_opts, &long_optsind)) != while ( (op = getopt_long(zeek_args.size(), zargs.get(), opts, long_opts, &long_optsind)) != EOF )
EOF ) switch ( op ) {
switch ( op ) case 'a': rval.parse_only = true; break;
{ case 'b': rval.bare_mode = true; break;
case 'a': case 'c': rval.unprocessed_output_file = optarg; break;
rval.parse_only = true; case 'd': rval.debug_scripts = true; break;
break; case 'e': rval.script_code_to_exec = optarg; break;
case 'b': case 'f': rval.pcap_filter = optarg; break;
rval.bare_mode = true; case 'h': rval.print_usage = true; break;
break;
case 'c':
rval.unprocessed_output_file = optarg;
break;
case 'd':
rval.debug_scripts = true;
break;
case 'e':
rval.script_code_to_exec = optarg;
break;
case 'f':
rval.pcap_filter = optarg;
break;
case 'h':
rval.print_usage = true;
break;
case 'i': case 'i':
if ( rval.interface ) if ( rval.interface ) {
{
fprintf(stderr, "ERROR: Only a single interface option (-i) is allowed.\n"); fprintf(stderr, "ERROR: Only a single interface option (-i) is allowed.\n");
exit(1); exit(1);
} }
if ( rval.pcap_file ) if ( rval.pcap_file ) {
{
fprintf(stderr, "ERROR: Using -i is not allow when reading a pcap file.\n"); fprintf(stderr, "ERROR: Using -i is not allow when reading a pcap file.\n");
exit(1); exit(1);
} }
@ -489,128 +436,74 @@ Options parse_cmdline(int argc, char** argv)
break; break;
case 'j': case 'j':
rval.supervisor_mode = true; rval.supervisor_mode = true;
if ( optarg ) if ( optarg ) {
{
// TODO: for supervised offline pcap reading, the argument is // TODO: for supervised offline pcap reading, the argument is
// expected to be number of workers like "-j 4" or possibly a // expected to be number of workers like "-j 4" or possibly a
// list of worker/proxy/logger counts like "-j 4,2,1" // list of worker/proxy/logger counts like "-j 4,2,1"
} }
break; break;
case 'p': case 'p': rval.script_prefixes.emplace_back(optarg); break;
rval.script_prefixes.emplace_back(optarg);
break;
case 'r': case 'r':
if ( rval.pcap_file ) if ( rval.pcap_file ) {
{
fprintf(stderr, "ERROR: Only a single readfile option (-r) is allowed.\n"); fprintf(stderr, "ERROR: Only a single readfile option (-r) is allowed.\n");
exit(1); exit(1);
} }
if ( rval.interface ) if ( rval.interface ) {
{
fprintf(stderr, "Using -r is not allowed when reading a live interface.\n"); fprintf(stderr, "Using -r is not allowed when reading a live interface.\n");
exit(1); exit(1);
} }
rval.pcap_file = optarg; rval.pcap_file = optarg;
break; break;
case 's': case 's': rval.signature_files.emplace_back(optarg); break;
rval.signature_files.emplace_back(optarg); case 't': rval.debug_script_tracing_file = optarg; break;
break; case 'u': ++rval.analysis_options.usage_issues; break;
case 't': case 'v': rval.print_version = true; break;
rval.debug_script_tracing_file = optarg; case 'V': rval.print_build_info = true; break;
break; case 'w': rval.pcap_output_file = optarg; break;
case 'u':
++rval.analysis_options.usage_issues;
break;
case 'v':
rval.print_version = true;
break;
case 'V':
rval.print_build_info = true;
break;
case 'w':
rval.pcap_output_file = optarg;
break;
case 'B': case 'B':
#ifdef DEBUG #ifdef DEBUG
rval.debug_log_streams = optarg; rval.debug_log_streams = optarg;
#else #else
if ( util::streq(optarg, "help") ) if ( util::streq(optarg, "help") ) {
{
fprintf(stderr, "debug streams unavailable\n"); fprintf(stderr, "debug streams unavailable\n");
exit(1); exit(1);
} }
#endif #endif
break; break;
case 'C': case 'C': rval.ignore_checksums = true; break;
rval.ignore_checksums = true; case 'D': rval.deterministic_mode = true; break;
break; case 'E': rval.event_trace_file = optarg; break;
case 'D':
rval.deterministic_mode = true;
break;
case 'E':
rval.event_trace_file = optarg;
break;
case 'F': case 'F':
if ( rval.dns_mode != detail::DNS_DEFAULT ) if ( rval.dns_mode != detail::DNS_DEFAULT )
usage(zargs[0], 1); usage(zargs[0], 1);
rval.dns_mode = detail::DNS_FORCE; rval.dns_mode = detail::DNS_FORCE;
break; break;
case 'G': case 'G': rval.random_seed_input_file = optarg; break;
rval.random_seed_input_file = optarg; case 'H': rval.random_seed_output_file = optarg; break;
break; case 'I': rval.identifier_to_print = optarg; break;
case 'H': case 'N': ++rval.print_plugins; break;
rval.random_seed_output_file = optarg; case 'O': set_analysis_option(optarg, rval); break;
break; case 'o': add_func_analysis_pattern(rval.analysis_options, optarg); break;
case 'I': case '0': add_file_analysis_pattern(rval.analysis_options, optarg); break;
rval.identifier_to_print = optarg;
break;
case 'N':
++rval.print_plugins;
break;
case 'O':
set_analysis_option(optarg, rval);
break;
case 'o':
add_func_analysis_pattern(rval.analysis_options, optarg);
break;
case '0':
add_file_analysis_pattern(rval.analysis_options, optarg);
break;
case 'P': case 'P':
if ( rval.dns_mode != detail::DNS_DEFAULT ) if ( rval.dns_mode != detail::DNS_DEFAULT )
usage(zargs[0], 1); usage(zargs[0], 1);
rval.dns_mode = detail::DNS_PRIME; rval.dns_mode = detail::DNS_PRIME;
break; break;
case 'Q': case 'Q': rval.print_execution_time = true; break;
rval.print_execution_time = true; case 'S': rval.print_signature_debug_info = true; break;
break; case 'T': rval.signature_re_level = atoi(optarg); break;
case 'S': case 'U': rval.process_status_file = optarg; break;
rval.print_signature_debug_info = true; case 'W': rval.use_watchdog = true; break;
break; case 'X': rval.zeekygen_config_file = optarg; break;
case 'T':
rval.signature_re_level = atoi(optarg);
break;
case 'U':
rval.process_status_file = optarg;
break;
case 'W':
rval.use_watchdog = true;
break;
case 'X':
rval.zeekygen_config_file = optarg;
break;
#ifdef USE_PERFTOOLS_DEBUG #ifdef USE_PERFTOOLS_DEBUG
case 'm': case 'm': rval.perftools_check_leaks = 1; break;
rval.perftools_check_leaks = 1; case 'M': rval.perftools_profile = 1; break;
break;
case 'M':
rval.perftools_profile = 1;
break;
#endif #endif
case '~': case '~':
@ -627,15 +520,13 @@ Options parse_cmdline(int argc, char** argv)
case 0: case 0:
// This happens for long options that don't have // This happens for long options that don't have
// a short-option equivalent. // a short-option equivalent.
if ( profile_scripts ) if ( profile_scripts ) {
{
profile_filename = optarg ? optarg : ""; profile_filename = optarg ? optarg : "";
enable_script_profile = true; enable_script_profile = true;
profile_scripts = 0; profile_scripts = 0;
} }
if ( profile_script_call_stacks ) if ( profile_script_call_stacks ) {
{
enable_script_profile_call_stacks = true; enable_script_profile_call_stacks = true;
profile_script_call_stacks = 0; profile_script_call_stacks = 0;
} }
@ -645,18 +536,13 @@ Options parse_cmdline(int argc, char** argv)
break; break;
case '?': case '?':
default: default: usage(zargs[0], 1); break;
usage(zargs[0], 1);
break;
} }
if ( ! enable_script_profile && enable_script_profile_call_stacks ) if ( ! enable_script_profile && enable_script_profile_call_stacks )
fprintf( fprintf(stderr, "ERROR: --profile-scripts-traces requires --profile-scripts to be passed as well.\n");
stderr,
"ERROR: --profile-scripts-traces requires --profile-scripts to be passed as well.\n");
if ( enable_script_profile ) if ( enable_script_profile ) {
{
activate_script_profiling(profile_filename.empty() ? nullptr : profile_filename.c_str(), activate_script_profiling(profile_filename.empty() ? nullptr : profile_filename.c_str(),
enable_script_profile_call_stacks); enable_script_profile_call_stacks);
} }
@ -664,8 +550,7 @@ Options parse_cmdline(int argc, char** argv)
// Process remaining arguments. X=Y arguments indicate script // Process remaining arguments. X=Y arguments indicate script
// variable/parameter assignments. X::Y arguments indicate plugins to // variable/parameter assignments. X::Y arguments indicate plugins to
// activate/query. The remainder are treated as scripts to load. // activate/query. The remainder are treated as scripts to load.
while ( optind < static_cast<int>(zeek_args.size()) ) while ( optind < static_cast<int>(zeek_args.size()) ) {
{
if ( strchr(zargs[optind], '=') ) if ( strchr(zargs[optind], '=') )
rval.script_options_to_set.emplace_back(zargs[optind++]); rval.script_options_to_set.emplace_back(zargs[optind++]);
else if ( strstr(zargs[optind], "::") ) else if ( strstr(zargs[optind], "::") )
@ -674,8 +559,7 @@ Options parse_cmdline(int argc, char** argv)
rval.scripts_to_load.emplace_back(zargs[optind++]); rval.scripts_to_load.emplace_back(zargs[optind++]);
} }
auto canonify_script_path = [](std::string* path) auto canonify_script_path = [](std::string* path) {
{
if ( path->empty() ) if ( path->empty() )
return; return;
@ -685,13 +569,11 @@ Options parse_cmdline(int argc, char** argv)
// Absolute path // Absolute path
return; return;
if ( (*path)[0] != '.' ) if ( (*path)[0] != '.' ) {
{
// Look up file in ZEEKPATH // Look up file in ZEEKPATH
auto res = util::find_script_file(*path, util::zeek_path()); auto res = util::find_script_file(*path, util::zeek_path());
if ( res.empty() ) if ( res.empty() ) {
{
fprintf(stderr, "failed to locate script: %s\n", path->data()); fprintf(stderr, "failed to locate script: %s\n", path->data());
exit(1); exit(1);
} }
@ -706,8 +588,7 @@ Options parse_cmdline(int argc, char** argv)
// Need to translate relative path to absolute. // Need to translate relative path to absolute.
char cwd[PATH_MAX]; char cwd[PATH_MAX];
if ( ! getcwd(cwd, sizeof(cwd)) ) if ( ! getcwd(cwd, sizeof(cwd)) ) {
{
fprintf(stderr, "failed to get current directory: %s\n", strerror(errno)); fprintf(stderr, "failed to get current directory: %s\n", strerror(errno));
exit(1); exit(1);
} }
@ -715,8 +596,7 @@ Options parse_cmdline(int argc, char** argv)
*path = std::string(cwd) + "/" + *path; *path = std::string(cwd) + "/" + *path;
}; };
if ( rval.supervisor_mode ) if ( rval.supervisor_mode ) {
{
// Translate any relative paths supplied to supervisor into absolute // Translate any relative paths supplied to supervisor into absolute
// paths for use by supervised nodes since they have the option to // paths for use by supervised nodes since they have the option to
// operate out of a different working directory. // operate out of a different working directory.
@ -725,6 +605,6 @@ Options parse_cmdline(int argc, char** argv)
} }
return rval; return rval;
} }
} // namespace zeek } // namespace zeek

View file

@ -9,15 +9,13 @@
#include "zeek/DNS_Mgr.h" #include "zeek/DNS_Mgr.h"
#include "zeek/script_opt/ScriptOpt.h" #include "zeek/script_opt/ScriptOpt.h"
namespace zeek namespace zeek {
{
/** /**
* Options that define general Zeek processing behavior, usually determined * Options that define general Zeek processing behavior, usually determined
* from command-line arguments. * from command-line arguments.
*/ */
struct Options struct Options {
{
/** /**
* Unset options that aren't meant to be used by the supervisor, but may * Unset options that aren't meant to be used by the supervisor, but may
* make sense for supervised nodes to inherit (as opposed to flagging * make sense for supervised nodes to inherit (as opposed to flagging
@ -85,7 +83,7 @@ struct Options
// For script optimization: // For script optimization:
detail::AnalyOpt analysis_options; detail::AnalyOpt analysis_options;
}; };
/** /**
* Parse Zeek command-line arguments. * Parse Zeek command-line arguments.
@ -107,4 +105,4 @@ void usage(const char* prog, int code = 1);
*/ */
bool fake_dns(); bool fake_dns();
} // namespace zeek } // namespace zeek

View file

@ -4,11 +4,9 @@
#include "zeek/Val.h" #include "zeek/Val.h"
namespace zeek::detail namespace zeek::detail {
{
bool would_overflow(const zeek::Type* from_type, const zeek::Type* to_type, const Val* val) bool would_overflow(const zeek::Type* from_type, const zeek::Type* to_type, const Val* val) {
{
if ( ! to_type || ! from_type ) if ( ! to_type || ! from_type )
return true; return true;
@ -18,16 +16,14 @@ bool would_overflow(const zeek::Type* from_type, const zeek::Type* to_type, cons
if ( to_type->InternalType() == TYPE_INTERNAL_DOUBLE ) if ( to_type->InternalType() == TYPE_INTERNAL_DOUBLE )
return false; return false;
if ( to_type->InternalType() == TYPE_INTERNAL_UNSIGNED ) if ( to_type->InternalType() == TYPE_INTERNAL_UNSIGNED ) {
{
if ( from_type->InternalType() == TYPE_INTERNAL_DOUBLE ) if ( from_type->InternalType() == TYPE_INTERNAL_DOUBLE )
return double_to_count_would_overflow(val->InternalDouble()); return double_to_count_would_overflow(val->InternalDouble());
if ( from_type->InternalType() == TYPE_INTERNAL_INT ) if ( from_type->InternalType() == TYPE_INTERNAL_INT )
return int_to_count_would_overflow(val->InternalInt()); return int_to_count_would_overflow(val->InternalInt());
} }
if ( to_type->InternalType() == TYPE_INTERNAL_INT ) if ( to_type->InternalType() == TYPE_INTERNAL_INT ) {
{
if ( from_type->InternalType() == TYPE_INTERNAL_DOUBLE ) if ( from_type->InternalType() == TYPE_INTERNAL_DOUBLE )
return double_to_int_would_overflow(val->InternalDouble()); return double_to_int_would_overflow(val->InternalDouble());
if ( from_type->InternalType() == TYPE_INTERNAL_UNSIGNED ) if ( from_type->InternalType() == TYPE_INTERNAL_UNSIGNED )
@ -35,6 +31,6 @@ bool would_overflow(const zeek::Type* from_type, const zeek::Type* to_type, cons
} }
return false; return false;
} }
} } // namespace zeek::detail

View file

@ -4,29 +4,18 @@
#include "zeek/Type.h" #include "zeek/Type.h"
namespace zeek::detail namespace zeek::detail {
{
inline bool double_to_count_would_overflow(double v) inline bool double_to_count_would_overflow(double v) { return v < 0.0 || v > static_cast<double>(UINT64_MAX); }
{
return v < 0.0 || v > static_cast<double>(UINT64_MAX);
}
inline bool int_to_count_would_overflow(zeek_int_t v) inline bool int_to_count_would_overflow(zeek_int_t v) { return v < 0.0; }
{
return v < 0.0;
}
inline bool double_to_int_would_overflow(double v) inline bool double_to_int_would_overflow(double v) {
{
return v < static_cast<double>(INT64_MIN) || v > static_cast<double>(INT64_MAX); return v < static_cast<double>(INT64_MIN) || v > static_cast<double>(INT64_MAX);
} }
inline bool count_to_int_would_overflow(zeek_uint_t v) inline bool count_to_int_would_overflow(zeek_uint_t v) { return v > INT64_MAX; }
{
return v > INT64_MAX;
}
extern bool would_overflow(const zeek::Type* from_type, const zeek::Type* to_type, const Val* val); extern bool would_overflow(const zeek::Type* from_type, const zeek::Type* to_type, const Val* val);
} } // namespace zeek::detail

View file

@ -2,88 +2,76 @@
#include "zeek/IP.h" #include "zeek/IP.h"
namespace zeek::detail namespace zeek::detail {
{
void PacketFilter::DeleteFilter(void* data) void PacketFilter::DeleteFilter(void* data) {
{
auto f = static_cast<Filter*>(data); auto f = static_cast<Filter*>(data);
delete f; delete f;
} }
PacketFilter::PacketFilter(bool arg_default) PacketFilter::PacketFilter(bool arg_default) {
{
default_match = arg_default; default_match = arg_default;
src_filter.SetDeleteFunction(PacketFilter::DeleteFilter); src_filter.SetDeleteFunction(PacketFilter::DeleteFilter);
dst_filter.SetDeleteFunction(PacketFilter::DeleteFilter); dst_filter.SetDeleteFunction(PacketFilter::DeleteFilter);
} }
void PacketFilter::AddSrc(const IPAddr& src, uint32_t tcp_flags, double probability) void PacketFilter::AddSrc(const IPAddr& src, uint32_t tcp_flags, double probability) {
{
Filter* f = new Filter; Filter* f = new Filter;
f->tcp_flags = tcp_flags; f->tcp_flags = tcp_flags;
f->probability = probability * static_cast<double>(util::detail::max_random()); f->probability = probability * static_cast<double>(util::detail::max_random());
auto prev = static_cast<Filter*>(src_filter.Insert(src, 128, f)); auto prev = static_cast<Filter*>(src_filter.Insert(src, 128, f));
delete prev; delete prev;
} }
void PacketFilter::AddSrc(Val* src, uint32_t tcp_flags, double probability) void PacketFilter::AddSrc(Val* src, uint32_t tcp_flags, double probability) {
{
Filter* f = new Filter; Filter* f = new Filter;
f->tcp_flags = tcp_flags; f->tcp_flags = tcp_flags;
f->probability = probability * static_cast<double>(util::detail::max_random()); f->probability = probability * static_cast<double>(util::detail::max_random());
auto prev = static_cast<Filter*>(src_filter.Insert(src, f)); auto prev = static_cast<Filter*>(src_filter.Insert(src, f));
delete prev; delete prev;
} }
void PacketFilter::AddDst(const IPAddr& dst, uint32_t tcp_flags, double probability) void PacketFilter::AddDst(const IPAddr& dst, uint32_t tcp_flags, double probability) {
{
Filter* f = new Filter; Filter* f = new Filter;
f->tcp_flags = tcp_flags; f->tcp_flags = tcp_flags;
f->probability = probability * static_cast<double>(util::detail::max_random()); f->probability = probability * static_cast<double>(util::detail::max_random());
auto prev = static_cast<Filter*>(dst_filter.Insert(dst, 128, f)); auto prev = static_cast<Filter*>(dst_filter.Insert(dst, 128, f));
delete prev; delete prev;
} }
void PacketFilter::AddDst(Val* dst, uint32_t tcp_flags, double probability) void PacketFilter::AddDst(Val* dst, uint32_t tcp_flags, double probability) {
{
Filter* f = new Filter; Filter* f = new Filter;
f->tcp_flags = tcp_flags; f->tcp_flags = tcp_flags;
f->probability = probability * static_cast<double>(util::detail::max_random()); f->probability = probability * static_cast<double>(util::detail::max_random());
auto prev = static_cast<Filter*>(dst_filter.Insert(dst, f)); auto prev = static_cast<Filter*>(dst_filter.Insert(dst, f));
delete prev; delete prev;
} }
bool PacketFilter::RemoveSrc(const IPAddr& src) bool PacketFilter::RemoveSrc(const IPAddr& src) {
{
auto f = static_cast<Filter*>(src_filter.Remove(src, 128)); auto f = static_cast<Filter*>(src_filter.Remove(src, 128));
delete f; delete f;
return f != nullptr; return f != nullptr;
} }
bool PacketFilter::RemoveSrc(Val* src) bool PacketFilter::RemoveSrc(Val* src) {
{
auto f = static_cast<Filter*>(src_filter.Remove(src)); auto f = static_cast<Filter*>(src_filter.Remove(src));
delete f; delete f;
return f != nullptr; return f != nullptr;
} }
bool PacketFilter::RemoveDst(const IPAddr& dst) bool PacketFilter::RemoveDst(const IPAddr& dst) {
{
auto f = static_cast<Filter*>(dst_filter.Remove(dst, 128)); auto f = static_cast<Filter*>(dst_filter.Remove(dst, 128));
delete f; delete f;
return f != nullptr; return f != nullptr;
} }
bool PacketFilter::RemoveDst(Val* dst) bool PacketFilter::RemoveDst(Val* dst) {
{
auto f = static_cast<Filter*>(dst_filter.Remove(dst)); auto f = static_cast<Filter*>(dst_filter.Remove(dst));
delete f; delete f;
return f != nullptr; return f != nullptr;
} }
bool PacketFilter::Match(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen) bool PacketFilter::Match(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen) {
{
Filter* f = (Filter*)src_filter.Lookup(ip->SrcAddr(), 128); Filter* f = (Filter*)src_filter.Lookup(ip->SrcAddr(), 128);
if ( f ) if ( f )
return MatchFilter(*f, *ip, len, caplen); return MatchFilter(*f, *ip, len, caplen);
@ -93,19 +81,16 @@ bool PacketFilter::Match(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen)
return MatchFilter(*f, *ip, len, caplen); return MatchFilter(*f, *ip, len, caplen);
return default_match; return default_match;
} }
bool PacketFilter::MatchFilter(const Filter& f, const IP_Hdr& ip, int len, int caplen) bool PacketFilter::MatchFilter(const Filter& f, const IP_Hdr& ip, int len, int caplen) {
{ if ( ip.NextProto() == IPPROTO_TCP && f.tcp_flags ) {
if ( ip.NextProto() == IPPROTO_TCP && f.tcp_flags )
{
// Caution! The packet sanity checks have not been performed yet // Caution! The packet sanity checks have not been performed yet
int ip_hdr_len = ip.HdrLen(); int ip_hdr_len = ip.HdrLen();
len -= ip_hdr_len; // remove IP header len -= ip_hdr_len; // remove IP header
caplen -= ip_hdr_len; caplen -= ip_hdr_len;
if ( (unsigned int)len < sizeof(struct tcphdr) || if ( (unsigned int)len < sizeof(struct tcphdr) || (unsigned int)caplen < sizeof(struct tcphdr) )
(unsigned int)caplen < sizeof(struct tcphdr) )
// Packet too short, will be dropped anyway. // Packet too short, will be dropped anyway.
return false; return false;
@ -117,6 +102,6 @@ bool PacketFilter::MatchFilter(const Filter& f, const IP_Hdr& ip, int len, int c
} }
return util::detail::random_number() < f.probability; return util::detail::random_number() < f.probability;
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -7,20 +7,17 @@
#include "zeek/IPAddr.h" #include "zeek/IPAddr.h"
#include "zeek/PrefixTable.h" #include "zeek/PrefixTable.h"
namespace zeek namespace zeek {
{
class IP_Hdr; class IP_Hdr;
class Val; class Val;
namespace detail namespace detail {
{
class PacketFilter class PacketFilter {
{
public: public:
explicit PacketFilter(bool arg_default); explicit PacketFilter(bool arg_default);
~PacketFilter() { } ~PacketFilter() {}
// Drops all packets from a particular source (which may be given // Drops all packets from a particular source (which may be given
// as an AddrVal or a SubnetVal) which hasn't any of TCP flags set // as an AddrVal or a SubnetVal) which hasn't any of TCP flags set
@ -41,8 +38,7 @@ public:
bool Match(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen); bool Match(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen);
private: private:
struct Filter struct Filter {
{
uint32_t tcp_flags; uint32_t tcp_flags;
double probability; double probability;
}; };
@ -54,7 +50,7 @@ private:
bool default_match; bool default_match;
PrefixTable src_filter; PrefixTable src_filter;
PrefixTable dst_filter; PrefixTable dst_filter;
}; };
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -9,11 +9,9 @@
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
namespace zeek::detail namespace zeek::detail {
{
static void pipe_fail(int eno) static void pipe_fail(int eno) {
{
char tmp[256]; char tmp[256];
zeek::util::zeek_strerror_r(eno, tmp, sizeof(tmp)); zeek::util::zeek_strerror_r(eno, tmp, sizeof(tmp));
@ -21,17 +19,15 @@ static void pipe_fail(int eno)
reporter->FatalError("Pipe failure: %s", tmp); reporter->FatalError("Pipe failure: %s", tmp);
else else
fprintf(stderr, "Pipe failure: %s", tmp); fprintf(stderr, "Pipe failure: %s", tmp);
} }
static int set_flags(int fd, int flags) static int set_flags(int fd, int flags) {
{
auto rval = fcntl(fd, F_GETFD); auto rval = fcntl(fd, F_GETFD);
if ( rval == -1 ) if ( rval == -1 )
pipe_fail(errno); pipe_fail(errno);
if ( flags ) if ( flags ) {
{
rval |= flags; rval |= flags;
if ( fcntl(fd, F_SETFD, rval) == -1 ) if ( fcntl(fd, F_SETFD, rval) == -1 )
@ -39,17 +35,15 @@ static int set_flags(int fd, int flags)
} }
return rval; return rval;
} }
static int unset_flags(int fd, int flags) static int unset_flags(int fd, int flags) {
{
auto rval = fcntl(fd, F_GETFD); auto rval = fcntl(fd, F_GETFD);
if ( rval == -1 ) if ( rval == -1 )
pipe_fail(errno); pipe_fail(errno);
if ( flags ) if ( flags ) {
{
rval &= ~flags; rval &= ~flags;
if ( fcntl(fd, F_SETFD, rval) == -1 ) if ( fcntl(fd, F_SETFD, rval) == -1 )
@ -57,17 +51,15 @@ static int unset_flags(int fd, int flags)
} }
return rval; return rval;
} }
static int set_status_flags(int fd, int flags) static int set_status_flags(int fd, int flags) {
{
auto rval = fcntl(fd, F_GETFL); auto rval = fcntl(fd, F_GETFL);
if ( rval == -1 ) if ( rval == -1 )
pipe_fail(errno); pipe_fail(errno);
if ( flags ) if ( flags ) {
{
rval |= flags; rval |= flags;
if ( fcntl(fd, F_SETFL, rval) == -1 ) if ( fcntl(fd, F_SETFL, rval) == -1 )
@ -75,10 +67,9 @@ static int set_status_flags(int fd, int flags)
} }
return rval; return rval;
} }
static int dup_or_fail(int fd, int flags, int status_flags) static int dup_or_fail(int fd, int flags, int status_flags) {
{
int rval = dup(fd); int rval = dup(fd);
if ( rval < 0 ) if ( rval < 0 )
@ -87,17 +78,14 @@ static int dup_or_fail(int fd, int flags, int status_flags)
set_flags(fd, flags); set_flags(fd, flags);
set_status_flags(fd, status_flags); set_status_flags(fd, status_flags);
return rval; return rval;
} }
Pipe::Pipe(int flags0, int flags1, int status_flags0, int status_flags1, int* arg_fds) Pipe::Pipe(int flags0, int flags1, int status_flags0, int status_flags1, int* arg_fds) {
{ if ( arg_fds ) {
if ( arg_fds )
{
fds[0] = arg_fds[0]; fds[0] = arg_fds[0];
fds[1] = arg_fds[1]; fds[1] = arg_fds[1];
} }
else else {
{
// pipe2 can set flags atomically, but not yet available everywhere. // pipe2 can set flags atomically, but not yet available everywhere.
if ( ::pipe(fds) ) if ( ::pipe(fds) )
pipe_fail(errno); pipe_fail(errno);
@ -107,38 +95,33 @@ Pipe::Pipe(int flags0, int flags1, int status_flags0, int status_flags1, int* ar
flags[1] = set_flags(fds[1], flags1); flags[1] = set_flags(fds[1], flags1);
status_flags[0] = set_status_flags(fds[0], status_flags0); status_flags[0] = set_status_flags(fds[0], status_flags0);
status_flags[1] = set_status_flags(fds[1], status_flags1); status_flags[1] = set_status_flags(fds[1], status_flags1);
} }
void Pipe::SetFlags(int arg_flags) void Pipe::SetFlags(int arg_flags) {
{
flags[0] = set_flags(fds[0], arg_flags); flags[0] = set_flags(fds[0], arg_flags);
flags[1] = set_flags(fds[1], arg_flags); flags[1] = set_flags(fds[1], arg_flags);
} }
void Pipe::UnsetFlags(int arg_flags) void Pipe::UnsetFlags(int arg_flags) {
{
flags[0] = unset_flags(fds[0], arg_flags); flags[0] = unset_flags(fds[0], arg_flags);
flags[1] = unset_flags(fds[1], arg_flags); flags[1] = unset_flags(fds[1], arg_flags);
} }
Pipe::~Pipe() Pipe::~Pipe() {
{
close(fds[0]); close(fds[0]);
close(fds[1]); close(fds[1]);
} }
Pipe::Pipe(const Pipe& other) Pipe::Pipe(const Pipe& other) {
{
fds[0] = dup_or_fail(other.fds[0], other.flags[0], other.status_flags[0]); fds[0] = dup_or_fail(other.fds[0], other.flags[0], other.status_flags[0]);
fds[1] = dup_or_fail(other.fds[1], other.flags[1], other.status_flags[1]); fds[1] = dup_or_fail(other.fds[1], other.flags[1], other.status_flags[1]);
flags[0] = other.flags[0]; flags[0] = other.flags[0];
flags[1] = other.flags[1]; flags[1] = other.flags[1];
status_flags[0] = other.status_flags[0]; status_flags[0] = other.status_flags[0];
status_flags[1] = other.status_flags[1]; status_flags[1] = other.status_flags[1];
} }
Pipe& Pipe::operator=(const Pipe& other) Pipe& Pipe::operator=(const Pipe& other) {
{
if ( this == &other ) if ( this == &other )
return *this; return *this;
@ -151,12 +134,10 @@ Pipe& Pipe::operator=(const Pipe& other)
status_flags[0] = other.status_flags[0]; status_flags[0] = other.status_flags[0];
status_flags[1] = other.status_flags[1]; status_flags[1] = other.status_flags[1];
return *this; return *this;
} }
PipePair::PipePair(int flags, int status_flags, int* fds) PipePair::PipePair(int flags, int status_flags, int* fds)
: pipes{Pipe(flags, flags, status_flags, status_flags, fds ? fds + 0 : nullptr), : pipes{Pipe(flags, flags, status_flags, status_flags, fds ? fds + 0 : nullptr),
Pipe(flags, flags, status_flags, status_flags, fds ? fds + 2 : nullptr)} Pipe(flags, flags, status_flags, status_flags, fds ? fds + 2 : nullptr)} {}
{
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -2,11 +2,9 @@
#pragma once #pragma once
namespace zeek::detail namespace zeek::detail {
{
class Pipe class Pipe {
{
public: public:
/** /**
* Create a pair of file descriptors via pipe(), or aborts if it cannot. * Create a pair of file descriptors via pipe(), or aborts if it cannot.
@ -18,8 +16,7 @@ public:
* than create ones from a new pipe. Should point to memory containing * than create ones from a new pipe. Should point to memory containing
* two consecutive file descriptors, the "read" one and then the "write" one. * two consecutive file descriptors, the "read" one and then the "write" one.
*/ */
explicit Pipe(int flags0 = 0, int flags1 = 0, int status_flags0 = 0, int status_flags1 = 0, explicit Pipe(int flags0 = 0, int flags1 = 0, int status_flags0 = 0, int status_flags1 = 0, int* fds = nullptr);
int* fds = nullptr);
/** /**
* Close the pair of file descriptors owned by the object. * Close the pair of file descriptors owned by the object.
@ -63,13 +60,12 @@ private:
int fds[2]; int fds[2];
int flags[2]; int flags[2];
int status_flags[2]; int status_flags[2];
}; };
/** /**
* A pair of pipes that can be used for bi-directional IPC. * A pair of pipes that can be used for bi-directional IPC.
*/ */
class PipePair class PipePair {
{
public: public:
/** /**
* Create a pair of pipes * Create a pair of pipes
@ -125,6 +121,6 @@ public:
private: private:
Pipe pipes[2]; Pipe pipes[2];
bool swapped = false; bool swapped = false;
}; };
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -17,15 +17,12 @@
using namespace std; using namespace std;
struct PolicyFile struct PolicyFile {
{ PolicyFile() {
PolicyFile()
{
filedata = nullptr; filedata = nullptr;
lmtime = 0; lmtime = 0;
} }
~PolicyFile() ~PolicyFile() {
{
delete[] filedata; delete[] filedata;
filedata = nullptr; filedata = nullptr;
} }
@ -33,22 +30,19 @@ struct PolicyFile
time_t lmtime; time_t lmtime;
char* filedata; char* filedata;
vector<const char*> lines; vector<const char*> lines;
}; };
using PolicyFileMap = map<string, PolicyFile*>; using PolicyFileMap = map<string, PolicyFile*>;
static PolicyFileMap policy_files; static PolicyFileMap policy_files;
namespace zeek::detail namespace zeek::detail {
{
int how_many_lines_in(const char* policy_filename) int how_many_lines_in(const char* policy_filename) {
{
if ( ! policy_filename ) if ( ! policy_filename )
reporter->InternalError("NULL value passed to how_many_lines_in\n"); reporter->InternalError("NULL value passed to how_many_lines_in\n");
FILE* throwaway = fopen(policy_filename, "r"); FILE* throwaway = fopen(policy_filename, "r");
if ( ! throwaway ) if ( ! throwaway ) {
{
debug_msg("Could not open policy file: %s.\n", policy_filename); debug_msg("Could not open policy file: %s.\n", policy_filename);
return -1; return -1;
} }
@ -58,11 +52,9 @@ int how_many_lines_in(const char* policy_filename)
PolicyFileMap::iterator match; PolicyFileMap::iterator match;
match = policy_files.find(policy_filename); match = policy_files.find(policy_filename);
if ( match == policy_files.end() ) if ( match == policy_files.end() ) {
{
match = policy_files.find(policy_filename); match = policy_files.find(policy_filename);
if ( match == policy_files.end() ) if ( match == policy_files.end() ) {
{
debug_msg("Policy file %s was not loaded.\n", policy_filename); debug_msg("Policy file %s was not loaded.\n", policy_filename);
return -1; return -1;
} }
@ -70,11 +62,9 @@ int how_many_lines_in(const char* policy_filename)
PolicyFile* pf = match->second; PolicyFile* pf = match->second;
return pf->lines.size(); return pf->lines.size();
} }
bool LoadPolicyFileText(const char* policy_filename, bool LoadPolicyFileText(const char* policy_filename, const std::optional<std::string>& preloaded_content) {
const std::optional<std::string>& preloaded_content)
{
if ( ! policy_filename ) if ( ! policy_filename )
return true; return true;
@ -84,26 +74,22 @@ bool LoadPolicyFileText(const char* policy_filename,
PolicyFile* pf = new PolicyFile; PolicyFile* pf = new PolicyFile;
policy_files.insert(PolicyFileMap::value_type(policy_filename, pf)); policy_files.insert(PolicyFileMap::value_type(policy_filename, pf));
if ( preloaded_content ) if ( preloaded_content ) {
{
auto size = preloaded_content->size(); auto size = preloaded_content->size();
pf->filedata = new char[size + 1]; pf->filedata = new char[size + 1];
memcpy(pf->filedata, preloaded_content->data(), size); memcpy(pf->filedata, preloaded_content->data(), size);
pf->filedata[size] = '\0'; pf->filedata[size] = '\0';
} }
else else {
{
FILE* f = fopen(policy_filename, "r"); FILE* f = fopen(policy_filename, "r");
if ( ! f ) if ( ! f ) {
{
debug_msg("Could not open policy file: %s.\n", policy_filename); debug_msg("Could not open policy file: %s.\n", policy_filename);
return false; return false;
} }
struct stat st; struct stat st;
if ( fstat(fileno(f), &st) != 0 ) if ( fstat(fileno(f), &st) != 0 ) {
{
char buf[256]; char buf[256];
util::zeek_strerror_r(errno, buf, sizeof(buf)); util::zeek_strerror_r(errno, buf, sizeof(buf));
reporter->Error("fstat failed on %s: %s", policy_filename, buf); reporter->Error("fstat failed on %s: %s", policy_filename, buf);
@ -127,10 +113,8 @@ bool LoadPolicyFileText(const char* policy_filename,
// Separate the string by newlines. // Separate the string by newlines.
pf->lines.push_back(pf->filedata); pf->lines.push_back(pf->filedata);
for ( char* iter = pf->filedata; *iter; ++iter ) for ( char* iter = pf->filedata; *iter; ++iter ) {
{ if ( *iter == '\n' ) {
if ( *iter == '\n' )
{
*iter = 0; *iter = 0;
if ( *(iter + 1) ) if ( *(iter + 1) )
pf->lines.push_back(iter + 1); pf->lines.push_back(iter + 1);
@ -141,18 +125,15 @@ bool LoadPolicyFileText(const char* policy_filename,
assert(pf->lines[i][0] != '\n'); assert(pf->lines[i][0] != '\n');
return true; return true;
} }
// REMEMBER: line number arguments are indexed from 0. // REMEMBER: line number arguments are indexed from 0.
bool PrintLines(const char* policy_filename, unsigned int start_line, unsigned int how_many_lines, bool PrintLines(const char* policy_filename, unsigned int start_line, unsigned int how_many_lines, bool show_numbers) {
bool show_numbers)
{
if ( ! policy_filename ) if ( ! policy_filename )
return true; return true;
FILE* throwaway = fopen(policy_filename, "r"); FILE* throwaway = fopen(policy_filename, "r");
if ( ! throwaway ) if ( ! throwaway ) {
{
debug_msg("Could not open policy file: %s.\n", policy_filename); debug_msg("Could not open policy file: %s.\n", policy_filename);
return false; return false;
} }
@ -162,11 +143,9 @@ bool PrintLines(const char* policy_filename, unsigned int start_line, unsigned i
PolicyFileMap::iterator match; PolicyFileMap::iterator match;
match = policy_files.find(policy_filename); match = policy_files.find(policy_filename);
if ( match == policy_files.end() ) if ( match == policy_files.end() ) {
{
match = policy_files.find(policy_filename); match = policy_files.find(policy_filename);
if ( match == policy_files.end() ) if ( match == policy_files.end() ) {
{
debug_msg("Policy file %s was not loaded.\n", policy_filename); debug_msg("Policy file %s was not loaded.\n", policy_filename);
return false; return false;
} }
@ -177,18 +156,15 @@ bool PrintLines(const char* policy_filename, unsigned int start_line, unsigned i
if ( start_line < 1 ) if ( start_line < 1 )
start_line = 1; start_line = 1;
if ( start_line > pf->lines.size() ) if ( start_line > pf->lines.size() ) {
{ debug_msg("Line number %d out of range; %s has %d lines\n", start_line, policy_filename, int(pf->lines.size()));
debug_msg("Line number %d out of range; %s has %d lines\n", start_line, policy_filename,
int(pf->lines.size()));
return false; return false;
} }
if ( start_line + how_many_lines - 1 > pf->lines.size() ) if ( start_line + how_many_lines - 1 > pf->lines.size() )
how_many_lines = pf->lines.size() - start_line + 1; how_many_lines = pf->lines.size() - start_line + 1;
for ( unsigned int i = 0; i < how_many_lines; ++i ) for ( unsigned int i = 0; i < how_many_lines; ++i ) {
{
if ( show_numbers ) if ( show_numbers )
debug_msg("%d\t", i + start_line); debug_msg("%d\t", i + start_line);
@ -197,6 +173,6 @@ bool PrintLines(const char* policy_filename, unsigned int start_line, unsigned i
} }
return true; return true;
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -17,16 +17,13 @@
#include <optional> #include <optional>
#include <string> #include <string>
namespace zeek::detail namespace zeek::detail {
{
int how_many_lines_in(const char* policy_filename); int how_many_lines_in(const char* policy_filename);
bool LoadPolicyFileText(const char* policy_filename, bool LoadPolicyFileText(const char* policy_filename, const std::optional<std::string>& preloaded_content = {});
const std::optional<std::string>& preloaded_content = {});
// start_line is 1-based (the intuitive way) // start_line is 1-based (the intuitive way)
bool PrintLines(const char* policy_filename, unsigned int start_line, unsigned int how_many_lines, bool PrintLines(const char* policy_filename, unsigned int start_line, unsigned int how_many_lines, bool show_numbers);
bool show_numbers);
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -3,11 +3,9 @@
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
#include "zeek/Val.h" #include "zeek/Val.h"
namespace zeek::detail namespace zeek::detail {
{
prefix_t* PrefixTable::MakePrefix(const IPAddr& addr, int width) prefix_t* PrefixTable::MakePrefix(const IPAddr& addr, int width) {
{
prefix_t* prefix = (prefix_t*)util::safe_malloc(sizeof(prefix_t)); prefix_t* prefix = (prefix_t*)util::safe_malloc(sizeof(prefix_t));
addr.CopyIPv6(&prefix->add.sin6); addr.CopyIPv6(&prefix->add.sin6);
@ -16,23 +14,19 @@ prefix_t* PrefixTable::MakePrefix(const IPAddr& addr, int width)
prefix->ref_count = 1; prefix->ref_count = 1;
return prefix; return prefix;
} }
IPPrefix PrefixTable::PrefixToIPPrefix(prefix_t* prefix) IPPrefix PrefixTable::PrefixToIPPrefix(prefix_t* prefix) {
{ return IPPrefix(IPAddr(IPv6, reinterpret_cast<const uint32_t*>(&prefix->add.sin6), IPAddr::Network), prefix->bitlen,
return IPPrefix( true);
IPAddr(IPv6, reinterpret_cast<const uint32_t*>(&prefix->add.sin6), IPAddr::Network), }
prefix->bitlen, true);
}
void* PrefixTable::Insert(const IPAddr& addr, int width, void* data) void* PrefixTable::Insert(const IPAddr& addr, int width, void* data) {
{
prefix_t* prefix = MakePrefix(addr, width); prefix_t* prefix = MakePrefix(addr, width);
patricia_node_t* node = patricia_lookup(tree, prefix); patricia_node_t* node = patricia_lookup(tree, prefix);
Deref_Prefix(prefix); Deref_Prefix(prefix);
if ( ! node ) if ( ! node ) {
{
reporter->InternalWarning("Cannot create node in patricia tree"); reporter->InternalWarning("Cannot create node in patricia tree");
return nullptr; return nullptr;
} }
@ -44,32 +38,23 @@ void* PrefixTable::Insert(const IPAddr& addr, int width, void* data)
node->data = data ? data : node; node->data = data ? data : node;
return old; return old;
} }
void* PrefixTable::Insert(const Val* value, void* data) void* PrefixTable::Insert(const Val* value, void* data) {
{
// [elem] -> elem // [elem] -> elem
if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 ) if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 )
value = value->AsListVal()->Idx(0).get(); value = value->AsListVal()->Idx(0).get();
switch ( value->GetType()->Tag() ) switch ( value->GetType()->Tag() ) {
{ case TYPE_ADDR: return Insert(value->AsAddr(), 128, data); break;
case TYPE_ADDR:
return Insert(value->AsAddr(), 128, data);
break;
case TYPE_SUBNET: case TYPE_SUBNET: return Insert(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6(), data); break;
return Insert(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6(), data);
break;
default: default: reporter->InternalWarning("Wrong index type for PrefixTable"); return nullptr;
reporter->InternalWarning("Wrong index type for PrefixTable");
return nullptr;
}
} }
}
std::list<std::tuple<IPPrefix, void*>> PrefixTable::FindAll(const IPAddr& addr, int width) const std::list<std::tuple<IPPrefix, void*>> PrefixTable::FindAll(const IPAddr& addr, int width) const {
{
std::list<std::tuple<IPPrefix, void*>> out; std::list<std::tuple<IPPrefix, void*>> out;
prefix_t* prefix = MakePrefix(addr, width); prefix_t* prefix = MakePrefix(addr, width);
@ -84,51 +69,40 @@ std::list<std::tuple<IPPrefix, void*>> PrefixTable::FindAll(const IPAddr& addr,
Deref_Prefix(prefix); Deref_Prefix(prefix);
free(list); free(list);
return out; return out;
} }
std::list<std::tuple<IPPrefix, void*>> PrefixTable::FindAll(const SubNetVal* value) const std::list<std::tuple<IPPrefix, void*>> PrefixTable::FindAll(const SubNetVal* value) const {
{
return FindAll(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6()); return FindAll(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6());
} }
void* PrefixTable::Lookup(const IPAddr& addr, int width, bool exact) const void* PrefixTable::Lookup(const IPAddr& addr, int width, bool exact) const {
{
prefix_t* prefix = MakePrefix(addr, width); prefix_t* prefix = MakePrefix(addr, width);
patricia_node_t* node = exact ? patricia_search_exact(tree, prefix) patricia_node_t* node = exact ? patricia_search_exact(tree, prefix) : patricia_search_best(tree, prefix);
: patricia_search_best(tree, prefix);
int elems = 0; int elems = 0;
patricia_node_t** list = nullptr; patricia_node_t** list = nullptr;
Deref_Prefix(prefix); Deref_Prefix(prefix);
return node ? node->data : nullptr; return node ? node->data : nullptr;
} }
void* PrefixTable::Lookup(const Val* value, bool exact) const void* PrefixTable::Lookup(const Val* value, bool exact) const {
{
// [elem] -> elem // [elem] -> elem
if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 ) if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 )
value = value->AsListVal()->Idx(0).get(); value = value->AsListVal()->Idx(0).get();
switch ( value->GetType()->Tag() ) switch ( value->GetType()->Tag() ) {
{ case TYPE_ADDR: return Lookup(value->AsAddr(), 128, exact); break;
case TYPE_ADDR:
return Lookup(value->AsAddr(), 128, exact);
break;
case TYPE_SUBNET: case TYPE_SUBNET: return Lookup(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6(), exact); break;
return Lookup(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6(), exact);
break;
default: default:
reporter->InternalWarning("Wrong index type %d for PrefixTable", reporter->InternalWarning("Wrong index type %d for PrefixTable", value->GetType()->Tag());
value->GetType()->Tag());
return nullptr; return nullptr;
} }
} }
void* PrefixTable::Remove(const IPAddr& addr, int width) void* PrefixTable::Remove(const IPAddr& addr, int width) {
{
prefix_t* prefix = MakePrefix(addr, width); prefix_t* prefix = MakePrefix(addr, width);
patricia_node_t* node = patricia_search_exact(tree, prefix); patricia_node_t* node = patricia_search_exact(tree, prefix);
Deref_Prefix(prefix); Deref_Prefix(prefix);
@ -140,49 +114,37 @@ void* PrefixTable::Remove(const IPAddr& addr, int width)
patricia_remove(tree, node); patricia_remove(tree, node);
return old; return old;
} }
void* PrefixTable::Remove(const Val* value) void* PrefixTable::Remove(const Val* value) {
{
// [elem] -> elem // [elem] -> elem
if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 ) if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 )
value = value->AsListVal()->Idx(0).get(); value = value->AsListVal()->Idx(0).get();
switch ( value->GetType()->Tag() ) switch ( value->GetType()->Tag() ) {
{ case TYPE_ADDR: return Remove(value->AsAddr(), 128); break;
case TYPE_ADDR:
return Remove(value->AsAddr(), 128);
break;
case TYPE_SUBNET: case TYPE_SUBNET: return Remove(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6()); break;
return Remove(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6());
break;
default: default: reporter->InternalWarning("Wrong index type for PrefixTable"); return nullptr;
reporter->InternalWarning("Wrong index type for PrefixTable");
return nullptr;
}
} }
}
PrefixTable::iterator PrefixTable::InitIterator() PrefixTable::iterator PrefixTable::InitIterator() {
{
iterator i; iterator i;
i.Xsp = i.Xstack; i.Xsp = i.Xstack;
i.Xrn = tree->head; i.Xrn = tree->head;
i.Xnode = nullptr; i.Xnode = nullptr;
return i; return i;
} }
void* PrefixTable::GetNext(iterator* i) void* PrefixTable::GetNext(iterator* i) {
{ while ( true ) {
while ( true )
{
i->Xnode = i->Xrn; i->Xnode = i->Xrn;
if ( ! i->Xnode ) if ( ! i->Xnode )
return nullptr; return nullptr;
if ( i->Xrn->l ) if ( i->Xrn->l ) {
{
if ( i->Xrn->r ) if ( i->Xrn->r )
*i->Xsp++ = i->Xrn->r; *i->Xsp++ = i->Xrn->r;
@ -203,6 +165,6 @@ void* PrefixTable::GetNext(iterator* i)
} }
// Not reached. // Not reached.
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -1,29 +1,24 @@
#pragma once #pragma once
extern "C" extern "C" {
{
#include "zeek/3rdparty/patricia.h" #include "zeek/3rdparty/patricia.h"
} }
#include <list> #include <list>
#include <tuple> #include <tuple>
#include "zeek/IPAddr.h" #include "zeek/IPAddr.h"
namespace zeek namespace zeek {
{
class Val; class Val;
class SubNetVal; class SubNetVal;
namespace detail namespace detail {
{
class PrefixTable class PrefixTable {
{
private: private:
struct iterator struct iterator {
{
patricia_node_t* Xstack[PATRICIA_MAXBITS + 1]; patricia_node_t* Xstack[PATRICIA_MAXBITS + 1];
patricia_node_t** Xsp; patricia_node_t** Xsp;
patricia_node_t* Xrn; patricia_node_t* Xrn;
@ -31,8 +26,7 @@ private:
}; };
public: public:
PrefixTable() PrefixTable() {
{
tree = New_Patricia(128); tree = New_Patricia(128);
delete_function = nullptr; delete_function = nullptr;
} }
@ -74,7 +68,7 @@ private:
patricia_tree_t* tree; patricia_tree_t* tree;
data_fn_t delete_function; data_fn_t delete_function;
}; };
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -10,24 +10,18 @@
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek::detail namespace zeek::detail {
{
PriorityQueue::PriorityQueue(int initial_size) : max_heap_size(initial_size) PriorityQueue::PriorityQueue(int initial_size) : max_heap_size(initial_size) { heap = new PQ_Element*[max_heap_size]; }
{
heap = new PQ_Element*[max_heap_size];
}
PriorityQueue::~PriorityQueue() PriorityQueue::~PriorityQueue() {
{
for ( int i = 0; i < heap_size; ++i ) for ( int i = 0; i < heap_size; ++i )
delete heap[i]; delete heap[i];
delete[] heap; delete[] heap;
} }
PQ_Element* PriorityQueue::Remove() PQ_Element* PriorityQueue::Remove() {
{
if ( heap_size == 0 ) if ( heap_size == 0 )
return nullptr; return nullptr;
@ -39,10 +33,9 @@ PQ_Element* PriorityQueue::Remove()
top->SetOffset(-1); // = not in heap top->SetOffset(-1); // = not in heap
return top; return top;
} }
PQ_Element* PriorityQueue::Remove(PQ_Element* e) PQ_Element* PriorityQueue::Remove(PQ_Element* e) {
{
if ( e->Offset() < 0 || e->Offset() >= heap_size || heap[e->Offset()] != e ) if ( e->Offset() < 0 || e->Offset() >= heap_size || heap[e->Offset()] != e )
return nullptr; // not in heap return nullptr; // not in heap
@ -55,10 +48,9 @@ PQ_Element* PriorityQueue::Remove(PQ_Element* e)
reporter->InternalError("inconsistency in PriorityQueue::Remove"); reporter->InternalError("inconsistency in PriorityQueue::Remove");
return e2; return e2;
} }
bool PriorityQueue::Add(PQ_Element* e) bool PriorityQueue::Add(PQ_Element* e) {
{
SetElement(heap_size, e); SetElement(heap_size, e);
BubbleUp(heap_size); BubbleUp(heap_size);
@ -72,10 +64,9 @@ bool PriorityQueue::Add(PQ_Element* e)
return Resize(max_heap_size * 2); return Resize(max_heap_size * 2);
else else
return true; return true;
} }
bool PriorityQueue::Resize(int new_size) bool PriorityQueue::Resize(int new_size) {
{
PQ_Element** tmp = new PQ_Element*[new_size]; PQ_Element** tmp = new PQ_Element*[new_size];
for ( int i = 0; i < max_heap_size; ++i ) for ( int i = 0; i < max_heap_size; ++i )
tmp[i] = heap[i]; tmp[i] = heap[i];
@ -86,23 +77,20 @@ bool PriorityQueue::Resize(int new_size)
max_heap_size = new_size; max_heap_size = new_size;
return heap != nullptr; return heap != nullptr;
} }
void PriorityQueue::BubbleUp(int bin) void PriorityQueue::BubbleUp(int bin) {
{
if ( bin == 0 ) if ( bin == 0 )
return; return;
int p = Parent(bin); int p = Parent(bin);
if ( heap[p]->Time() > heap[bin]->Time() ) if ( heap[p]->Time() > heap[bin]->Time() ) {
{
Swap(p, bin); Swap(p, bin);
BubbleUp(p); BubbleUp(p);
} }
} }
void PriorityQueue::BubbleDown(int bin) void PriorityQueue::BubbleDown(int bin) {
{
double v = heap[bin]->Time(); double v = heap[bin]->Time();
int l = LeftChild(bin); int l = LeftChild(bin);
@ -111,32 +99,27 @@ void PriorityQueue::BubbleDown(int bin)
if ( l >= heap_size ) if ( l >= heap_size )
return; // No children. return; // No children.
if ( r >= heap_size ) if ( r >= heap_size ) { // Just a left child.
{ // Just a left child.
if ( heap[l]->Time() < v ) if ( heap[l]->Time() < v )
Swap(l, bin); Swap(l, bin);
} }
else else {
{
double lv = heap[l]->Time(); double lv = heap[l]->Time();
double rv = heap[r]->Time(); double rv = heap[r]->Time();
if ( lv < rv ) if ( lv < rv ) {
{ if ( lv < v ) {
if ( lv < v )
{
Swap(l, bin); Swap(l, bin);
BubbleDown(l); BubbleDown(l);
} }
} }
else if ( rv < v ) else if ( rv < v ) {
{
Swap(r, bin); Swap(r, bin);
BubbleDown(r); BubbleDown(r);
} }
} }
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -7,15 +7,13 @@
#include <cmath> #include <cmath>
#include <cstdint> #include <cstdint>
namespace zeek::detail namespace zeek::detail {
{
class PriorityQueue; class PriorityQueue;
class PQ_Element class PQ_Element {
{
public: public:
explicit PQ_Element(double t) : time(t) { } explicit PQ_Element(double t) : time(t) {}
virtual ~PQ_Element() = default; virtual ~PQ_Element() = default;
double Time() const { return time; } double Time() const { return time; }
@ -29,17 +27,15 @@ protected:
PQ_Element() = default; PQ_Element() = default;
double time = 0.0; double time = 0.0;
int offset = -1; int offset = -1;
}; };
class PriorityQueue class PriorityQueue {
{
public: public:
explicit PriorityQueue(int initial_size = 16); explicit PriorityQueue(int initial_size = 16);
~PriorityQueue(); ~PriorityQueue();
// Returns the top of queue, or nil if the queue is empty. // Returns the top of queue, or nil if the queue is empty.
PQ_Element* Top() const PQ_Element* Top() const {
{
if ( heap_size == 0 ) if ( heap_size == 0 )
return nullptr; return nullptr;
@ -74,14 +70,12 @@ protected:
int RightChild(int bin) const { return LeftChild(bin) + 1; } int RightChild(int bin) const { return LeftChild(bin) + 1; }
void SetElement(int bin, PQ_Element* e) void SetElement(int bin, PQ_Element* e) {
{
heap[bin] = e; heap[bin] = e;
e->SetOffset(bin); e->SetOffset(bin);
} }
void Swap(int bin1, int bin2) void Swap(int bin1, int bin2) {
{
PQ_Element* t = heap[bin1]; PQ_Element* t = heap[bin1];
SetElement(bin1, heap[bin2]); SetElement(bin1, heap[bin2]);
SetElement(bin2, t); SetElement(bin2, t);
@ -92,6 +86,6 @@ protected:
int peak_heap_size = 0; int peak_heap_size = 0;
int max_heap_size = 0; int max_heap_size = 0;
uint64_t cumulative_num = 0; uint64_t cumulative_num = 0;
}; };
} // namespace zeek::detail } // namespace zeek::detail

256
src/RE.cc
View file

@ -24,36 +24,29 @@ extern int RE_parse(void);
extern void RE_set_input(const char* str); extern void RE_set_input(const char* str);
extern void RE_done_with_scan(); extern void RE_done_with_scan();
namespace zeek namespace zeek {
{ namespace detail {
namespace detail
{
Specific_RE_Matcher::Specific_RE_Matcher(match_type arg_mt, bool arg_multiline) Specific_RE_Matcher::Specific_RE_Matcher(match_type arg_mt, bool arg_multiline)
: mt(arg_mt), multiline(arg_multiline), equiv_class(NUM_SYM) : mt(arg_mt), multiline(arg_multiline), equiv_class(NUM_SYM) {
{
any_ccl = nullptr; any_ccl = nullptr;
single_line_ccl = nullptr; single_line_ccl = nullptr;
dfa = nullptr; dfa = nullptr;
ecs = nullptr; ecs = nullptr;
accepted = new AcceptingSet(); accepted = new AcceptingSet();
} }
Specific_RE_Matcher::~Specific_RE_Matcher() Specific_RE_Matcher::~Specific_RE_Matcher() {
{
for ( int i = 0; i < ccl_list.length(); ++i ) for ( int i = 0; i < ccl_list.length(); ++i )
delete ccl_list[i]; delete ccl_list[i];
Unref(dfa); Unref(dfa);
delete accepted; delete accepted;
} }
CCL* Specific_RE_Matcher::AnyCCL(bool single_line_mode) CCL* Specific_RE_Matcher::AnyCCL(bool single_line_mode) {
{ if ( single_line_mode ) {
if ( single_line_mode ) if ( ! single_line_ccl ) {
{
if ( ! single_line_ccl )
{
single_line_ccl = new CCL(); single_line_ccl = new CCL();
single_line_ccl->Negate(); single_line_ccl->Negate();
EC()->CCL_Use(single_line_ccl); EC()->CCL_Use(single_line_ccl);
@ -62,8 +55,7 @@ CCL* Specific_RE_Matcher::AnyCCL(bool single_line_mode)
return single_line_ccl; return single_line_ccl;
} }
if ( ! any_ccl ) if ( ! any_ccl ) {
{
any_ccl = new CCL(); any_ccl = new CCL();
if ( ! multiline ) if ( ! multiline )
any_ccl->Add('\n'); any_ccl->Add('\n');
@ -72,54 +64,44 @@ CCL* Specific_RE_Matcher::AnyCCL(bool single_line_mode)
} }
return any_ccl; return any_ccl;
} }
void Specific_RE_Matcher::ConvertCCLs() void Specific_RE_Matcher::ConvertCCLs() {
{
for ( int i = 0; i < ccl_list.length(); ++i ) for ( int i = 0; i < ccl_list.length(); ++i )
equiv_class.ConvertCCL(ccl_list[i]); equiv_class.ConvertCCL(ccl_list[i]);
} }
void Specific_RE_Matcher::AddPat(const char* new_pat) void Specific_RE_Matcher::AddPat(const char* new_pat) {
{
if ( mt == MATCH_EXACTLY ) if ( mt == MATCH_EXACTLY )
AddExactPat(new_pat); AddExactPat(new_pat);
else else
AddAnywherePat(new_pat); AddAnywherePat(new_pat);
} }
void Specific_RE_Matcher::AddAnywherePat(const char* new_pat) void Specific_RE_Matcher::AddAnywherePat(const char* new_pat) {
{
AddPat(new_pat, "^?(.|\\n)*(%s)", "(%s)|(^?(.|\\n)*(%s))"); AddPat(new_pat, "^?(.|\\n)*(%s)", "(%s)|(^?(.|\\n)*(%s))");
} }
void Specific_RE_Matcher::AddExactPat(const char* new_pat) void Specific_RE_Matcher::AddExactPat(const char* new_pat) { AddPat(new_pat, "^?(%s)$?", "(%s)|(^?(%s)$?)"); }
{
AddPat(new_pat, "^?(%s)$?", "(%s)|(^?(%s)$?)");
}
void Specific_RE_Matcher::AddPat(const char* new_pat, const char* orig_fmt, const char* app_fmt) void Specific_RE_Matcher::AddPat(const char* new_pat, const char* orig_fmt, const char* app_fmt) {
{
if ( ! pattern_text.empty() ) if ( ! pattern_text.empty() )
pattern_text = util::fmt(app_fmt, pattern_text.c_str(), new_pat); pattern_text = util::fmt(app_fmt, pattern_text.c_str(), new_pat);
else else
pattern_text = util::fmt(orig_fmt, new_pat); pattern_text = util::fmt(orig_fmt, new_pat);
} }
void Specific_RE_Matcher::MakeCaseInsensitive() void Specific_RE_Matcher::MakeCaseInsensitive() {
{
const char fmt[] = "(?i:%s)"; const char fmt[] = "(?i:%s)";
pattern_text = util::fmt(fmt, pattern_text.c_str()); pattern_text = util::fmt(fmt, pattern_text.c_str());
} }
void Specific_RE_Matcher::MakeSingleLine() void Specific_RE_Matcher::MakeSingleLine() {
{
const char fmt[] = "(?s:%s)"; const char fmt[] = "(?s:%s)";
pattern_text = util::fmt(fmt, pattern_text.c_str()); pattern_text = util::fmt(fmt, pattern_text.c_str());
} }
bool Specific_RE_Matcher::Compile(bool lazy) bool Specific_RE_Matcher::Compile(bool lazy) {
{
if ( pattern_text.empty() ) if ( pattern_text.empty() )
return false; return false;
@ -129,8 +111,7 @@ bool Specific_RE_Matcher::Compile(bool lazy)
int parse_status = RE_parse(); int parse_status = RE_parse();
RE_done_with_scan(); RE_done_with_scan();
if ( parse_status ) if ( parse_status ) {
{
reporter->Error("error compiling pattern /%s/", pattern_text.c_str()); reporter->Error("error compiling pattern /%s/", pattern_text.c_str());
Unref(nfa); Unref(nfa);
nfa = nullptr; nfa = nullptr;
@ -148,10 +129,9 @@ bool Specific_RE_Matcher::Compile(bool lazy)
ecs = EC()->EquivClasses(); ecs = EC()->EquivClasses();
return true; return true;
} }
bool Specific_RE_Matcher::CompileSet(const string_list& set, const int_list& idx) bool Specific_RE_Matcher::CompileSet(const string_list& set, const int_list& idx) {
{
if ( (size_t)set.length() != idx.size() ) if ( (size_t)set.length() != idx.size() )
reporter->InternalError("compileset: lengths of sets differ"); reporter->InternalError("compileset: lengths of sets differ");
@ -159,14 +139,12 @@ bool Specific_RE_Matcher::CompileSet(const string_list& set, const int_list& idx
NFA_Machine* set_nfa = nullptr; NFA_Machine* set_nfa = nullptr;
loop_over_list(set, i) loop_over_list(set, i) {
{
RE_set_input(set[i]); RE_set_input(set[i]);
int parse_status = RE_parse(); int parse_status = RE_parse();
RE_done_with_scan(); RE_done_with_scan();
if ( parse_status ) if ( parse_status ) {
{
reporter->Error("error compiling pattern /%s/", set[i]); reporter->Error("error compiling pattern /%s/", set[i]);
if ( set_nfa && set_nfa != nfa ) if ( set_nfa && set_nfa != nfa )
@ -195,50 +173,32 @@ bool Specific_RE_Matcher::CompileSet(const string_list& set, const int_list& idx
ecs = EC()->EquivClasses(); ecs = EC()->EquivClasses();
return true; return true;
} }
std::string Specific_RE_Matcher::LookupDef(const std::string& def) std::string Specific_RE_Matcher::LookupDef(const std::string& def) {
{
const auto& iter = defs.find(def); const auto& iter = defs.find(def);
if ( iter != defs.end() ) if ( iter != defs.end() )
return iter->second; return iter->second;
return {}; return {};
} }
bool Specific_RE_Matcher::MatchAll(const char* s) bool Specific_RE_Matcher::MatchAll(const char* s) { return MatchAll((const u_char*)(s), strlen(s)); }
{
return MatchAll((const u_char*)(s), strlen(s));
}
bool Specific_RE_Matcher::MatchAll(const String* s) bool Specific_RE_Matcher::MatchAll(const String* s) {
{
// s->Len() does not include '\0'. // s->Len() does not include '\0'.
return MatchAll(s->Bytes(), s->Len()); return MatchAll(s->Bytes(), s->Len());
} }
int Specific_RE_Matcher::Match(const char* s) int Specific_RE_Matcher::Match(const char* s) { return Match((const u_char*)(s), strlen(s)); }
{
return Match((const u_char*)(s), strlen(s));
}
int Specific_RE_Matcher::Match(const String* s) int Specific_RE_Matcher::Match(const String* s) { return Match(s->Bytes(), s->Len()); }
{
return Match(s->Bytes(), s->Len());
}
int Specific_RE_Matcher::LongestMatch(const char* s) int Specific_RE_Matcher::LongestMatch(const char* s) { return LongestMatch((const u_char*)(s), strlen(s)); }
{
return LongestMatch((const u_char*)(s), strlen(s));
}
int Specific_RE_Matcher::LongestMatch(const String* s) int Specific_RE_Matcher::LongestMatch(const String* s) { return LongestMatch(s->Bytes(), s->Len()); }
{
return LongestMatch(s->Bytes(), s->Len());
}
bool Specific_RE_Matcher::MatchAll(const u_char* bv, int n) bool Specific_RE_Matcher::MatchAll(const u_char* bv, int n) {
{
if ( ! dfa ) if ( ! dfa )
// An empty pattern matches "all" iff what's being // An empty pattern matches "all" iff what's being
// matched is empty. // matched is empty.
@ -247,8 +207,7 @@ bool Specific_RE_Matcher::MatchAll(const u_char* bv, int n)
DFA_State* d = dfa->StartState(); DFA_State* d = dfa->StartState();
d = d->Xtion(ecs[SYM_BOL], dfa); d = d->Xtion(ecs[SYM_BOL], dfa);
while ( d ) while ( d ) {
{
if ( --n < 0 ) if ( --n < 0 )
break; break;
@ -260,10 +219,9 @@ bool Specific_RE_Matcher::MatchAll(const u_char* bv, int n)
d = d->Xtion(ecs[SYM_EOL], dfa); d = d->Xtion(ecs[SYM_EOL], dfa);
return d && d->Accept() != nullptr; return d && d->Accept() != nullptr;
} }
int Specific_RE_Matcher::Match(const u_char* bv, int n) int Specific_RE_Matcher::Match(const u_char* bv, int n) {
{
if ( ! dfa ) if ( ! dfa )
// An empty pattern matches anything. // An empty pattern matches anything.
return 1; return 1;
@ -274,8 +232,7 @@ int Specific_RE_Matcher::Match(const u_char* bv, int n)
if ( ! d ) if ( ! d )
return 0; return 0;
for ( int i = 0; i < n; ++i ) for ( int i = 0; i < n; ++i ) {
{
int ec = ecs[bv[i]]; int ec = ecs[bv[i]];
d = d->Xtion(ec, dfa); d = d->Xtion(ec, dfa);
if ( ! d ) if ( ! d )
@ -285,33 +242,26 @@ int Specific_RE_Matcher::Match(const u_char* bv, int n)
return i + 1; return i + 1;
} }
if ( d ) if ( d ) {
{
d = d->Xtion(ecs[SYM_EOL], dfa); d = d->Xtion(ecs[SYM_EOL], dfa);
if ( d && d->Accept() ) if ( d && d->Accept() )
return n > 0 ? n : 1; // we can't return 0 here for match... return n > 0 ? n : 1; // we can't return 0 here for match...
} }
return 0; return 0;
} }
void Specific_RE_Matcher::Dump(FILE* f) void Specific_RE_Matcher::Dump(FILE* f) { dfa->Dump(f); }
{
dfa->Dump(f);
}
inline void RE_Match_State::AddMatches(const AcceptingSet& as, MatchPos position) inline void RE_Match_State::AddMatches(const AcceptingSet& as, MatchPos position) {
{
using am_idx = std::pair<AcceptIdx, MatchPos>; using am_idx = std::pair<AcceptIdx, MatchPos>;
for ( AcceptingSet::const_iterator it = as.begin(); it != as.end(); ++it ) for ( AcceptingSet::const_iterator it = as.begin(); it != as.end(); ++it )
accepted_matches.insert(am_idx(*it, position)); accepted_matches.insert(am_idx(*it, position));
} }
bool RE_Match_State::Match(const u_char* bv, int n, bool bol, bool eol, bool clear) bool RE_Match_State::Match(const u_char* bv, int n, bool bol, bool eol, bool clear) {
{ if ( current_pos == -1 ) {
if ( current_pos == -1 )
{
// First call to Match(). // First call to Match().
if ( ! dfa ) if ( ! dfa )
return false; return false;
@ -340,8 +290,7 @@ bool RE_Match_State::Match(const u_char* bv, int n, bool bol, bool eol, bool cle
int m = bol ? n + 1 : n; int m = bol ? n + 1 : n;
int e = eol ? -1 : 0; int e = eol ? -1 : 0;
while ( --m >= e ) while ( --m >= e ) {
{
if ( m == n ) if ( m == n )
ec = ecs[SYM_BOL]; ec = ecs[SYM_BOL];
else if ( m == -1 ) else if ( m == -1 )
@ -351,8 +300,7 @@ bool RE_Match_State::Match(const u_char* bv, int n, bool bol, bool eol, bool cle
DFA_State* next_state = current_state->Xtion(ec, dfa); DFA_State* next_state = current_state->Xtion(ec, dfa);
if ( ! next_state ) if ( ! next_state ) {
{
current_state = nullptr; current_state = nullptr;
break; break;
} }
@ -368,10 +316,9 @@ bool RE_Match_State::Match(const u_char* bv, int n, bool bol, bool eol, bool cle
} }
return accepted_matches.size() != old_matches; return accepted_matches.size() != old_matches;
} }
int Specific_RE_Matcher::LongestMatch(const u_char* bv, int n) int Specific_RE_Matcher::LongestMatch(const u_char* bv, int n) {
{
if ( ! dfa ) if ( ! dfa )
// An empty pattern matches anything. // An empty pattern matches anything.
return 0; return 0;
@ -387,8 +334,7 @@ int Specific_RE_Matcher::LongestMatch(const u_char* bv, int n)
if ( d->Accept() ) if ( d->Accept() )
last_accept = 0; last_accept = 0;
for ( int i = 0; i < n; ++i ) for ( int i = 0; i < n; ++i ) {
{
int ec = ecs[bv[i]]; int ec = ecs[bv[i]];
d = d->Xtion(ec, dfa); d = d->Xtion(ec, dfa);
@ -399,18 +345,16 @@ int Specific_RE_Matcher::LongestMatch(const u_char* bv, int n)
last_accept = i + 1; last_accept = i + 1;
} }
if ( d ) if ( d ) {
{
d = d->Xtion(ecs[SYM_EOL], dfa); d = d->Xtion(ecs[SYM_EOL], dfa);
if ( d && d->Accept() ) if ( d && d->Accept() )
return n; return n;
} }
return last_accept; return last_accept;
} }
static RE_Matcher* matcher_merge(const RE_Matcher* re1, const RE_Matcher* re2, const char* merge_op) static RE_Matcher* matcher_merge(const RE_Matcher* re1, const RE_Matcher* re2, const char* merge_op) {
{
const char* text1 = re1->PatternText(); const char* text1 = re1->PatternText();
const char* text2 = re2->PatternText(); const char* text2 = re2->PatternText();
@ -422,80 +366,63 @@ static RE_Matcher* matcher_merge(const RE_Matcher* re1, const RE_Matcher* re2, c
merge->Compile(); merge->Compile();
return merge; return merge;
} }
RE_Matcher* RE_Matcher_conjunction(const RE_Matcher* re1, const RE_Matcher* re2) RE_Matcher* RE_Matcher_conjunction(const RE_Matcher* re1, const RE_Matcher* re2) { return matcher_merge(re1, re2, ""); }
{
return matcher_merge(re1, re2, "");
}
RE_Matcher* RE_Matcher_disjunction(const RE_Matcher* re1, const RE_Matcher* re2) RE_Matcher* RE_Matcher_disjunction(const RE_Matcher* re1, const RE_Matcher* re2) {
{
return matcher_merge(re1, re2, "|"); return matcher_merge(re1, re2, "|");
} }
} // namespace detail } // namespace detail
RE_Matcher::RE_Matcher() RE_Matcher::RE_Matcher() {
{
re_anywhere = new detail::Specific_RE_Matcher(detail::MATCH_ANYWHERE); re_anywhere = new detail::Specific_RE_Matcher(detail::MATCH_ANYWHERE);
re_exact = new detail::Specific_RE_Matcher(detail::MATCH_EXACTLY); re_exact = new detail::Specific_RE_Matcher(detail::MATCH_EXACTLY);
} }
RE_Matcher::RE_Matcher(const char* pat) : orig_text(pat) RE_Matcher::RE_Matcher(const char* pat) : orig_text(pat) {
{
re_anywhere = new detail::Specific_RE_Matcher(detail::MATCH_ANYWHERE); re_anywhere = new detail::Specific_RE_Matcher(detail::MATCH_ANYWHERE);
re_exact = new detail::Specific_RE_Matcher(detail::MATCH_EXACTLY); re_exact = new detail::Specific_RE_Matcher(detail::MATCH_EXACTLY);
AddPat(pat); AddPat(pat);
} }
RE_Matcher::RE_Matcher(const char* exact_pat, const char* anywhere_pat) RE_Matcher::RE_Matcher(const char* exact_pat, const char* anywhere_pat) {
{
re_anywhere = new detail::Specific_RE_Matcher(detail::MATCH_ANYWHERE); re_anywhere = new detail::Specific_RE_Matcher(detail::MATCH_ANYWHERE);
re_anywhere->SetPat(anywhere_pat); re_anywhere->SetPat(anywhere_pat);
re_exact = new detail::Specific_RE_Matcher(detail::MATCH_EXACTLY); re_exact = new detail::Specific_RE_Matcher(detail::MATCH_EXACTLY);
re_exact->SetPat(exact_pat); re_exact->SetPat(exact_pat);
} }
RE_Matcher::~RE_Matcher() RE_Matcher::~RE_Matcher() {
{
delete re_anywhere; delete re_anywhere;
delete re_exact; delete re_exact;
} }
void RE_Matcher::AddPat(const char* new_pat) void RE_Matcher::AddPat(const char* new_pat) {
{
re_anywhere->AddPat(new_pat); re_anywhere->AddPat(new_pat);
re_exact->AddPat(new_pat); re_exact->AddPat(new_pat);
} }
void RE_Matcher::MakeCaseInsensitive() void RE_Matcher::MakeCaseInsensitive() {
{
re_anywhere->MakeCaseInsensitive(); re_anywhere->MakeCaseInsensitive();
re_exact->MakeCaseInsensitive(); re_exact->MakeCaseInsensitive();
is_case_insensitive = true; is_case_insensitive = true;
} }
void RE_Matcher::MakeSingleLine() void RE_Matcher::MakeSingleLine() {
{
re_anywhere->MakeSingleLine(); re_anywhere->MakeSingleLine();
re_exact->MakeSingleLine(); re_exact->MakeSingleLine();
is_single_line = true; is_single_line = true;
} }
bool RE_Matcher::Compile(bool lazy) bool RE_Matcher::Compile(bool lazy) { return re_anywhere->Compile(lazy) && re_exact->Compile(lazy); }
{
return re_anywhere->Compile(lazy) && re_exact->Compile(lazy);
}
TEST_SUITE("re_matcher") TEST_SUITE("re_matcher") {
{ TEST_CASE("simple_pattern") {
TEST_CASE("simple_pattern")
{
RE_Matcher match("[0-9]+"); RE_Matcher match("[0-9]+");
match.Compile(); match.Compile();
CHECK(strcmp(match.OrigText(), "[0-9]+") == 0); CHECK(strcmp(match.OrigText(), "[0-9]+") == 0);
@ -513,8 +440,7 @@ TEST_SUITE("re_matcher")
CHECK(match.MatchAnywhere("abcd") == 0); CHECK(match.MatchAnywhere("abcd") == 0);
} }
TEST_CASE("case_insensitive_mode") TEST_CASE("case_insensitive_mode") {
{
RE_Matcher match("[a-z]+"); RE_Matcher match("[a-z]+");
match.MakeCaseInsensitive(); match.MakeCaseInsensitive();
match.Compile(); match.Compile();
@ -523,8 +449,7 @@ TEST_SUITE("re_matcher")
CHECK(match.MatchExactly("abcDEF")); CHECK(match.MatchExactly("abcDEF"));
} }
TEST_CASE("multi_pattern") TEST_CASE("multi_pattern") {
{
RE_Matcher match("[0-9]+"); RE_Matcher match("[0-9]+");
match.AddPat("[a-z]+"); match.AddPat("[a-z]+");
match.Compile(); match.Compile();
@ -536,8 +461,7 @@ TEST_SUITE("re_matcher")
CHECK_FALSE(match.MatchExactly("abc123")); CHECK_FALSE(match.MatchExactly("abc123"));
} }
TEST_CASE("modes_multi_pattern") TEST_CASE("modes_multi_pattern") {
{
RE_Matcher match("[a-m]+"); RE_Matcher match("[a-m]+");
match.MakeCaseInsensitive(); match.MakeCaseInsensitive();
@ -550,8 +474,7 @@ TEST_SUITE("re_matcher")
CHECK_FALSE(match.MatchExactly("NoP")); CHECK_FALSE(match.MatchExactly("NoP"));
} }
TEST_CASE("single_line_mode") TEST_CASE("single_line_mode") {
{
RE_Matcher match(".*"); RE_Matcher match(".*");
match.MakeSingleLine(); match.MakeSingleLine();
match.Compile(); match.Compile();
@ -580,8 +503,7 @@ TEST_SUITE("re_matcher")
CHECK(match4.MatchExactly("a\nc")); CHECK(match4.MatchExactly("a\nc"));
} }
TEST_CASE("disjunction") TEST_CASE("disjunction") {
{
RE_Matcher match1("a.c"); RE_Matcher match1("a.c");
match1.MakeSingleLine(); match1.MakeSingleLine();
match1.Compile(); match1.Compile();
@ -594,6 +516,6 @@ TEST_SUITE("re_matcher")
CHECK(dj->MatchExactly("def")); CHECK(dj->MatchExactly("def"));
delete dj; delete dj;
} }
} }
} // namespace zeek } // namespace zeek

Some files were not shown because too many files have changed in this diff Show more