This commit is contained in:
Johanna Amann 2023-11-06 09:42:46 +00:00
parent 0afe94154d
commit 36741c2fbf
787 changed files with 131811 additions and 153704 deletions

View file

@ -1,74 +1,66 @@
# Clang-format configuration for Zeek. This configuration requires # Copyright (c) 2020-2023 by the Zeek Project. See LICENSE for details.
# at least clang-format 12.0.1 to format correctly.
---
Language: Cpp Language: Cpp
Standard: c++17
BreakBeforeBraces: Whitesmiths
# BraceWrapping:
# AfterCaseLabel: true
# AfterClass: false
# AfterControlStatement: Always
# AfterEnum: false
# AfterFunction: true
# AfterNamespace: false
# AfterStruct: false
# AfterUnion: false
# AfterExternBlock: false
# BeforeCatch: true
# BeforeElse: true
# BeforeWhile: false
# IndentBraces: true
# SplitEmptyFunction: false
# SplitEmptyRecord: false
# SplitEmptyNamespace: false
AccessModifierOffset: -4 AccessModifierOffset: -4
AlignAfterOpenBracket: Align AlignAfterOpenBracket: Align
AlignTrailingComments: false AlignConsecutiveAssignments: false
AllowShortBlocksOnASingleLine: Empty AlignConsecutiveDeclarations: false
AllowShortEnumsOnASingleLine: true AlignEscapedNewlines: Right
AllowShortFunctionsOnASingleLine: Inline AlignOperands: true
AlignTrailingComments: true
AllowAllParametersOfDeclarationOnNextLine: false
AllowShortBlocksOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: true
AllowShortIfStatementsOnASingleLine: false AllowShortIfStatementsOnASingleLine: false
AllowShortLambdasOnASingleLine: Empty
AllowShortLoopsOnASingleLine: false AllowShortLoopsOnASingleLine: false
AlwaysBreakAfterDefinitionReturnType: None
AlwaysBreakAfterReturnType: None AlwaysBreakAfterReturnType: None
AlwaysBreakBeforeMultilineStrings: true
AlwaysBreakTemplateDeclarations: Yes
BinPackArguments: true BinPackArguments: true
BinPackParameters: true BinPackParameters: true
BreakConstructorInitializers: BeforeColon BraceWrapping:
AfterClass: false
AfterControlStatement: false
AfterEnum: false
AfterFunction: false
AfterNamespace: false
AfterObjCDeclaration: false
AfterStruct: false
AfterUnion: false
AfterExternBlock: false
BeforeCatch: false
BeforeElse: true
IndentBraces: false
SplitEmptyFunction: false
SplitEmptyRecord: false
SplitEmptyNamespace: false
BreakBeforeBinaryOperators: None
BreakBeforeBraces: Custom
BreakBeforeInheritanceComma: false
BreakInheritanceList: BeforeColon BreakInheritanceList: BeforeColon
ColumnLimit: 100 BreakBeforeTernaryOperators: false
ConstructorInitializerAllOnOneLineOrOnePerLine: false BreakConstructorInitializersBeforeComma: false
FixNamespaceComments: false BreakConstructorInitializers: BeforeColon
IndentCaseLabels: true BreakAfterJavaFieldAnnotations: false
IndentCaseBlocks: false BreakStringLiterals: true
IndentExternBlock: NoIndent ColumnLimit: 120
IndentPPDirectives: None CommentPragmas: 'NOLINT'
IndentWidth: 4 CompactNamespaces: false
NamespaceIndentation: None ConstructorInitializerAllOnOneLineOrOnePerLine: true
PointerAlignment: Left ConstructorInitializerIndentWidth: 4
SpaceAfterCStyleCast: false ContinuationIndentWidth: 4
SpaceAfterLogicalNot: true Cpp11BracedListStyle: true
SpaceBeforeAssignmentOperators: true DerivePointerAlignment: false
SpaceBeforeCpp11BracedList: false DisableFormat: false
SpaceBeforeCtorInitializerColon: true ExperimentalAutoDetectBinPacking: false
SpaceBeforeInheritanceColon: true FixNamespaceComments: true
SpaceBeforeParens: ControlStatements ForEachMacros:
SpaceBeforeRangeBasedForLoopColon: true - foreach
SpaceInEmptyBlock: true - Q_FOREACH
SpaceInEmptyParentheses: false - BOOST_FOREACH
SpacesInAngles: false
SpacesInConditionalStatement: true
SpacesInContainerLiterals: false
SpacesInParentheses: false
TabWidth: 4
UseTab: AlignWithSpaces
# Setting this to a high number causes clang-format to prefer breaking somewhere else
# over breaking after the assignment operator in a line that's over the column limit
PenaltyBreakAssignment: 100
IncludeBlocks: Regroup IncludeBlocks: Regroup
# Include categories go like this: # Include categories go like this:
@ -98,3 +90,57 @@ IncludeCategories:
Priority: 4 Priority: 4
- Regex: '.*' - Regex: '.*'
Priority: 5 Priority: 5
IncludeIsMainRegex: '$'
IndentCaseLabels: true
IndentPPDirectives: None
IndentWidth: 4
IndentWrappedFunctionNames: false
JavaScriptQuotes: Leave
JavaScriptWrapImports: true
KeepEmptyLinesAtTheStartOfBlocks: false
MacroBlockBegin: '^BEGIN_'
MacroBlockEnd: '^END_'
MaxEmptyLinesToKeep: 2
NamespaceIndentation: None
ObjCBinPackProtocolList: Auto
ObjCBlockIndentWidth: 2
ObjCSpaceAfterProperty: false
ObjCSpaceBeforeProtocolList: true
PenaltyBreakAssignment: 2
PenaltyBreakBeforeFirstCallParameter: 500
PenaltyBreakComment: 300
PenaltyBreakFirstLessLess: 120
PenaltyBreakString: 1000
PenaltyBreakTemplateDeclaration: 10
PenaltyExcessCharacter: 1000000
PenaltyReturnTypeOnItsOwnLine: 1000
PointerAlignment: Left
ReflowComments: true
SortIncludes: true
SortUsingDeclarations: true
SpaceAfterCStyleCast: false
SpaceAfterTemplateKeyword: false
SpaceAfterLogicalNot: true
SpaceBeforeAssignmentOperators: true
SpaceBeforeCpp11BracedList: false
SpaceBeforeCtorInitializerColon: true
SpaceBeforeInheritanceColon: true
SpaceBeforeParens: ControlStatements
SpaceBeforeRangeBasedForLoopColon: true
SpaceInEmptyParentheses: false
SpacesBeforeTrailingComments: 1
SpacesInAngles: false
SpacesInContainerLiterals: true
SpacesInCStyleCastParentheses: false
SpacesInParentheses: false
SpacesInSquareBrackets: false
SpacesInConditionalStatement: true
Standard: Cpp11
StatementMacros:
- STANDARD_OPERATOR_1
TabWidth: 4
UseTab: Never
---
Language: Json
...

View file

@ -3,9 +3,13 @@
# #
repos: repos:
- repo: https://github.com/pre-commit/mirrors-clang-format - repo: https://github.com/pre-commit/mirrors-clang-format
rev: 'v13.0.0' rev: 'v17.0.3'
hooks: hooks:
- id: clang-format - id: clang-format
types_or:
- "c"
- "c++"
- "json"
- repo: https://github.com/maxwinterstein/shfmt-py - repo: https://github.com/maxwinterstein/shfmt-py
rev: v3.7.0.1 rev: v3.7.0.1
@ -14,7 +18,7 @@ repos:
args: ["-w", "-i", "4", "-ci"] args: ["-w", "-i", "4", "-ci"]
- repo: https://github.com/google/yapf - repo: https://github.com/google/yapf
rev: v0.40.0 rev: v0.40.2
hooks: hooks:
- id: yapf - id: yapf
@ -25,7 +29,7 @@ repos:
exclude: '^auxil/.*$' exclude: '^auxil/.*$'
- repo: https://github.com/crate-ci/typos - repo: https://github.com/crate-ci/typos
rev: v1.16.8 rev: v1.16.21
hooks: hooks:
- id: typos - id: typos
exclude: '^(.typos.toml|src/SmithWaterman.cc|testing/.*|auxil/.*|scripts/base/frameworks/files/magic/.*|CHANGES)$' exclude: '^(.typos.toml|src/SmithWaterman.cc|testing/.*|auxil/.*|scripts/base/frameworks/files/magic/.*|CHANGES)$'

View file

@ -67,7 +67,6 @@ uses_seh = "uses_seh"
[default.extend-words] [default.extend-words]
caf = "caf" caf = "caf"
helo = "helo" helo = "helo"
inout = "inout"
# Seems we use this in the management framework # Seems we use this in the management framework
requestor = "requestor" requestor = "requestor"
# `inout` is used as a keyword in Spicy, but looks like a typo of `input`. # `inout` is used as a keyword in Spicy, but looks like a typo of `input`.

View file

@ -15,25 +15,20 @@
#include "zeek/net_util.h" #include "zeek/net_util.h"
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek::detail namespace zeek::detail {
{
AnonymizeIPAddr* ip_anonymizer[NUM_ADDR_ANONYMIZATION_METHODS] = {nullptr}; AnonymizeIPAddr* ip_anonymizer[NUM_ADDR_ANONYMIZATION_METHODS] = {nullptr};
static uint32_t rand32() static uint32_t rand32() {
{ return ((util::detail::random_number() & 0xffff) << 16) | (util::detail::random_number() & 0xffff);
return ((util::detail::random_number() & 0xffff) << 16) |
(util::detail::random_number() & 0xffff);
} }
// From tcpdpriv. // From tcpdpriv.
static int bi_ffs(uint32_t value) static int bi_ffs(uint32_t value) {
{
int add = 0; int add = 0;
static uint8_t bvals[] = {0, 4, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1}; static uint8_t bvals[] = {0, 4, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1};
if ( (value & 0xFFFF0000) == 0 ) if ( (value & 0xFFFF0000) == 0 ) {
{
if ( value == 0 ) if ( value == 0 )
// Zero input ==> zero output. // Zero input ==> zero output.
return 0; return 0;
@ -59,13 +54,11 @@ static int bi_ffs(uint32_t value)
#define first_n_bit_mask(n) (~(0xFFFFFFFFU >> n)) #define first_n_bit_mask(n) (~(0xFFFFFFFFU >> n))
ipaddr32_t AnonymizeIPAddr::Anonymize(ipaddr32_t addr) ipaddr32_t AnonymizeIPAddr::Anonymize(ipaddr32_t addr) {
{
std::map<ipaddr32_t, ipaddr32_t>::iterator p = mapping.find(addr); std::map<ipaddr32_t, ipaddr32_t>::iterator p = mapping.find(addr);
if ( p != mapping.end() ) if ( p != mapping.end() )
return p->second; return p->second;
else else {
{
ipaddr32_t new_addr = anonymize(addr); ipaddr32_t new_addr = anonymize(addr);
mapping[addr] = new_addr; mapping[addr] = new_addr;
@ -74,35 +67,26 @@ ipaddr32_t AnonymizeIPAddr::Anonymize(ipaddr32_t addr)
} }
// Keep the specified prefix unchanged. // Keep the specified prefix unchanged.
bool AnonymizeIPAddr::PreservePrefix(ipaddr32_t /* input */, int /* num_bits */) bool AnonymizeIPAddr::PreservePrefix(ipaddr32_t /* input */, int /* num_bits */) {
{
reporter->InternalError("prefix preserving is not supported for the anonymizer"); reporter->InternalError("prefix preserving is not supported for the anonymizer");
return false; return false;
} }
bool AnonymizeIPAddr::PreserveNet(ipaddr32_t input) bool AnonymizeIPAddr::PreserveNet(ipaddr32_t input) {
{ switch ( addr_to_class(ntohl(input)) ) {
switch ( addr_to_class(ntohl(input)) ) case 'A': return PreservePrefix(input, 8);
{ case 'B': return PreservePrefix(input, 16);
case 'A': case 'C': return PreservePrefix(input, 24);
return PreservePrefix(input, 8); default: return false;
case 'B':
return PreservePrefix(input, 16);
case 'C':
return PreservePrefix(input, 24);
default:
return false;
} }
} }
ipaddr32_t AnonymizeIPAddr_Seq::anonymize(ipaddr32_t /* input */) ipaddr32_t AnonymizeIPAddr_Seq::anonymize(ipaddr32_t /* input */) {
{
++seq; ++seq;
return htonl(seq); return htonl(seq);
} }
ipaddr32_t AnonymizeIPAddr_RandomMD5::anonymize(ipaddr32_t input) ipaddr32_t AnonymizeIPAddr_RandomMD5::anonymize(ipaddr32_t input) {
{
uint8_t digest[16]; uint8_t digest[16];
ipaddr32_t output = 0; ipaddr32_t output = 0;
@ -119,15 +103,13 @@ ipaddr32_t AnonymizeIPAddr_RandomMD5::anonymize(ipaddr32_t input)
// //
// http://www.imconf.net/imw-2001/proceedings.html // http://www.imconf.net/imw-2001/proceedings.html
ipaddr32_t AnonymizeIPAddr_PrefixMD5::anonymize(ipaddr32_t input) ipaddr32_t AnonymizeIPAddr_PrefixMD5::anonymize(ipaddr32_t input) {
{
uint8_t digest[16]; uint8_t digest[16];
ipaddr32_t prefix_mask = 0xffffffff; ipaddr32_t prefix_mask = 0xffffffff;
input = ntohl(input); input = ntohl(input);
ipaddr32_t output = input; ipaddr32_t output = input;
for ( int i = 0; i < 32; ++i ) for ( int i = 0; i < 32; ++i ) {
{
// PAD(x_0 ... x_{i-1}) = x_0 ... x_{i-1} 1 0 ... 0 . // PAD(x_0 ... x_{i-1}) = x_0 ... x_{i-1} 1 0 ... 0 .
prefix.len = htonl(i + 1); prefix.len = htonl(i + 1);
prefix.prefix = htonl((input & ~(prefix_mask >> i)) | (1 << (31 - i))); prefix.prefix = htonl((input & ~(prefix_mask >> i)) | (1 << (31 - i)));
@ -145,14 +127,12 @@ ipaddr32_t AnonymizeIPAddr_PrefixMD5::anonymize(ipaddr32_t input)
return htonl(output); return htonl(output);
} }
AnonymizeIPAddr_A50::~AnonymizeIPAddr_A50() AnonymizeIPAddr_A50::~AnonymizeIPAddr_A50() {
{
for ( auto& b : blocks ) for ( auto& b : blocks )
delete[] b; delete[] b;
} }
void AnonymizeIPAddr_A50::init() void AnonymizeIPAddr_A50::init() {
{
root = next_free_node = nullptr; root = next_free_node = nullptr;
// Prepare special nodes for 0.0.0.0 and 255.255.255.255. // Prepare special nodes for 0.0.0.0 and 255.255.255.255.
@ -165,12 +145,10 @@ void AnonymizeIPAddr_A50::init()
new_mapping = 0; new_mapping = 0;
} }
bool AnonymizeIPAddr_A50::PreservePrefix(ipaddr32_t input, int num_bits) bool AnonymizeIPAddr_A50::PreservePrefix(ipaddr32_t input, int num_bits) {
{
DEBUG_MSG("%s/%d\n", IPAddr(IPv4, &input, IPAddr::Network).AsString().c_str(), num_bits); DEBUG_MSG("%s/%d\n", IPAddr(IPv4, &input, IPAddr::Network).AsString().c_str(), num_bits);
if ( ! before_anonymization ) if ( ! before_anonymization ) {
{
reporter->Error("prefix preservation specified after anonymization begun"); reporter->Error("prefix preservation specified after anonymization begun");
return false; return false;
} }
@ -186,8 +164,7 @@ bool AnonymizeIPAddr_A50::PreservePrefix(ipaddr32_t input, int num_bits)
if ( num_bits == 32 ) if ( num_bits == 32 )
n->output = input; n->output = input;
else if ( num_bits > 0 ) else if ( num_bits > 0 ) {
{
assert((0xFFFFFFFFU >> 1) == 0x7FFFFFFFU); assert((0xFFFFFFFFU >> 1) == 0x7FFFFFFFU);
uint32_t suffix_mask = (0xFFFFFFFFU >> num_bits); uint32_t suffix_mask = (0xFFFFFFFFU >> num_bits);
uint32_t prefix_mask = ~suffix_mask; uint32_t prefix_mask = ~suffix_mask;
@ -197,13 +174,11 @@ bool AnonymizeIPAddr_A50::PreservePrefix(ipaddr32_t input, int num_bits)
return true; return true;
} }
ipaddr32_t AnonymizeIPAddr_A50::anonymize(ipaddr32_t a) ipaddr32_t AnonymizeIPAddr_A50::anonymize(ipaddr32_t a) {
{
before_anonymization = 0; before_anonymization = 0;
new_mapping = 0; new_mapping = 0;
if ( Node* n = find_node(ntohl(a)) ) if ( Node* n = find_node(ntohl(a)) ) {
{
ipaddr32_t output = htonl(n->output); ipaddr32_t output = htonl(n->output);
return output; return output;
} }
@ -211,8 +186,7 @@ ipaddr32_t AnonymizeIPAddr_A50::anonymize(ipaddr32_t a)
return 0; return 0;
} }
AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node_block() AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node_block() {
{
assert(! next_free_node); assert(! next_free_node);
int block_size = 1024; int block_size = 1024;
@ -231,12 +205,10 @@ AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node_block()
return &block[0]; return &block[0];
} }
inline AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node() inline AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node() {
{
new_mapping = 1; new_mapping = 1;
if ( next_free_node ) if ( next_free_node ) {
{
Node* n = next_free_node; Node* n = next_free_node;
next_free_node = n->child[0]; next_free_node = n->child[0];
return n; return n;
@ -245,19 +217,16 @@ inline AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node()
return new_node_block(); return new_node_block();
} }
inline void AnonymizeIPAddr_A50::free_node(Node* n) inline void AnonymizeIPAddr_A50::free_node(Node* n) {
{
n->child[0] = next_free_node; n->child[0] = next_free_node;
next_free_node = n; next_free_node = n;
} }
ipaddr32_t AnonymizeIPAddr_A50::make_output(ipaddr32_t old_output, int swivel) const ipaddr32_t AnonymizeIPAddr_A50::make_output(ipaddr32_t old_output, int swivel) const {
{
// -A50 anonymization // -A50 anonymization
if ( swivel == 32 ) if ( swivel == 32 )
return old_output ^ 1; return old_output ^ 1;
else else {
{
// Bits up to swivel are unchanged; bit swivel is flipped. // Bits up to swivel are unchanged; bit swivel is flipped.
ipaddr32_t known_part = ((old_output >> (32 - swivel)) ^ 1) << (32 - swivel); ipaddr32_t known_part = ((old_output >> (32 - swivel)) ^ 1) << (32 - swivel);
@ -266,8 +235,7 @@ ipaddr32_t AnonymizeIPAddr_A50::make_output(ipaddr32_t old_output, int swivel) c
} }
} }
AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::make_peer(ipaddr32_t a, Node* n) AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::make_peer(ipaddr32_t a, Node* n) {
{
if ( a == 0 || a == 0xFFFFFFFFU ) if ( a == 0 || a == 0xFFFFFFFFU )
reporter->InternalError("0.0.0.0 and 255.255.255.255 should never get into the tree"); reporter->InternalError("0.0.0.0 and 255.255.255.255 should never get into the tree");
@ -281,8 +249,7 @@ AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::make_peer(ipaddr32_t a, Node* n)
return nullptr; return nullptr;
down[1] = new_node(); down[1] = new_node();
if ( ! down[1] ) if ( ! down[1] ) {
{
free_node(down[0]); free_node(down[0]);
return nullptr; return nullptr;
} }
@ -307,15 +274,13 @@ AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::make_peer(ipaddr32_t a, Node* n)
return down[bitvalue]; return down[bitvalue];
} }
AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::find_node(ipaddr32_t a) AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::find_node(ipaddr32_t a) {
{
// Watch out for special IP addresses, which never make it // Watch out for special IP addresses, which never make it
// into the tree. // into the tree.
if ( a == 0 || a == 0xFFFFFFFFU ) if ( a == 0 || a == 0xFFFFFFFFU )
return &special_nodes[a & 1]; return &special_nodes[a & 1];
if ( ! root ) if ( ! root ) {
{
root = new_node(); root = new_node();
root->input = a; root->input = a;
root->output = rand32(); root->output = rand32();
@ -326,16 +291,14 @@ AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::find_node(ipaddr32_t a)
// Straight from tcpdpriv. // Straight from tcpdpriv.
Node* n = root; Node* n = root;
while ( n ) while ( n ) {
{
if ( n->input == a ) if ( n->input == a )
return n; return n;
if ( ! n->child[0] ) if ( ! n->child[0] )
n = make_peer(a, n); n = make_peer(a, n);
else else {
{
// swivel is the first bit in which the two children // swivel is the first bit in which the two children
// differ. // differ.
int swivel = bi_ffs(n->child[0]->input ^ n->child[1]->input); int swivel = bi_ffs(n->child[0]->input ^ n->child[1]->input);
@ -360,8 +323,7 @@ static TableValPtr anon_preserve_orig_addr;
static TableValPtr anon_preserve_resp_addr; static TableValPtr anon_preserve_resp_addr;
static TableValPtr anon_preserve_other_addr; static TableValPtr anon_preserve_other_addr;
void init_ip_addr_anonymizers() void init_ip_addr_anonymizers() {
{
ip_anonymizer[KEEP_ORIG_ADDR] = nullptr; ip_anonymizer[KEEP_ORIG_ADDR] = nullptr;
ip_anonymizer[SEQUENTIALLY_NUMBERED] = new AnonymizeIPAddr_Seq(); ip_anonymizer[SEQUENTIALLY_NUMBERED] = new AnonymizeIPAddr_Seq();
ip_anonymizer[RANDOM_MD5] = new AnonymizeIPAddr_RandomMD5(); ip_anonymizer[RANDOM_MD5] = new AnonymizeIPAddr_RandomMD5();
@ -384,15 +346,13 @@ void init_ip_addr_anonymizers()
anon_preserve_other_addr = cast_intrusive<TableVal>(id->GetVal()); anon_preserve_other_addr = cast_intrusive<TableVal>(id->GetVal());
} }
ipaddr32_t anonymize_ip(ipaddr32_t ip, enum ip_addr_anonymization_class_t cl) ipaddr32_t anonymize_ip(ipaddr32_t ip, enum ip_addr_anonymization_class_t cl) {
{
TableVal* preserve_addr = nullptr; TableVal* preserve_addr = nullptr;
auto addr = make_intrusive<AddrVal>(ip); auto addr = make_intrusive<AddrVal>(ip);
int method = -1; int method = -1;
switch ( cl ) switch ( cl ) {
{
case ORIG_ADDR: // client address case ORIG_ADDR: // client address
preserve_addr = anon_preserve_orig_addr.get(); preserve_addr = anon_preserve_orig_addr.get();
method = orig_addr_anonymization; method = orig_addr_anonymization;
@ -414,8 +374,7 @@ ipaddr32_t anonymize_ip(ipaddr32_t ip, enum ip_addr_anonymization_class_t cl)
if ( preserve_addr && preserve_addr->FindOrDefault(addr) ) if ( preserve_addr && preserve_addr->FindOrDefault(addr) )
new_ip = ip; new_ip = ip;
else if ( method >= 0 && method < NUM_ADDR_ANONYMIZATION_METHODS ) else if ( method >= 0 && method < NUM_ADDR_ANONYMIZATION_METHODS ) {
{
if ( method == KEEP_ORIG_ADDR ) if ( method == KEEP_ORIG_ADDR )
new_ip = ip; new_ip = ip;
@ -437,11 +396,9 @@ ipaddr32_t anonymize_ip(ipaddr32_t ip, enum ip_addr_anonymization_class_t cl)
#ifdef LOG_ANONYMIZATION_MAPPING #ifdef LOG_ANONYMIZATION_MAPPING
void log_anonymization_mapping(ipaddr32_t input, ipaddr32_t output) void log_anonymization_mapping(ipaddr32_t input, ipaddr32_t output) {
{
if ( anonymization_mapping ) if ( anonymization_mapping )
event_mgr.Enqueue(anonymization_mapping, make_intrusive<AddrVal>(input), event_mgr.Enqueue(anonymization_mapping, make_intrusive<AddrVal>(input), make_intrusive<AddrVal>(output));
make_intrusive<AddrVal>(output));
} }
#endif #endif

View file

@ -14,21 +14,18 @@
#include <map> #include <map>
#include <vector> #include <vector>
namespace zeek::detail namespace zeek::detail {
{
// TODO: Anon.h may not be the right place to put these functions ... // TODO: Anon.h may not be the right place to put these functions ...
enum ip_addr_anonymization_class_t enum ip_addr_anonymization_class_t {
{
ORIG_ADDR, // client address ORIG_ADDR, // client address
RESP_ADDR, // server address RESP_ADDR, // server address
OTHER_ADDR, OTHER_ADDR,
NUM_ADDR_ANONYMIZATION_CLASSES, NUM_ADDR_ANONYMIZATION_CLASSES,
}; };
enum ip_addr_anonymization_method_t enum ip_addr_anonymization_method_t {
{
KEEP_ORIG_ADDR, KEEP_ORIG_ADDR,
SEQUENTIALLY_NUMBERED, SEQUENTIALLY_NUMBERED,
RANDOM_MD5, RANDOM_MD5,
@ -42,8 +39,7 @@ using ipaddr32_t = uint32_t;
// NOTE: all addresses in parameters of *public* functions are in // NOTE: all addresses in parameters of *public* functions are in
// network order. // network order.
class AnonymizeIPAddr class AnonymizeIPAddr {
{
public: public:
virtual ~AnonymizeIPAddr() = default; virtual ~AnonymizeIPAddr() = default;
@ -59,8 +55,7 @@ protected:
std::map<ipaddr32_t, ipaddr32_t> mapping; std::map<ipaddr32_t, ipaddr32_t> mapping;
}; };
class AnonymizeIPAddr_Seq : public AnonymizeIPAddr class AnonymizeIPAddr_Seq : public AnonymizeIPAddr {
{
public: public:
AnonymizeIPAddr_Seq() { seq = 1; } AnonymizeIPAddr_Seq() { seq = 1; }
ipaddr32_t anonymize(ipaddr32_t addr) override; ipaddr32_t anonymize(ipaddr32_t addr) override;
@ -69,27 +64,23 @@ protected:
ipaddr32_t seq; ipaddr32_t seq;
}; };
class AnonymizeIPAddr_RandomMD5 : public AnonymizeIPAddr class AnonymizeIPAddr_RandomMD5 : public AnonymizeIPAddr {
{
public: public:
ipaddr32_t anonymize(ipaddr32_t addr) override; ipaddr32_t anonymize(ipaddr32_t addr) override;
}; };
class AnonymizeIPAddr_PrefixMD5 : public AnonymizeIPAddr class AnonymizeIPAddr_PrefixMD5 : public AnonymizeIPAddr {
{
public: public:
ipaddr32_t anonymize(ipaddr32_t addr) override; ipaddr32_t anonymize(ipaddr32_t addr) override;
protected: protected:
struct anon_prefix struct anon_prefix {
{
int len; int len;
ipaddr32_t prefix; ipaddr32_t prefix;
} prefix; } prefix;
}; };
class AnonymizeIPAddr_A50 : public AnonymizeIPAddr class AnonymizeIPAddr_A50 : public AnonymizeIPAddr {
{
public: public:
AnonymizeIPAddr_A50() { init(); } AnonymizeIPAddr_A50() { init(); }
~AnonymizeIPAddr_A50() override; ~AnonymizeIPAddr_A50() override;
@ -98,8 +89,7 @@ public:
bool PreservePrefix(ipaddr32_t input, int num_bits) override; bool PreservePrefix(ipaddr32_t input, int num_bits) override;
protected: protected:
struct Node struct Node {
{
ipaddr32_t input; ipaddr32_t input;
ipaddr32_t output; ipaddr32_t output;
Node* child[2]; Node* child[2];

View file

@ -11,11 +11,9 @@
#include "zeek/input/Manager.h" #include "zeek/input/Manager.h"
#include "zeek/threading/SerialTypes.h" #include "zeek/threading/SerialTypes.h"
namespace zeek::detail namespace zeek::detail {
{
const char* attr_name(AttrTag t) const char* attr_name(AttrTag t) {
{
// Do not collapse the list. // Do not collapse the list.
// clang-format off // clang-format off
static const char* attr_names[int(NUM_ATTRS)] = { static const char* attr_names[int(NUM_ATTRS)] = {
@ -50,21 +48,16 @@ const char* attr_name(AttrTag t)
return attr_names[int(t)]; return attr_names[int(t)];
} }
Attr::Attr(AttrTag t, ExprPtr e) : expr(std::move(e)) Attr::Attr(AttrTag t, ExprPtr e) : expr(std::move(e)) {
{
tag = t; tag = t;
SetLocationInfo(&start_location, &end_location); SetLocationInfo(&start_location, &end_location);
} }
Attr::Attr(AttrTag t) : Attr(t, nullptr) {} Attr::Attr(AttrTag t) : Attr(t, nullptr) {}
void Attr::SetAttrExpr(ExprPtr e) void Attr::SetAttrExpr(ExprPtr e) { expr = std::move(e); }
{
expr = std::move(e);
}
std::string Attr::DeprecationMessage() const std::string Attr::DeprecationMessage() const {
{
if ( tag != ATTR_DEPRECATED ) if ( tag != ATTR_DEPRECATED )
return ""; return "";
@ -75,12 +68,10 @@ std::string Attr::DeprecationMessage() const
return ce->Value()->AsStringVal()->CheckString(); return ce->Value()->AsStringVal()->CheckString();
} }
void Attr::Describe(ODesc* d) const void Attr::Describe(ODesc* d) const {
{
AddTag(d); AddTag(d);
if ( expr ) if ( expr ) {
{
if ( ! d->IsBinary() ) if ( ! d->IsBinary() )
d->Add("="); d->Add("=");
@ -88,27 +79,22 @@ void Attr::Describe(ODesc* d) const
} }
} }
void Attr::DescribeReST(ODesc* d, bool shorten) const void Attr::DescribeReST(ODesc* d, bool shorten) const {
{ auto add_long_expr_string = [](ODesc* d, const std::string& s, bool shorten) {
auto add_long_expr_string = [](ODesc* d, const std::string& s, bool shorten)
{
constexpr auto max_expr_chars = 32; constexpr auto max_expr_chars = 32;
constexpr auto shortened_expr = "*...*"; constexpr auto shortened_expr = "*...*";
if ( s.size() > max_expr_chars ) if ( s.size() > max_expr_chars ) {
{
if ( shorten ) if ( shorten )
d->Add(shortened_expr); d->Add(shortened_expr);
else else {
{
// Long inline-literals likely won't wrap well in HTML render // Long inline-literals likely won't wrap well in HTML render
d->Add("*"); d->Add("*");
d->Add(s); d->Add(s);
d->Add("*"); d->Add("*");
} }
} }
else else {
{
d->Add("``"); d->Add("``");
d->Add(s); d->Add(s);
d->Add("``"); d->Add("``");
@ -119,28 +105,24 @@ void Attr::DescribeReST(ODesc* d, bool shorten) const
AddTag(d); AddTag(d);
d->Add("`"); d->Add("`");
if ( expr ) if ( expr ) {
{
d->SP(); d->SP();
d->Add("="); d->Add("=");
d->SP(); d->SP();
if ( expr->Tag() == EXPR_NAME ) if ( expr->Tag() == EXPR_NAME ) {
{
d->Add(":zeek:see:`"); d->Add(":zeek:see:`");
expr->Describe(d); expr->Describe(d);
d->Add("`"); d->Add("`");
} }
else if ( expr->GetType()->Tag() == TYPE_FUNC ) else if ( expr->GetType()->Tag() == TYPE_FUNC ) {
{
d->Add(":zeek:type:`"); d->Add(":zeek:type:`");
d->Add(expr->GetType()->AsFuncType()->FlavorString()); d->Add(expr->GetType()->AsFuncType()->FlavorString());
d->Add("`"); d->Add("`");
} }
else if ( expr->Tag() == EXPR_CONST ) else if ( expr->Tag() == EXPR_CONST ) {
{
ODesc dd; ODesc dd;
dd.SetQuotes(true); dd.SetQuotes(true);
expr->Describe(&dd); expr->Describe(&dd);
@ -148,8 +130,7 @@ void Attr::DescribeReST(ODesc* d, bool shorten) const
add_long_expr_string(d, s, shorten); add_long_expr_string(d, s, shorten);
} }
else else {
{
ODesc dd; ODesc dd;
expr->Eval(nullptr)->Describe(&dd); expr->Eval(nullptr)->Describe(&dd);
std::string s = dd.Description(); std::string s = dd.Description();
@ -163,21 +144,18 @@ void Attr::DescribeReST(ODesc* d, bool shorten) const
} }
} }
void Attr::AddTag(ODesc* d) const void Attr::AddTag(ODesc* d) const {
{
if ( d->IsBinary() ) if ( d->IsBinary() )
d->Add(static_cast<zeek_int_t>(Tag())); d->Add(static_cast<zeek_int_t>(Tag()));
else else
d->Add(attr_name(Tag())); d->Add(attr_name(Tag()));
} }
detail::TraversalCode Attr::Traverse(detail::TraversalCallback* cb) const detail::TraversalCode Attr::Traverse(detail::TraversalCallback* cb) const {
{
auto tc = cb->PreAttr(this); auto tc = cb->PreAttr(this);
HANDLE_TC_ATTR_PRE(tc); HANDLE_TC_ATTR_PRE(tc);
if ( expr ) if ( expr ) {
{
auto tc = expr->Traverse(cb); auto tc = expr->Traverse(cb);
HANDLE_TC_ATTR_PRE(tc); HANDLE_TC_ATTR_PRE(tc);
} }
@ -187,13 +165,9 @@ detail::TraversalCode Attr::Traverse(detail::TraversalCallback* cb) const
} }
Attributes::Attributes(TypePtr t, bool arg_in_record, bool is_global) Attributes::Attributes(TypePtr t, bool arg_in_record, bool is_global)
: Attributes(std::vector<AttrPtr>{}, std::move(t), arg_in_record, is_global) : Attributes(std::vector<AttrPtr>{}, std::move(t), arg_in_record, is_global) {}
{
}
Attributes::Attributes(std::vector<AttrPtr> a, TypePtr t, bool arg_in_record, bool is_global) Attributes::Attributes(std::vector<AttrPtr> a, TypePtr t, bool arg_in_record, bool is_global) : type(std::move(t)) {
: type(std::move(t))
{
attrs.reserve(a.size()); attrs.reserve(a.size());
in_record = arg_in_record; in_record = arg_in_record;
global_var = is_global; global_var = is_global;
@ -208,19 +182,15 @@ Attributes::Attributes(std::vector<AttrPtr> a, TypePtr t, bool arg_in_record, bo
AddAttr(std::move(attr)); AddAttr(std::move(attr));
} }
void Attributes::AddAttr(AttrPtr attr, bool is_redef) void Attributes::AddAttr(AttrPtr attr, bool is_redef) {
{ auto acceptable_duplicate_attr = [](const AttrPtr& attr, const AttrPtr& existing) -> bool {
auto acceptable_duplicate_attr = [](const AttrPtr& attr, const AttrPtr& existing) -> bool
{
if ( attr == existing ) if ( attr == existing )
return true; return true;
AttrTag new_tag = attr->Tag(); AttrTag new_tag = attr->Tag();
if ( new_tag == ATTR_DEPRECATED ) if ( new_tag == ATTR_DEPRECATED ) {
{ if ( ! attr->DeprecationMessage().empty() || (existing && ! existing->DeprecationMessage().empty()) )
if ( ! attr->DeprecationMessage().empty() ||
(existing && ! existing->DeprecationMessage().empty()) )
return false; return false;
return true; return true;
@ -233,8 +203,7 @@ void Attributes::AddAttr(AttrPtr attr, bool is_redef)
// A `redef` is allowed to overwrite an existing attribute instead of // A `redef` is allowed to overwrite an existing attribute instead of
// flagging it as ambiguous. // flagging it as ambiguous.
if ( ! is_redef ) if ( ! is_redef ) {
{
auto existing = Find(attr->Tag()); auto existing = Find(attr->Tag());
if ( existing && ! acceptable_duplicate_attr(attr, existing) ) if ( existing && ! acceptable_duplicate_attr(attr, existing) )
reporter->Error("Duplicate %s attribute is ambiguous", attr_name(attr->Tag())); reporter->Error("Duplicate %s attribute is ambiguous", attr_name(attr->Tag()));
@ -254,28 +223,24 @@ void Attributes::AddAttr(AttrPtr attr, bool is_redef)
// For ADD_FUNC or DEL_FUNC, add in an implicit REDEF, since // For ADD_FUNC or DEL_FUNC, add in an implicit REDEF, since
// those attributes only have meaning for a redefinable value. // those attributes only have meaning for a redefinable value.
if ( (attr->Tag() == ATTR_ADD_FUNC || attr->Tag() == ATTR_DEL_FUNC) && ! Find(ATTR_REDEF) ) if ( (attr->Tag() == ATTR_ADD_FUNC || attr->Tag() == ATTR_DEL_FUNC) && ! Find(ATTR_REDEF) ) {
{
auto a = make_intrusive<Attr>(ATTR_REDEF); auto a = make_intrusive<Attr>(ATTR_REDEF);
attrs.emplace_back(std::move(a)); attrs.emplace_back(std::move(a));
} }
// For DEFAULT, add an implicit OPTIONAL if it's not a global. // For DEFAULT, add an implicit OPTIONAL if it's not a global.
if ( ! global_var && attr->Tag() == ATTR_DEFAULT && ! Find(ATTR_OPTIONAL) ) if ( ! global_var && attr->Tag() == ATTR_DEFAULT && ! Find(ATTR_OPTIONAL) ) {
{
auto a = make_intrusive<Attr>(ATTR_OPTIONAL); auto a = make_intrusive<Attr>(ATTR_OPTIONAL);
attrs.emplace_back(std::move(a)); attrs.emplace_back(std::move(a));
} }
} }
void Attributes::AddAttrs(const AttributesPtr& a, bool is_redef) void Attributes::AddAttrs(const AttributesPtr& a, bool is_redef) {
{
for ( const auto& attr : a->GetAttrs() ) for ( const auto& attr : a->GetAttrs() )
AddAttr(attr, is_redef); AddAttr(attr, is_redef);
} }
const AttrPtr& Attributes::Find(AttrTag t) const const AttrPtr& Attributes::Find(AttrTag t) const {
{
for ( const auto& a : attrs ) for ( const auto& a : attrs )
if ( a->Tag() == t ) if ( a->Tag() == t )
return a; return a;
@ -283,10 +248,8 @@ const AttrPtr& Attributes::Find(AttrTag t) const
return Attr::nil; return Attr::nil;
} }
void Attributes::RemoveAttr(AttrTag t) void Attributes::RemoveAttr(AttrTag t) {
{ for ( auto it = attrs.begin(); it != attrs.end(); ) {
for ( auto it = attrs.begin(); it != attrs.end(); )
{
if ( (*it)->Tag() == t ) if ( (*it)->Tag() == t )
it = attrs.erase(it); it = attrs.erase(it);
else else
@ -294,18 +257,15 @@ void Attributes::RemoveAttr(AttrTag t)
} }
} }
void Attributes::Describe(ODesc* d) const void Attributes::Describe(ODesc* d) const {
{ if ( attrs.empty() ) {
if ( attrs.empty() )
{
d->AddCount(0); d->AddCount(0);
return; return;
} }
d->AddCount(static_cast<uint64_t>(attrs.size())); d->AddCount(static_cast<uint64_t>(attrs.size()));
for ( size_t i = 0; i < attrs.size(); ++i ) for ( size_t i = 0; i < attrs.size(); ++i ) {
{
if ( d->IsReadable() && i > 0 ) if ( d->IsReadable() && i > 0 )
d->Add(", "); d->Add(", ");
@ -313,10 +273,8 @@ void Attributes::Describe(ODesc* d) const
} }
} }
void Attributes::DescribeReST(ODesc* d, bool shorten) const void Attributes::DescribeReST(ODesc* d, bool shorten) const {
{ for ( size_t i = 0; i < attrs.size(); ++i ) {
for ( size_t i = 0; i < attrs.size(); ++i )
{
if ( i > 0 ) if ( i > 0 )
d->Add(" "); d->Add(" ");
@ -324,15 +282,12 @@ void Attributes::DescribeReST(ODesc* d, bool shorten) const
} }
} }
void Attributes::CheckAttr(Attr* a) void Attributes::CheckAttr(Attr* a) {
{ switch ( a->Tag() ) {
switch ( a->Tag() )
{
case ATTR_DEPRECATED: case ATTR_DEPRECATED:
case ATTR_REDEF: case ATTR_REDEF:
case ATTR_IS_ASSIGNED: case ATTR_IS_ASSIGNED:
case ATTR_IS_USED: case ATTR_IS_USED: break;
break;
case ATTR_OPTIONAL: case ATTR_OPTIONAL:
if ( global_var ) if ( global_var )
@ -340,67 +295,53 @@ void Attributes::CheckAttr(Attr* a)
break; break;
case ATTR_ADD_FUNC: case ATTR_ADD_FUNC:
case ATTR_DEL_FUNC: case ATTR_DEL_FUNC: {
{
bool is_add = a->Tag() == ATTR_ADD_FUNC; bool is_add = a->Tag() == ATTR_ADD_FUNC;
const auto& at = a->GetExpr()->GetType(); const auto& at = a->GetExpr()->GetType();
if ( at->Tag() != TYPE_FUNC ) if ( at->Tag() != TYPE_FUNC ) {
{ a->GetExpr()->Error(is_add ? "&add_func must be a function" : "&delete_func must be a function");
a->GetExpr()->Error(is_add ? "&add_func must be a function"
: "&delete_func must be a function");
break; break;
} }
FuncType* aft = at->AsFuncType(); FuncType* aft = at->AsFuncType();
if ( ! same_type(aft->Yield(), type) ) if ( ! same_type(aft->Yield(), type) ) {
{ a->GetExpr()->Error(is_add ? "&add_func function must yield same type as variable" :
a->GetExpr()->Error(is_add "&delete_func function must yield same type as variable");
? "&add_func function must yield same type as variable"
: "&delete_func function must yield same type as variable");
break; break;
} }
} } break;
break;
case ATTR_DEFAULT_INSERT: case ATTR_DEFAULT_INSERT: {
{ if ( ! type->IsTable() ) {
if ( ! type->IsTable() )
{
Error("&default_insert only applicable to tables"); Error("&default_insert only applicable to tables");
break; break;
} }
if ( Find(ATTR_DEFAULT) ) if ( Find(ATTR_DEFAULT) ) {
{
Error("&default and &default_insert cannot be used together"); Error("&default and &default_insert cannot be used together");
break; break;
} }
std::string err_msg; std::string err_msg;
if ( ! check_default_attr(a, type, global_var, in_record, err_msg) && if ( ! check_default_attr(a, type, global_var, in_record, err_msg) && ! err_msg.empty() )
! err_msg.empty() )
Error(err_msg.c_str()); Error(err_msg.c_str());
break; break;
} }
case ATTR_DEFAULT: case ATTR_DEFAULT: {
{ if ( Find(ATTR_DEFAULT_INSERT) ) {
if ( Find(ATTR_DEFAULT_INSERT) )
{
Error("&default and &default_insert cannot be used together"); Error("&default and &default_insert cannot be used together");
break; break;
} }
std::string err_msg; std::string err_msg;
if ( ! check_default_attr(a, type, global_var, in_record, err_msg) && if ( ! check_default_attr(a, type, global_var, in_record, err_msg) && ! err_msg.empty() )
! err_msg.empty() )
Error(err_msg.c_str()); Error(err_msg.c_str());
break; break;
} }
case ATTR_EXPIRE_READ: case ATTR_EXPIRE_READ: {
{
if ( Find(ATTR_BROKER_STORE) ) if ( Find(ATTR_BROKER_STORE) )
Error("&broker_store and &read_expire cannot be used simultaneously"); Error("&broker_store and &read_expire cannot be used simultaneously");
@ -410,26 +351,23 @@ void Attributes::CheckAttr(Attr* a)
// fallthrough // fallthrough
case ATTR_EXPIRE_WRITE: case ATTR_EXPIRE_WRITE:
case ATTR_EXPIRE_CREATE: case ATTR_EXPIRE_CREATE: {
{ if ( type->Tag() != TYPE_TABLE ) {
if ( type->Tag() != TYPE_TABLE )
{
Error("expiration only applicable to sets/tables"); Error("expiration only applicable to sets/tables");
break; break;
} }
int num_expires = 0; int num_expires = 0;
for ( const auto& at : attrs ) for ( const auto& at : attrs ) {
{
if ( at->Tag() == ATTR_EXPIRE_READ || at->Tag() == ATTR_EXPIRE_WRITE || if ( at->Tag() == ATTR_EXPIRE_READ || at->Tag() == ATTR_EXPIRE_WRITE ||
at->Tag() == ATTR_EXPIRE_CREATE ) at->Tag() == ATTR_EXPIRE_CREATE )
num_expires++; num_expires++;
} }
if ( num_expires > 1 ) if ( num_expires > 1 ) {
{ Error(
Error("set/table can only have one of &read_expire, &write_expire, " "set/table can only have one of &read_expire, &write_expire, "
"&create_expire"); "&create_expire");
break; break;
} }
@ -443,10 +381,8 @@ void Attributes::CheckAttr(Attr* a)
break; break;
case ATTR_EXPIRE_FUNC: case ATTR_EXPIRE_FUNC: {
{ if ( type->Tag() != TYPE_TABLE ) {
if ( type->Tag() != TYPE_TABLE )
{
Error("expiration only applicable to tables"); Error("expiration only applicable to tables");
break; break;
} }
@ -462,10 +398,8 @@ void Attributes::CheckAttr(Attr* a)
break; break;
} }
case ATTR_ON_CHANGE: case ATTR_ON_CHANGE: {
{ if ( type->Tag() != TYPE_TABLE ) {
if ( type->Tag() != TYPE_TABLE )
{
Error("&on_change only applicable to sets/tables"); Error("&on_change only applicable to sets/tables");
break; break;
} }
@ -478,8 +412,7 @@ void Attributes::CheckAttr(Attr* a)
const FuncType* c_ft = change_func->GetType()->AsFuncType(); const FuncType* c_ft = change_func->GetType()->AsFuncType();
if ( c_ft->Yield()->Tag() != TYPE_VOID ) if ( c_ft->Yield()->Tag() != TYPE_VOID ) {
{
Error("&on_change must not return a value"); Error("&on_change must not return a value");
break; break;
} }
@ -491,47 +424,38 @@ void Attributes::CheckAttr(Attr* a)
const auto& args = c_ft->ParamList()->GetTypes(); const auto& args = c_ft->ParamList()->GetTypes();
const auto& t_indexes = the_table->GetIndexTypes(); const auto& t_indexes = the_table->GetIndexTypes();
if ( args.size() != (type->IsSet() ? 2 : 3) + t_indexes.size() ) if ( args.size() != (type->IsSet() ? 2 : 3) + t_indexes.size() ) {
{
Error("&on_change function has incorrect number of arguments"); Error("&on_change function has incorrect number of arguments");
break; break;
} }
if ( ! same_type(args[0], the_table->AsTableType()) ) if ( ! same_type(args[0], the_table->AsTableType()) ) {
{
Error("&on_change: first argument must be of same type as table"); Error("&on_change: first argument must be of same type as table");
break; break;
} }
// can't check exact type here yet - the data structures don't exist yet. // can't check exact type here yet - the data structures don't exist yet.
if ( args[1]->Tag() != TYPE_ENUM ) if ( args[1]->Tag() != TYPE_ENUM ) {
{
Error("&on_change: second argument must be a TableChange enum"); Error("&on_change: second argument must be a TableChange enum");
break; break;
} }
for ( size_t i = 0; i < t_indexes.size(); i++ ) for ( size_t i = 0; i < t_indexes.size(); i++ ) {
{ if ( ! same_type(args[2 + i], t_indexes[i]) ) {
if ( ! same_type(args[2 + i], t_indexes[i]) )
{
Error("&on_change: index types do not match table"); Error("&on_change: index types do not match table");
break; break;
} }
} }
if ( ! type->IsSet() ) if ( ! type->IsSet() )
if ( ! same_type(args[2 + t_indexes.size()], the_table->Yield()) ) if ( ! same_type(args[2 + t_indexes.size()], the_table->Yield()) ) {
{
Error("&on_change: value type does not match table"); Error("&on_change: value type does not match table");
break; break;
} }
} } break;
break;
case ATTR_BACKEND: case ATTR_BACKEND: {
{ if ( ! global_var || type->Tag() != TYPE_TABLE ) {
if ( ! global_var || type->Tag() != TYPE_TABLE )
{
Error("&backend only applicable to global sets/tables"); Error("&backend only applicable to global sets/tables");
break; break;
} }
@ -539,8 +463,7 @@ void Attributes::CheckAttr(Attr* a)
// cannot do better equality check - the Broker types are not // cannot do better equality check - the Broker types are not
// actually existing yet when we are here. We will do that // actually existing yet when we are here. We will do that
// later - before actually attaching to a broker store // later - before actually attaching to a broker store
if ( a->GetExpr()->GetType()->Tag() != TYPE_ENUM ) if ( a->GetExpr()->GetType()->Tag() != TYPE_ENUM ) {
{
Error("&backend must take an enum argument"); Error("&backend must take an enum argument");
break; break;
} }
@ -549,8 +472,7 @@ void Attributes::CheckAttr(Attr* a)
// explicitly overridden // explicitly overridden
if ( ! type->AsTableType()->IsSet() && if ( ! type->AsTableType()->IsSet() &&
! input::Manager::IsCompatibleType(type->AsTableType()->Yield().get(), true) && ! input::Manager::IsCompatibleType(type->AsTableType()->Yield().get(), true) &&
! Find(ATTR_BROKER_STORE_ALLOW_COMPLEX) ) ! Find(ATTR_BROKER_STORE_ALLOW_COMPLEX) ) {
{
Error("&backend only supports atomic types as table value"); Error("&backend only supports atomic types as table value");
} }
@ -566,16 +488,13 @@ void Attributes::CheckAttr(Attr* a)
break; break;
} }
case ATTR_BROKER_STORE: case ATTR_BROKER_STORE: {
{ if ( type->Tag() != TYPE_TABLE ) {
if ( type->Tag() != TYPE_TABLE )
{
Error("&broker_store only applicable to sets/tables"); Error("&broker_store only applicable to sets/tables");
break; break;
} }
if ( a->GetExpr()->GetType()->Tag() != TYPE_STRING ) if ( a->GetExpr()->GetType()->Tag() != TYPE_STRING ) {
{
Error("&broker_store must take a string argument"); Error("&broker_store must take a string argument");
break; break;
} }
@ -584,8 +503,7 @@ void Attributes::CheckAttr(Attr* a)
// explicitly overridden // explicitly overridden
if ( ! type->AsTableType()->IsSet() && if ( ! type->AsTableType()->IsSet() &&
! input::Manager::IsCompatibleType(type->AsTableType()->Yield().get(), true) && ! input::Manager::IsCompatibleType(type->AsTableType()->Yield().get(), true) &&
! Find(ATTR_BROKER_STORE_ALLOW_COMPLEX) ) ! Find(ATTR_BROKER_STORE_ALLOW_COMPLEX) ) {
{
Error("&broker_store only supports atomic types as table value"); Error("&broker_store only supports atomic types as table value");
} }
@ -601,10 +519,8 @@ void Attributes::CheckAttr(Attr* a)
break; break;
} }
case ATTR_BROKER_STORE_ALLOW_COMPLEX: case ATTR_BROKER_STORE_ALLOW_COMPLEX: {
{ if ( type->Tag() != TYPE_TABLE ) {
if ( type->Tag() != TYPE_TABLE )
{
Error("&broker_allow_complex_type only applicable to sets/tables"); Error("&broker_allow_complex_type only applicable to sets/tables");
break; break;
} }
@ -619,9 +535,7 @@ void Attributes::CheckAttr(Attr* a)
Error("&raw_output only applicable to files"); Error("&raw_output only applicable to files");
break; break;
case ATTR_PRIORITY: case ATTR_PRIORITY: Error("&priority only applicable to event bodies"); break;
Error("&priority only applicable to event bodies");
break;
case ATTR_GROUP: case ATTR_GROUP:
if ( type->Tag() != TYPE_FUNC || type->AsFuncType()->Flavor() != FUNC_FLAVOR_EVENT ) if ( type->Tag() != TYPE_FUNC || type->AsFuncType()->Flavor() != FUNC_FLAVOR_EVENT )
@ -638,18 +552,15 @@ void Attributes::CheckAttr(Attr* a)
Error("&log applied to a type that cannot be logged"); Error("&log applied to a type that cannot be logged");
break; break;
case ATTR_TYPE_COLUMN: case ATTR_TYPE_COLUMN: {
{ if ( type->Tag() != TYPE_PORT ) {
if ( type->Tag() != TYPE_PORT )
{
Error("type_column tag only applicable to ports"); Error("type_column tag only applicable to ports");
break; break;
} }
const auto& atype = a->GetExpr()->GetType(); const auto& atype = a->GetExpr()->GetType();
if ( atype->Tag() != TYPE_STRING ) if ( atype->Tag() != TYPE_STRING ) {
{
Error("type column needs to have a string argument"); Error("type column needs to have a string argument");
break; break;
} }
@ -662,21 +573,18 @@ void Attributes::CheckAttr(Attr* a)
Error("&ordered only applicable to tables"); Error("&ordered only applicable to tables");
break; break;
default: default: BadTag("Attributes::CheckAttr", attr_name(a->Tag()));
BadTag("Attributes::CheckAttr", attr_name(a->Tag()));
} }
} }
bool Attributes::operator==(const Attributes& other) const bool Attributes::operator==(const Attributes& other) const {
{
if ( attrs.empty() ) if ( attrs.empty() )
return other.attrs.empty(); return other.attrs.empty();
if ( other.attrs.empty() ) if ( other.attrs.empty() )
return false; return false;
for ( const auto& a : attrs ) for ( const auto& a : attrs ) {
{
const auto& o = other.Find(a->Tag()); const auto& o = other.Find(a->Tag());
if ( ! o ) if ( ! o )
@ -686,8 +594,7 @@ bool Attributes::operator==(const Attributes& other) const
return false; return false;
} }
for ( const auto& o : other.attrs ) for ( const auto& o : other.attrs ) {
{
const auto& a = Find(o->Tag()); const auto& a = Find(o->Tag());
if ( ! a ) if ( ! a )
@ -700,23 +607,19 @@ bool Attributes::operator==(const Attributes& other) const
return true; return true;
} }
bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_record, bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_record, std::string& err_msg) {
std::string& err_msg)
{
ASSERT(a->Tag() == ATTR_DEFAULT || a->Tag() == ATTR_DEFAULT_INSERT); ASSERT(a->Tag() == ATTR_DEFAULT || a->Tag() == ATTR_DEFAULT_INSERT);
std::string aname = attr_name(a->Tag()); std::string aname = attr_name(a->Tag());
// &default is allowed for global tables, since it's used in // &default is allowed for global tables, since it's used in
// initialization of table fields. It's not allowed otherwise. // initialization of table fields. It's not allowed otherwise.
if ( global_var && ! type->IsTable() ) if ( global_var && ! type->IsTable() ) {
{
err_msg = aname + " is not valid for global variables except for tables"; err_msg = aname + " is not valid for global variables except for tables";
return false; return false;
} }
const auto& atype = a->GetExpr()->GetType(); const auto& atype = a->GetExpr()->GetType();
if ( type->Tag() != TYPE_TABLE || (type->IsSet() && ! in_record) ) if ( type->Tag() != TYPE_TABLE || (type->IsSet() && ! in_record) ) {
{
if ( same_type(atype, type) ) if ( same_type(atype, type) )
// Ok. // Ok.
return true; return true;
@ -733,8 +636,7 @@ bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_r
auto e = check_and_promote_expr(a->GetExpr(), type); auto e = check_and_promote_expr(a->GetExpr(), type);
if ( e ) if ( e ) {
{
a->SetAttrExpr(std::move(e)); a->SetAttrExpr(std::move(e));
// Ok. // Ok.
return true; return true;
@ -747,17 +649,14 @@ bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_r
TableType* tt = type->AsTableType(); TableType* tt = type->AsTableType();
const auto& ytype = tt->Yield(); const auto& ytype = tt->Yield();
if ( ! in_record ) if ( ! in_record ) { // &default applies to the type itself.
{ // &default applies to the type itself.
if ( same_type(atype, ytype) ) if ( same_type(atype, ytype) )
return true; return true;
// It can still be a default function. // It can still be a default function.
if ( atype->Tag() == TYPE_FUNC ) if ( atype->Tag() == TYPE_FUNC ) {
{
FuncType* f = atype->AsFuncType(); FuncType* f = atype->AsFuncType();
if ( ! f->CheckArgs(tt->GetIndexTypes()) || ! same_type(f->Yield(), ytype) ) if ( ! f->CheckArgs(tt->GetIndexTypes()) || ! same_type(f->Yield(), ytype) ) {
{
err_msg = aname + " function type clash"; err_msg = aname + " function type clash";
return false; return false;
} }
@ -774,8 +673,7 @@ bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_r
auto e = check_and_promote_expr(a->GetExpr(), ytype); auto e = check_and_promote_expr(a->GetExpr(), ytype);
if ( e ) if ( e ) {
{
a->SetAttrExpr(std::move(e)); a->SetAttrExpr(std::move(e));
// Ok. // Ok.
return true; return true;
@ -790,12 +688,10 @@ bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_r
if ( same_type(atype, type) ) if ( same_type(atype, type) )
return true; return true;
if ( (atype->Tag() == TYPE_TABLE && atype->AsTableType()->IsUnspecifiedTable()) ) if ( (atype->Tag() == TYPE_TABLE && atype->AsTableType()->IsUnspecifiedTable()) ) {
{
auto e = check_and_promote_expr(a->GetExpr(), type); auto e = check_and_promote_expr(a->GetExpr(), type);
if ( e ) if ( e ) {
{
a->SetAttrExpr(std::move(e)); a->SetAttrExpr(std::move(e));
return true; return true;
} }
@ -811,13 +707,11 @@ bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_r
return false; return false;
} }
detail::TraversalCode Attributes::Traverse(detail::TraversalCallback* cb) const detail::TraversalCode Attributes::Traverse(detail::TraversalCallback* cb) const {
{
auto tc = cb->PreAttrs(this); auto tc = cb->PreAttrs(this);
HANDLE_TC_ATTRS_PRE(tc); HANDLE_TC_ATTRS_PRE(tc);
for ( const auto& a : attrs ) for ( const auto& a : attrs ) {
{
tc = a->Traverse(cb); tc = a->Traverse(cb);
HANDLE_TC_ATTRS_PRE(tc); HANDLE_TC_ATTRS_PRE(tc);
} }
@ -826,4 +720,4 @@ detail::TraversalCode Attributes::Traverse(detail::TraversalCallback* cb) const
HANDLE_TC_ATTRS_POST(tc); HANDLE_TC_ATTRS_POST(tc);
} }
} } // namespace zeek::detail

View file

@ -14,20 +14,17 @@
// modify expressions or supply metadata on types, and the kind that // modify expressions or supply metadata on types, and the kind that
// are extra metadata on every variable instance. // are extra metadata on every variable instance.
namespace zeek namespace zeek {
{
class Type; class Type;
using TypePtr = IntrusivePtr<Type>; using TypePtr = IntrusivePtr<Type>;
namespace detail namespace detail {
{
class Expr; class Expr;
using ExprPtr = IntrusivePtr<Expr>; using ExprPtr = IntrusivePtr<Expr>;
enum AttrTag enum AttrTag {
{
ATTR_OPTIONAL, ATTR_OPTIONAL,
ATTR_DEFAULT, ATTR_DEFAULT,
ATTR_DEFAULT_INSERT, // insert default value on failed lookups ATTR_DEFAULT_INSERT, // insert default value on failed lookups
@ -61,8 +58,7 @@ using AttrPtr = IntrusivePtr<Attr>;
class Attributes; class Attributes;
using AttributesPtr = IntrusivePtr<Attributes>; using AttributesPtr = IntrusivePtr<Attributes>;
class Attr final : public Obj class Attr final : public Obj {
{
public: public:
static inline const AttrPtr nil; static inline const AttrPtr nil;
@ -86,8 +82,7 @@ public:
*/ */
std::string DeprecationMessage() const; std::string DeprecationMessage() const;
bool operator==(const Attr& other) const bool operator==(const Attr& other) const {
{
if ( tag != other.tag ) if ( tag != other.tag )
return false; return false;
@ -111,8 +106,7 @@ protected:
}; };
// Manages a collection of attributes. // Manages a collection of attributes.
class Attributes final : public Obj class Attributes final : public Obj {
{
public: public:
Attributes(std::vector<AttrPtr> a, TypePtr t, bool in_record, bool is_global); Attributes(std::vector<AttrPtr> a, TypePtr t, bool in_record, bool is_global);
Attributes(TypePtr t, bool in_record, bool is_global); Attributes(TypePtr t, bool in_record, bool is_global);
@ -154,8 +148,7 @@ protected:
// Returns true on compatibility (which might include modifying "a"), false // Returns true on compatibility (which might include modifying "a"), false
// on an error. If an error message hasn't been directly generated, then // on an error. If an error message hasn't been directly generated, then
// it will be returned in err_msg. // it will be returned in err_msg.
extern bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_record, extern bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_record, std::string& err_msg);
std::string& err_msg);
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -8,15 +8,13 @@
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
#include "zeek/ZeekString.h" #include "zeek/ZeekString.h"
namespace zeek::detail namespace zeek::detail {
{
int Base64Converter::default_base64_table[256]; int Base64Converter::default_base64_table[256];
const std::string Base64Converter::default_alphabet = const std::string Base64Converter::default_alphabet =
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
void Base64Converter::Encode(int len, const unsigned char* data, int* pblen, char** pbuf) void Base64Converter::Encode(int len, const unsigned char* data, int* pblen, char** pbuf) {
{
int blen; int blen;
char* buf; char* buf;
@ -26,20 +24,17 @@ void Base64Converter::Encode(int len, const unsigned char* data, int* pblen, cha
if ( *pbuf && (*pblen % 4 != 0) ) if ( *pbuf && (*pblen % 4 != 0) )
reporter->InternalError("Base64 encode buffer not a multiple of 4"); reporter->InternalError("Base64 encode buffer not a multiple of 4");
if ( *pbuf ) if ( *pbuf ) {
{
buf = *pbuf; buf = *pbuf;
blen = *pblen; blen = *pblen;
} }
else else {
{
blen = (int)(4 * ceil((double)len / 3)); blen = (int)(4 * ceil((double)len / 3));
*pbuf = buf = new char[blen]; *pbuf = buf = new char[blen];
*pblen = blen; *pblen = blen;
} }
for ( int i = 0, j = 0; (i < len) && (j < blen); ) for ( int i = 0, j = 0; (i < len) && (j < blen); ) {
{
uint32_t bit32 = data[i++] << 16; uint32_t bit32 = data[i++] << 16;
bit32 += (i++ < len ? data[i - 1] : 0) << 8; bit32 += (i++ < len ? data[i - 1] : 0) << 8;
bit32 += i++ < len ? data[i - 1] : 0; bit32 += i++ < len ? data[i - 1] : 0;
@ -51,8 +46,7 @@ void Base64Converter::Encode(int len, const unsigned char* data, int* pblen, cha
} }
} }
int* Base64Converter::InitBase64Table(const std::string& alphabet) int* Base64Converter::InitBase64Table(const std::string& alphabet) {
{
assert(alphabet.size() == 64); assert(alphabet.size() == 64);
static bool default_table_initialized = false; static bool default_table_initialized = false;
@ -62,8 +56,7 @@ int* Base64Converter::InitBase64Table(const std::string& alphabet)
int* base64_table = nullptr; int* base64_table = nullptr;
if ( alphabet == default_alphabet ) if ( alphabet == default_alphabet ) {
{
base64_table = default_base64_table; base64_table = default_base64_table;
default_table_initialized = true; default_table_initialized = true;
} }
@ -74,8 +67,7 @@ int* Base64Converter::InitBase64Table(const std::string& alphabet)
for ( i = 0; i < 256; ++i ) for ( i = 0; i < 256; ++i )
base64_table[i] = -1; base64_table[i] = -1;
for ( i = 0; i < 26; ++i ) for ( i = 0; i < 26; ++i ) {
{
base64_table[int(alphabet[0 + i])] = i; base64_table[int(alphabet[0 + i])] = i;
base64_table[int(alphabet[26 + i])] = i + 26; base64_table[int(alphabet[26 + i])] = i + 26;
} }
@ -91,15 +83,12 @@ int* Base64Converter::InitBase64Table(const std::string& alphabet)
return base64_table; return base64_table;
} }
Base64Converter::Base64Converter(Connection* arg_conn, const std::string& arg_alphabet) Base64Converter::Base64Converter(Connection* arg_conn, const std::string& arg_alphabet) {
{ if ( arg_alphabet.size() > 0 ) {
if ( arg_alphabet.size() > 0 )
{
assert(arg_alphabet.size() == 64); assert(arg_alphabet.size() == 64);
alphabet = arg_alphabet; alphabet = arg_alphabet;
} }
else else {
{
alphabet = default_alphabet; alphabet = default_alphabet;
} }
@ -110,14 +99,12 @@ Base64Converter::Base64Converter(Connection* arg_conn, const std::string& arg_al
conn = arg_conn; conn = arg_conn;
} }
Base64Converter::~Base64Converter() Base64Converter::~Base64Converter() {
{
if ( base64_table != default_base64_table ) if ( base64_table != default_base64_table )
delete[] base64_table; delete[] base64_table;
} }
int Base64Converter::Decode(int len, const char* data, int* pblen, char** pbuf) int Base64Converter::Decode(int len, const char* data, int* pblen, char** pbuf) {
{
int blen; int blen;
char* buf; char* buf;
@ -128,13 +115,11 @@ int Base64Converter::Decode(int len, const char* data, int* pblen, char** pbuf)
if ( ! pbuf ) if ( ! pbuf )
reporter->InternalError("nil pointer to decoding result buffer"); reporter->InternalError("nil pointer to decoding result buffer");
if ( *pbuf ) if ( *pbuf ) {
{
buf = *pbuf; buf = *pbuf;
blen = *pblen; blen = *pblen;
} }
else else {
{
// Estimate the maximal number of 3-byte groups needed, // Estimate the maximal number of 3-byte groups needed,
// plus 1 byte for the optional ending NUL. // plus 1 byte for the optional ending NUL.
blen = int((len + base64_group_next + 3) / 4) * 3 + 1; blen = int((len + base64_group_next + 3) / 4) * 3 + 1;
@ -143,14 +128,11 @@ int Base64Converter::Decode(int len, const char* data, int* pblen, char** pbuf)
int dlen = 0; int dlen = 0;
while ( true ) while ( true ) {
{ if ( base64_group_next == 4 ) {
if ( base64_group_next == 4 )
{
// For every group of 4 6-bit numbers, // For every group of 4 6-bit numbers,
// write the decoded 3 bytes to the buffer. // write the decoded 3 bytes to the buffer.
if ( base64_after_padding ) if ( base64_after_padding ) {
{
if ( ++errored == 1 ) if ( ++errored == 1 )
IllegalEncoding("extra base64 groups after '=' padding are ignored"); IllegalEncoding("extra base64 groups after '=' padding are ignored");
base64_group_next = 0; base64_group_next = 0;
@ -191,8 +173,7 @@ int Base64Converter::Decode(int len, const char* data, int* pblen, char** pbuf)
int k = base64_table[c]; int k = base64_table[c];
if ( k >= 0 ) if ( k >= 0 )
base64_group[base64_group_next++] = k; base64_group[base64_group_next++] = k;
else else {
{
if ( ++errored == 1 ) if ( ++errored == 1 )
IllegalEncoding(util::fmt("character %d ignored by Base64 decoding", (int)c)); IllegalEncoding(util::fmt("character %d ignored by Base64 decoding", (int)c));
} }
@ -204,15 +185,13 @@ int Base64Converter::Decode(int len, const char* data, int* pblen, char** pbuf)
return dlen; return dlen;
} }
int Base64Converter::Done(int* pblen, char** pbuf) int Base64Converter::Done(int* pblen, char** pbuf) {
{
const char* padding = "==="; const char* padding = "===";
if ( base64_group_next != 0 ) if ( base64_group_next != 0 ) {
{
if ( base64_group_next < 4 ) if ( base64_group_next < 4 )
IllegalEncoding(util::fmt("incomplete base64 group, padding with %d bits of 0", IllegalEncoding(
(4 - base64_group_next) * 6)); util::fmt("incomplete base64 group, padding with %d bits of 0", (4 - base64_group_next) * 6));
Decode(4 - base64_group_next, padding, pblen, pbuf); Decode(4 - base64_group_next, padding, pblen, pbuf);
return -1; return -1;
} }
@ -223,8 +202,7 @@ int Base64Converter::Done(int* pblen, char** pbuf)
return 0; return 0;
} }
void Base64Converter::IllegalEncoding(const char* msg) void Base64Converter::IllegalEncoding(const char* msg) {
{
// strncpy(error_msg, msg, sizeof(error_msg)); // strncpy(error_msg, msg, sizeof(error_msg));
if ( conn ) if ( conn )
conn->Weird("base64_illegal_encoding", msg); conn->Weird("base64_illegal_encoding", msg);
@ -232,10 +210,8 @@ void Base64Converter::IllegalEncoding(const char* msg)
reporter->Error("%s", msg); reporter->Error("%s", msg);
} }
String* decode_base64(const String* s, const String* a, Connection* conn) String* decode_base64(const String* s, const String* a, Connection* conn) {
{ if ( a && a->Len() != 0 && a->Len() != 64 ) {
if ( a && a->Len() != 0 && a->Len() != 64 )
{
reporter->Error("base64 decoding alphabet is not 64 characters: %s", a->CheckString()); reporter->Error("base64 decoding alphabet is not 64 characters: %s", a->CheckString());
return nullptr; return nullptr;
} }
@ -265,10 +241,8 @@ err:
return nullptr; return nullptr;
} }
String* encode_base64(const String* s, const String* a, Connection* conn) String* encode_base64(const String* s, const String* a, Connection* conn) {
{ if ( a && a->Len() != 0 && a->Len() != 64 ) {
if ( a && a->Len() != 0 && a->Len() != 64 )
{
reporter->Error("base64 alphabet is not 64 characters: %s", a->CheckString()); reporter->Error("base64 alphabet is not 64 characters: %s", a->CheckString());
return nullptr; return nullptr;
} }

View file

@ -4,18 +4,15 @@
#include <string> #include <string>
namespace zeek namespace zeek {
{
class String; class String;
class Connection; class Connection;
namespace detail namespace detail {
{
// Maybe we should have a base class for generic decoders? // Maybe we should have a base class for generic decoders?
class Base64Converter class Base64Converter {
{
public: public:
// <conn> is used for error reporting. If it is set to zero (as, // <conn> is used for error reporting. If it is set to zero (as,
// e.g., done by the built-in functions decode_base64() and // e.g., done by the built-in functions decode_base64() and

View file

@ -4,8 +4,7 @@
#include "zeek/Val.h" #include "zeek/Val.h"
namespace zeek::detail namespace zeek::detail {
{
BifReturnVal::BifReturnVal(std::nullptr_t) noexcept {} BifReturnVal::BifReturnVal(std::nullptr_t) noexcept {}

View file

@ -6,26 +6,22 @@
#include "zeek/IntrusivePtr.h" #include "zeek/IntrusivePtr.h"
namespace zeek namespace zeek {
{
class Val; class Val;
using ValPtr = IntrusivePtr<Val>; using ValPtr = IntrusivePtr<Val>;
namespace detail namespace detail {
{
/** /**
* A simple wrapper class to use for the return value of BIFs so that * A simple wrapper class to use for the return value of BIFs so that
* they may return either a Val* or IntrusivePtr<Val> (the former could * they may return either a Val* or IntrusivePtr<Val> (the former could
* potentially be deprecated). * potentially be deprecated).
*/ */
class BifReturnVal class BifReturnVal {
{
public: public:
template <typename T> BifReturnVal(IntrusivePtr<T> v) noexcept : rval(AdoptRef{}, v.release()) template<typename T>
{ BifReturnVal(IntrusivePtr<T> v) noexcept : rval(AdoptRef{}, v.release()) {}
}
BifReturnVal(std::nullptr_t) noexcept; BifReturnVal(std::nullptr_t) noexcept;

View file

@ -9,30 +9,23 @@
#include "zeek/DFA.h" #include "zeek/DFA.h"
#include "zeek/RE.h" #include "zeek/RE.h"
namespace zeek::detail namespace zeek::detail {
{
CCL::CCL() CCL::CCL() {
{
syms = new int_list; syms = new int_list;
index = -(rem->InsertCCL(this) + 1); index = -(rem->InsertCCL(this) + 1);
negated = 0; negated = 0;
} }
CCL::~CCL() CCL::~CCL() { delete syms; }
{
delete syms;
}
void CCL::Negate() void CCL::Negate() {
{
negated = 1; negated = 1;
Add(SYM_BOL); Add(SYM_BOL);
Add(SYM_EOL); Add(SYM_EOL);
} }
void CCL::Add(int sym) void CCL::Add(int sym) {
{
auto sym_p = static_cast<std::intptr_t>(sym); auto sym_p = static_cast<std::intptr_t>(sym);
// Check to see if the character is already in the ccl. // Check to see if the character is already in the ccl.
@ -43,9 +36,6 @@ void CCL::Add(int sym)
syms->push_back(sym_p); syms->push_back(sym_p);
} }
void CCL::Sort() void CCL::Sort() { std::sort(syms->begin(), syms->end()); }
{
std::sort(syms->begin(), syms->end());
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -5,13 +5,11 @@
#include <cstdint> #include <cstdint>
#include <vector> #include <vector>
namespace zeek::detail namespace zeek::detail {
{
using int_list = std::vector<std::intptr_t>; using int_list = std::vector<std::intptr_t>;
class CCL class CCL {
{
public: public:
CCL(); CCL();
~CCL(); ~CCL();
@ -25,8 +23,7 @@ public:
int_list* Syms() { return syms; } int_list* Syms() { return syms; }
void ReplaceSyms(int_list* new_syms) void ReplaceSyms(int_list* new_syms) {
{
delete syms; delete syms;
syms = new_syms; syms = new_syms;
} }

View file

@ -15,15 +15,12 @@
#include "zeek/Val.h" #include "zeek/Val.h"
#include "zeek/ZeekString.h" #include "zeek/ZeekString.h"
namespace zeek::detail namespace zeek::detail {
{
// A comparison callable to assist with consistent iteration order over tables // A comparison callable to assist with consistent iteration order over tables
// during reservation & writes. // during reservation & writes.
struct HashKeyComparer struct HashKeyComparer {
{ bool operator()(const std::unique_ptr<HashKey>& a, const std::unique_ptr<HashKey>& b) const {
bool operator()(const std::unique_ptr<HashKey>& a, const std::unique_ptr<HashKey>& b) const
{
if ( a->Hash() != b->Hash() ) if ( a->Hash() != b->Hash() )
return a->Hash() < b->Hash(); return a->Hash() < b->Hash();
if ( a->Size() != b->Size() ) if ( a->Size() != b->Size() )
@ -37,14 +34,12 @@ using HashkeyMapPtr = std::unique_ptr<HashkeyMap>;
// Helper that produces a table from HashKeys to the ListVal indexes into the // Helper that produces a table from HashKeys to the ListVal indexes into the
// table, that we can iterate over in sorted-Hashkey order. // table, that we can iterate over in sorted-Hashkey order.
const HashkeyMapPtr ordered_hashkeys(const TableVal* tv) const HashkeyMapPtr ordered_hashkeys(const TableVal* tv) {
{
auto res = std::make_unique<HashkeyMap>(); auto res = std::make_unique<HashkeyMap>();
auto tbl = tv->AsTable(); auto tbl = tv->AsTable();
auto idx = 0; auto idx = 0;
for ( const auto& entry : *tbl ) for ( const auto& entry : *tbl ) {
{
auto k = entry.GetHashKey(); auto k = entry.GetHashKey();
// Potential optimization: we could do without the following if // Potential optimization: we could do without the following if
// the caller uses k directly to determine key length & // the caller uses k directly to determine key length &
@ -60,25 +55,21 @@ const HashkeyMapPtr ordered_hashkeys(const TableVal* tv)
return res; return res;
} }
CompositeHash::CompositeHash(TypeListPtr composite_type) : type(std::move(composite_type)) CompositeHash::CompositeHash(TypeListPtr composite_type) : type(std::move(composite_type)) {
{
if ( type->GetTypes().size() == 1 ) if ( type->GetTypes().size() == 1 )
is_singleton = true; is_singleton = true;
} }
std::unique_ptr<HashKey> CompositeHash::MakeHashKey(const Val& argv, bool type_check) const std::unique_ptr<HashKey> CompositeHash::MakeHashKey(const Val& argv, bool type_check) const {
{
auto res = std::make_unique<HashKey>(); auto res = std::make_unique<HashKey>();
const auto& tl = type->GetTypes(); const auto& tl = type->GetTypes();
if ( is_singleton ) if ( is_singleton ) {
{
const Val* v = &argv; const Val* v = &argv;
// This is the "singleton" case -- actually just a single value // This is the "singleton" case -- actually just a single value
// that may come bundled in a list. If so, unwrap it. // that may come bundled in a list. If so, unwrap it.
if ( v->GetType()->Tag() == TYPE_LIST ) if ( v->GetType()->Tag() == TYPE_LIST ) {
{
auto lv = v->AsListVal(); auto lv = v->AsListVal();
if ( type_check && lv->Length() != 1 ) if ( type_check && lv->Length() != 1 )
@ -105,25 +96,21 @@ std::unique_ptr<HashKey> CompositeHash::MakeHashKey(const Val& argv, bool type_c
// The size computation resulted in a requested buffer size; allocate it. // The size computation resulted in a requested buffer size; allocate it.
res->Allocate(); res->Allocate();
for ( auto i = 0u; i < tl.size(); ++i ) for ( auto i = 0u; i < tl.size(); ++i ) {
{ if ( ! SingleValHash(*res, argv.AsListVal()->Idx(i).get(), tl[i].get(), type_check, false, false) )
if ( ! SingleValHash(*res, argv.AsListVal()->Idx(i).get(), tl[i].get(), type_check, false,
false) )
return nullptr; return nullptr;
} }
return res; return res;
} }
ListValPtr CompositeHash::RecoverVals(const HashKey& hk) const ListValPtr CompositeHash::RecoverVals(const HashKey& hk) const {
{
auto l = make_intrusive<ListVal>(TYPE_ANY); auto l = make_intrusive<ListVal>(TYPE_ANY);
const auto& tl = type->GetTypes(); const auto& tl = type->GetTypes();
hk.ResetRead(); hk.ResetRead();
for ( const auto& type : tl ) for ( const auto& type : tl ) {
{
ValPtr v; ValPtr v;
if ( ! RecoverOneVal(hk, type.get(), &v, false, is_singleton) ) if ( ! RecoverOneVal(hk, type.get(), &v, false, is_singleton) )
@ -136,28 +123,22 @@ ListValPtr CompositeHash::RecoverVals(const HashKey& hk) const
return l; return l;
} }
bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool optional, bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool optional, bool singleton) const {
bool singleton) const
{
TypeTag tag = t->Tag(); TypeTag tag = t->Tag();
InternalTypeTag it = t->InternalType(); InternalTypeTag it = t->InternalType();
if ( optional ) if ( optional ) {
{
bool opt; bool opt;
hk.Read("optional", opt); hk.Read("optional", opt);
if ( ! opt ) if ( ! opt ) {
{
*pval = nullptr; *pval = nullptr;
return true; return true;
} }
} }
switch ( it ) switch ( it ) {
{ case TYPE_INTERNAL_INT: {
case TYPE_INTERNAL_INT:
{
zeek_int_t i; zeek_int_t i;
hk.Read("int", i); hk.Read("int", i);
@ -167,42 +148,30 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
*pval = val_mgr->Bool(i); *pval = val_mgr->Bool(i);
else if ( tag == TYPE_INT ) else if ( tag == TYPE_INT )
*pval = val_mgr->Int(i); *pval = val_mgr->Int(i);
else else {
{ reporter->InternalError("bad internal unsigned int in CompositeHash::RecoverOneVal()");
reporter->InternalError(
"bad internal unsigned int in CompositeHash::RecoverOneVal()");
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
} } break;
break;
case TYPE_INTERNAL_UNSIGNED: case TYPE_INTERNAL_UNSIGNED: {
{
zeek_uint_t u; zeek_uint_t u;
hk.Read("unsigned", u); hk.Read("unsigned", u);
switch ( tag ) switch ( tag ) {
{ case TYPE_COUNT: *pval = val_mgr->Count(u); break;
case TYPE_COUNT:
*pval = val_mgr->Count(u);
break;
case TYPE_PORT: case TYPE_PORT: *pval = val_mgr->Port(u); break;
*pval = val_mgr->Port(u);
break;
default: default:
reporter->InternalError( reporter->InternalError("bad internal unsigned int in CompositeHash::RecoverOneVal()");
"bad internal unsigned int in CompositeHash::RecoverOneVal()");
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
} } break;
break;
case TYPE_INTERNAL_DOUBLE: case TYPE_INTERNAL_DOUBLE: {
{
double d; double d;
hk.Read("double", d); hk.Read("double", d);
@ -212,33 +181,25 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
*pval = make_intrusive<TimeVal>(d); *pval = make_intrusive<TimeVal>(d);
else else
*pval = make_intrusive<DoubleVal>(d); *pval = make_intrusive<DoubleVal>(d);
} } break;
break;
case TYPE_INTERNAL_ADDR: case TYPE_INTERNAL_ADDR: {
{
hk.AlignRead(sizeof(uint32_t)); hk.AlignRead(sizeof(uint32_t));
hk.EnsureReadSpace(sizeof(uint32_t) * 4); hk.EnsureReadSpace(sizeof(uint32_t) * 4);
IPAddr addr(IPv6, static_cast<const uint32_t*>(hk.KeyAtRead()), IPAddr::Network); IPAddr addr(IPv6, static_cast<const uint32_t*>(hk.KeyAtRead()), IPAddr::Network);
hk.SkipRead("addr", sizeof(uint32_t) * 4); hk.SkipRead("addr", sizeof(uint32_t) * 4);
switch ( tag ) switch ( tag ) {
{ case TYPE_ADDR: *pval = make_intrusive<AddrVal>(addr); break;
case TYPE_ADDR:
*pval = make_intrusive<AddrVal>(addr);
break;
default: default:
reporter->InternalError( reporter->InternalError("bad internal address in CompositeHash::RecoverOneVal()");
"bad internal address in CompositeHash::RecoverOneVal()");
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
} } break;
break;
case TYPE_INTERNAL_SUBNET: case TYPE_INTERNAL_SUBNET: {
{
hk.AlignRead(sizeof(uint32_t)); hk.AlignRead(sizeof(uint32_t));
hk.EnsureReadSpace(sizeof(uint32_t) * 4); hk.EnsureReadSpace(sizeof(uint32_t) * 4);
IPAddr addr(IPv6, static_cast<const uint32_t*>(hk.KeyAtRead()), IPAddr::Network); IPAddr addr(IPv6, static_cast<const uint32_t*>(hk.KeyAtRead()), IPAddr::Network);
@ -247,16 +208,12 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
uint32_t width; uint32_t width;
hk.Read("subnet-width", width); hk.Read("subnet-width", width);
*pval = make_intrusive<SubNetVal>(addr, width); *pval = make_intrusive<SubNetVal>(addr, width);
} } break;
break;
case TYPE_INTERNAL_VOID: case TYPE_INTERNAL_VOID:
case TYPE_INTERNAL_OTHER: case TYPE_INTERNAL_OTHER: {
{ switch ( t->Tag() ) {
switch ( t->Tag() ) case TYPE_FUNC: {
{
case TYPE_FUNC:
{
uint32_t id; uint32_t id;
hk.Read("func", id); hk.Read("func", id);
@ -273,36 +230,29 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
const auto& pvt = (*pval)->GetType(); const auto& pvt = (*pval)->GetType();
if ( ! pvt ) if ( ! pvt )
reporter->InternalError( reporter->InternalError("bad aggregate Val in CompositeHash::RecoverOneVal()");
"bad aggregate Val in CompositeHash::RecoverOneVal()");
else if ( t->Tag() != TYPE_FUNC && ! same_type(pvt, t) ) else if ( t->Tag() != TYPE_FUNC && ! same_type(pvt, t) )
// ### Maybe fix later, but may be fundamentally un-checkable --US // ### Maybe fix later, but may be fundamentally un-checkable --US
{ {
reporter->InternalError( reporter->InternalError("inconsistent aggregate Val in CompositeHash::RecoverOneVal()");
"inconsistent aggregate Val in CompositeHash::RecoverOneVal()");
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
// ### A crude approximation for now. // ### A crude approximation for now.
else if ( t->Tag() == TYPE_FUNC && pvt->Tag() != TYPE_FUNC ) else if ( t->Tag() == TYPE_FUNC && pvt->Tag() != TYPE_FUNC ) {
{ reporter->InternalError("inconsistent aggregate Val in CompositeHash::RecoverOneVal()");
reporter->InternalError(
"inconsistent aggregate Val in CompositeHash::RecoverOneVal()");
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
} } break;
break;
case TYPE_PATTERN: case TYPE_PATTERN: {
{
const char* texts[2] = {nullptr, nullptr}; const char* texts[2] = {nullptr, nullptr};
uint64_t lens[2] = {0, 0}; uint64_t lens[2] = {0, 0};
if ( ! singleton ) if ( ! singleton ) {
{
hk.Read("pattern-len1", lens[0]); hk.Read("pattern-len1", lens[0]);
hk.Read("pattern-len2", lens[1]); hk.Read("pattern-len2", lens[1]);
} }
@ -315,29 +265,23 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
RE_Matcher* re = new RE_Matcher(texts[0], texts[1]); RE_Matcher* re = new RE_Matcher(texts[0], texts[1]);
if ( ! re->Compile() ) if ( ! re->Compile() )
reporter->InternalError("failed compiling table/set key pattern: %s", reporter->InternalError("failed compiling table/set key pattern: %s", re->PatternText());
re->PatternText());
*pval = make_intrusive<PatternVal>(re); *pval = make_intrusive<PatternVal>(re);
} } break;
break;
case TYPE_RECORD: case TYPE_RECORD: {
{
auto rt = t->AsRecordType(); auto rt = t->AsRecordType();
int num_fields = rt->NumFields(); int num_fields = rt->NumFields();
std::vector<ValPtr> values; std::vector<ValPtr> values;
int i; int i;
for ( i = 0; i < num_fields; ++i ) for ( i = 0; i < num_fields; ++i ) {
{
ValPtr v; ValPtr v;
Attributes* a = rt->FieldDecl(i)->attrs.get(); Attributes* a = rt->FieldDecl(i)->attrs.get();
bool is_optional = (a && a->Find(ATTR_OPTIONAL)); bool is_optional = (a && a->Find(ATTR_OPTIONAL));
if ( ! RecoverOneVal(hk, rt->GetFieldType(i).get(), &v, is_optional, if ( ! RecoverOneVal(hk, rt->GetFieldType(i).get(), &v, is_optional, false) ) {
false) )
{
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
@ -346,10 +290,8 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
// abort() and broken the call tree that clang-tidy is relying on to // abort() and broken the call tree that clang-tidy is relying on to
// get the error described. // get the error described.
// NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Branch) // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Branch)
if ( ! (v || is_optional) ) if ( ! (v || is_optional) ) {
{ reporter->InternalError("didn't recover expected number of fields from HashKey");
reporter->InternalError(
"didn't recover expected number of fields from HashKey");
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
@ -359,39 +301,32 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
ASSERT(int(values.size()) == num_fields); ASSERT(int(values.size()) == num_fields);
auto rv = make_intrusive<RecordVal>(IntrusivePtr{NewRef{}, rt}, auto rv = make_intrusive<RecordVal>(IntrusivePtr{NewRef{}, rt}, false /* init_fields */);
false /* init_fields */);
for ( int i = 0; i < num_fields; ++i ) for ( int i = 0; i < num_fields; ++i )
rv->AppendField(std::move(values[i]), rt->GetFieldType(i)); rv->AppendField(std::move(values[i]), rt->GetFieldType(i));
*pval = std::move(rv); *pval = std::move(rv);
} } break;
break;
case TYPE_TABLE: case TYPE_TABLE: {
{
int n; int n;
hk.Read("table-size", n); hk.Read("table-size", n);
auto tt = t->AsTableType(); auto tt = t->AsTableType();
auto tv = make_intrusive<TableVal>(IntrusivePtr{NewRef{}, tt}); auto tv = make_intrusive<TableVal>(IntrusivePtr{NewRef{}, tt});
for ( int i = 0; i < n; ++i ) for ( int i = 0; i < n; ++i ) {
{
ValPtr key; ValPtr key;
if ( ! RecoverOneVal(hk, tt->GetIndices().get(), &key, false, false) ) if ( ! RecoverOneVal(hk, tt->GetIndices().get(), &key, false, false) ) {
{
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
if ( t->IsSet() ) if ( t->IsSet() )
tv->Assign(std::move(key), nullptr); tv->Assign(std::move(key), nullptr);
else else {
{
ValPtr value; ValPtr value;
if ( ! RecoverOneVal(hk, tt->Yield().get(), &value, false, false) ) if ( ! RecoverOneVal(hk, tt->Yield().get(), &value, false, false) ) {
{
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
@ -400,27 +335,22 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
} }
*pval = std::move(tv); *pval = std::move(tv);
} } break;
break;
case TYPE_VECTOR: case TYPE_VECTOR: {
{
unsigned int n; unsigned int n;
hk.Read("vector-size", n); hk.Read("vector-size", n);
auto vt = t->AsVectorType(); auto vt = t->AsVectorType();
auto vv = make_intrusive<VectorVal>(IntrusivePtr{NewRef{}, vt}); auto vv = make_intrusive<VectorVal>(IntrusivePtr{NewRef{}, vt});
for ( unsigned int i = 0; i < n; ++i ) for ( unsigned int i = 0; i < n; ++i ) {
{
unsigned int index; unsigned int index;
hk.Read("vector-idx", index); hk.Read("vector-idx", index);
bool have_val; bool have_val;
hk.Read("vector-idx-present", have_val); hk.Read("vector-idx-present", have_val);
ValPtr value; ValPtr value;
if ( have_val && if ( have_val && ! RecoverOneVal(hk, vt->Yield().get(), &value, false, false) ) {
! RecoverOneVal(hk, vt->Yield().get(), &value, false, false) )
{
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
@ -429,18 +359,15 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
} }
*pval = std::move(vv); *pval = std::move(vv);
} } break;
break;
case TYPE_LIST: case TYPE_LIST: {
{
int n; int n;
hk.Read("list-size", n); hk.Read("list-size", n);
auto tl = t->AsTypeList(); auto tl = t->AsTypeList();
auto lv = make_intrusive<ListVal>(TYPE_ANY); auto lv = make_intrusive<ListVal>(TYPE_ANY);
for ( int i = 0; i < n; ++i ) for ( int i = 0; i < n; ++i ) {
{
ValPtr v; ValPtr v;
Type* it = tl->GetTypes()[i].get(); Type* it = tl->GetTypes()[i].get();
if ( ! RecoverOneVal(hk, it, &v, false, false) ) if ( ! RecoverOneVal(hk, it, &v, false, false) )
@ -449,55 +376,45 @@ bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool
} }
*pval = std::move(lv); *pval = std::move(lv);
} } break;
break;
default: default: {
{
reporter->InternalError("bad index type in CompositeHash::RecoverOneVal"); reporter->InternalError("bad index type in CompositeHash::RecoverOneVal");
*pval = nullptr; *pval = nullptr;
return false; return false;
} }
} }
} } break;
break;
case TYPE_INTERNAL_STRING: case TYPE_INTERNAL_STRING: {
{
int n = hk.Size(); int n = hk.Size();
if ( ! singleton ) if ( ! singleton ) {
{
hk.Read("string-len", n); hk.Read("string-len", n);
hk.EnsureReadSpace(n); hk.EnsureReadSpace(n);
} }
*pval = make_intrusive<StringVal>(new String((const byte_vec)hk.KeyAtRead(), n, true)); *pval = make_intrusive<StringVal>(new String((const byte_vec)hk.KeyAtRead(), n, true));
hk.SkipRead("string", n); hk.SkipRead("string", n);
} } break;
break;
case TYPE_INTERNAL_ERROR: case TYPE_INTERNAL_ERROR: break;
break;
} }
return true; return true;
} }
bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type_check, bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type_check, bool optional,
bool optional, bool singleton) const bool singleton) const {
{
InternalTypeTag t = bt->InternalType(); InternalTypeTag t = bt->InternalType();
if ( type_check && v ) if ( type_check && v ) {
{
InternalTypeTag vt = v->GetType()->InternalType(); InternalTypeTag vt = v->GetType()->InternalType();
if ( vt != t ) if ( vt != t )
return false; return false;
} }
if ( optional ) if ( optional ) {
{
// Add a marker saying whether the optional field is set. // Add a marker saying whether the optional field is set.
hk.Write("optional", v != nullptr); hk.Write("optional", v != nullptr);
@ -510,15 +427,10 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
if ( ! v ) if ( ! v )
return false; return false;
switch ( t ) switch ( t ) {
{ case TYPE_INTERNAL_INT: hk.Write("int", v->AsInt()); break;
case TYPE_INTERNAL_INT:
hk.Write("int", v->AsInt());
break;
case TYPE_INTERNAL_UNSIGNED: case TYPE_INTERNAL_UNSIGNED: hk.Write("unsigned", v->AsCount()); break;
hk.Write("unsigned", v->AsCount());
break;
case TYPE_INTERNAL_ADDR: case TYPE_INTERNAL_ADDR:
if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) if ( ! EnsureTypeReserve(hk, v, bt, type_check) )
@ -541,17 +453,12 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
hk.Write("subnet-width", v->AsSubNet().Length()); hk.Write("subnet-width", v->AsSubNet().Length());
break; break;
case TYPE_INTERNAL_DOUBLE: case TYPE_INTERNAL_DOUBLE: hk.Write("double", v->InternalDouble()); break;
hk.Write("double", v->InternalDouble());
break;
case TYPE_INTERNAL_VOID: case TYPE_INTERNAL_VOID:
case TYPE_INTERNAL_OTHER: case TYPE_INTERNAL_OTHER: {
{ switch ( v->GetType()->Tag() ) {
switch ( v->GetType()->Tag() ) case TYPE_FUNC: {
{
case TYPE_FUNC:
{
auto f = v->AsFunc(); auto f = v->AsFunc();
if ( ! func_to_func_id ) if ( ! func_to_func_id )
@ -560,8 +467,7 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
auto id_mapping = func_to_func_id->find(f); auto id_mapping = func_to_func_id->find(f);
uint32_t id; uint32_t id;
if ( id_mapping == func_to_func_id->end() ) if ( id_mapping == func_to_func_id->end() ) {
{
// We need the pointer to stick around // We need the pointer to stick around
// for our lifetime, so we have to get // for our lifetime, so we have to get
// a non-const version we can ref. // a non-const version we can ref.
@ -575,22 +481,17 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
id = id_mapping->second; id = id_mapping->second;
hk.Write("func", id); hk.Write("func", id);
} } break;
break;
case TYPE_PATTERN: case TYPE_PATTERN: {
{ const char* texts[2] = {v->AsPattern()->PatternText(), v->AsPattern()->AnywherePatternText()};
const char* texts[2] = {v->AsPattern()->PatternText(),
v->AsPattern()->AnywherePatternText()};
uint64_t lens[2] = {strlen(texts[0]) + 1, strlen(texts[1]) + 1}; uint64_t lens[2] = {strlen(texts[0]) + 1, strlen(texts[1]) + 1};
if ( ! singleton ) if ( ! singleton ) {
{
hk.Write("pattern-len1", lens[0]); hk.Write("pattern-len1", lens[0]);
hk.Write("pattern-len2", lens[1]); hk.Write("pattern-len2", lens[1]);
} }
else else {
{
hk.Reserve("pattern", lens[0] + lens[1]); hk.Reserve("pattern", lens[0] + lens[1]);
hk.Allocate(); hk.Allocate();
} }
@ -600,8 +501,7 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
break; break;
} }
case TYPE_RECORD: case TYPE_RECORD: {
{
auto rv = v->AsRecordVal(); auto rv = v->AsRecordVal();
auto rt = bt->AsRecordType(); auto rt = bt->AsRecordType();
int num_fields = rt->NumFields(); int num_fields = rt->NumFields();
@ -609,8 +509,7 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) if ( ! EnsureTypeReserve(hk, v, bt, type_check) )
return false; return false;
for ( int i = 0; i < num_fields; ++i ) for ( int i = 0; i < num_fields; ++i ) {
{
auto rv_i = rv->GetField(i); auto rv_i = rv->GetField(i);
Attributes* a = rt->FieldDecl(i)->attrs.get(); Attributes* a = rt->FieldDecl(i)->attrs.get();
@ -619,15 +518,14 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
if ( ! (rv_i || optional_attr) ) if ( ! (rv_i || optional_attr) )
return false; return false;
if ( ! SingleValHash(hk, rv_i.get(), rt->GetFieldType(i).get(), type_check, if ( ! SingleValHash(hk, rv_i.get(), rt->GetFieldType(i).get(), type_check, optional_attr,
optional_attr, false) ) false) )
return false; return false;
} }
break; break;
} }
case TYPE_TABLE: case TYPE_TABLE: {
{
if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) if ( ! EnsureTypeReserve(hk, v, bt, type_check) )
return false; return false;
@ -636,28 +534,22 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
hk.Write("table-size", tv->Size()); hk.Write("table-size", tv->Size());
for ( auto& kv : *hashkeys ) for ( auto& kv : *hashkeys ) {
{
auto key = kv.second; auto key = kv.second;
if ( ! SingleValHash(hk, key.get(), key->GetType().get(), type_check, false, if ( ! SingleValHash(hk, key.get(), key->GetType().get(), type_check, false, false) )
false) )
return false; return false;
if ( ! v->GetType()->IsSet() ) if ( ! v->GetType()->IsSet() ) {
{
auto val = const_cast<TableVal*>(tv)->FindOrDefault(key); auto val = const_cast<TableVal*>(tv)->FindOrDefault(key);
if ( ! SingleValHash(hk, val.get(), val->GetType().get(), type_check, if ( ! SingleValHash(hk, val.get(), val->GetType().get(), type_check, false, false) )
false, false) )
return false; return false;
} }
} }
} } break;
break;
case TYPE_VECTOR: case TYPE_VECTOR: {
{
if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) if ( ! EnsureTypeReserve(hk, v, bt, type_check) )
return false; return false;
@ -666,25 +558,19 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
hk.Write("vector-size", vv->Size()); hk.Write("vector-size", vv->Size());
for ( unsigned int i = 0; i < vv->Size(); ++i ) for ( unsigned int i = 0; i < vv->Size(); ++i ) {
{
auto val = vv->ValAt(i); auto val = vv->ValAt(i);
hk.Write("vector-idx", i); hk.Write("vector-idx", i);
hk.Write("vector-idx-present", val != nullptr); hk.Write("vector-idx-present", val != nullptr);
if ( val && ! SingleValHash(hk, val.get(), vt->Yield().get(), type_check, if ( val && ! SingleValHash(hk, val.get(), vt->Yield().get(), type_check, false, false) )
false, false) )
return false; return false;
} }
} } break;
break;
case TYPE_LIST: case TYPE_LIST: {
{ if ( ! hk.IsAllocated() ) {
if ( ! hk.IsAllocated() ) if ( ! ReserveSingleTypeKeySize(hk, bt, v, type_check, false, false, false) )
{
if ( ! ReserveSingleTypeKeySize(hk, bt, v, type_check, false, false,
false) )
return false; return false;
hk.Allocate(); hk.Allocate();
@ -694,18 +580,14 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
hk.Write("list-size", lv->Length()); hk.Write("list-size", lv->Length());
for ( int i = 0; i < lv->Length(); ++i ) for ( int i = 0; i < lv->Length(); ++i ) {
{
Val* entry_val = lv->Idx(i).get(); Val* entry_val = lv->Idx(i).get();
if ( ! SingleValHash(hk, entry_val, entry_val->GetType().get(), type_check, if ( ! SingleValHash(hk, entry_val, entry_val->GetType().get(), type_check, false, false) )
false, false) )
return false; return false;
} }
} } break;
break;
default: default: {
{
reporter->InternalError("bad index type in CompositeHash::SingleValHash"); reporter->InternalError("bad index type in CompositeHash::SingleValHash");
return false; return false;
} }
@ -714,8 +596,7 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
break; // case TYPE_INTERNAL_VOID/OTHER break; // case TYPE_INTERNAL_VOID/OTHER
} }
case TYPE_INTERNAL_STRING: case TYPE_INTERNAL_STRING: {
{
if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) if ( ! EnsureTypeReserve(hk, v, bt, type_check) )
return false; return false;
@ -725,18 +606,15 @@ bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type
hk.Write("string-len", sval->Len()); hk.Write("string-len", sval->Len());
hk.Write("string", sval->Bytes(), sval->Len()); hk.Write("string", sval->Bytes(), sval->Len());
} } break;
break;
default: default: return false;
return false;
} }
return true; return true;
} }
bool CompositeHash::EnsureTypeReserve(HashKey& hk, const Val* v, Type* bt, bool type_check) const bool CompositeHash::EnsureTypeReserve(HashKey& hk, const Val* v, Type* bt, bool type_check) const {
{
if ( hk.IsAllocated() ) if ( hk.IsAllocated() )
return true; return true;
@ -747,94 +625,69 @@ bool CompositeHash::EnsureTypeReserve(HashKey& hk, const Val* v, Type* bt, bool
return true; return true;
} }
bool CompositeHash::ReserveKeySize(HashKey& hk, const Val* v, bool type_check, bool CompositeHash::ReserveKeySize(HashKey& hk, const Val* v, bool type_check, bool calc_static_size) const {
bool calc_static_size) const
{
const auto& tl = type->GetTypes(); const auto& tl = type->GetTypes();
for ( auto i = 0u; i < tl.size(); ++i ) for ( auto i = 0u; i < tl.size(); ++i ) {
{ if ( ! ReserveSingleTypeKeySize(hk, tl[i].get(), v ? v->AsListVal()->Idx(i).get() : nullptr, type_check, false,
if ( ! ReserveSingleTypeKeySize(hk, tl[i].get(), v ? v->AsListVal()->Idx(i).get() : nullptr, calc_static_size, is_singleton) )
type_check, false, calc_static_size, is_singleton) )
return false; return false;
} }
return true; return true;
} }
bool CompositeHash::ReserveSingleTypeKeySize(HashKey& hk, Type* bt, const Val* v, bool type_check, bool CompositeHash::ReserveSingleTypeKeySize(HashKey& hk, Type* bt, const Val* v, bool type_check, bool optional,
bool optional, bool calc_static_size, bool calc_static_size, bool singleton) const {
bool singleton) const
{
InternalTypeTag t = bt->InternalType(); InternalTypeTag t = bt->InternalType();
if ( optional ) if ( optional ) {
{
hk.ReserveType<bool>("optional"); hk.ReserveType<bool>("optional");
if ( ! v ) if ( ! v )
return true; return true;
} }
if ( type_check && v ) if ( type_check && v ) {
{
InternalTypeTag vt = v->GetType()->InternalType(); InternalTypeTag vt = v->GetType()->InternalType();
if ( vt != t ) if ( vt != t )
return false; return false;
} }
switch ( t ) switch ( t ) {
{ case TYPE_INTERNAL_INT: hk.ReserveType<zeek_int_t>("int"); break;
case TYPE_INTERNAL_INT:
hk.ReserveType<zeek_int_t>("int");
break;
case TYPE_INTERNAL_UNSIGNED: case TYPE_INTERNAL_UNSIGNED: hk.ReserveType<zeek_int_t>("unsigned"); break;
hk.ReserveType<zeek_int_t>("unsigned");
break;
case TYPE_INTERNAL_ADDR: case TYPE_INTERNAL_ADDR: hk.Reserve("addr", sizeof(uint32_t) * 4, sizeof(uint32_t)); break;
hk.Reserve("addr", sizeof(uint32_t) * 4, sizeof(uint32_t));
break;
case TYPE_INTERNAL_SUBNET: case TYPE_INTERNAL_SUBNET: hk.Reserve("subnet", sizeof(uint32_t) * 5, sizeof(uint32_t)); break;
hk.Reserve("subnet", sizeof(uint32_t) * 5, sizeof(uint32_t));
break;
case TYPE_INTERNAL_DOUBLE: case TYPE_INTERNAL_DOUBLE: hk.ReserveType<double>("double"); break;
hk.ReserveType<double>("double");
break;
case TYPE_INTERNAL_VOID: case TYPE_INTERNAL_VOID:
case TYPE_INTERNAL_OTHER: case TYPE_INTERNAL_OTHER: {
{ switch ( bt->Tag() ) {
switch ( bt->Tag() ) case TYPE_FUNC: {
{
case TYPE_FUNC:
{
hk.ReserveType<uint32_t>("func"); hk.ReserveType<uint32_t>("func");
break; break;
} }
case TYPE_PATTERN: case TYPE_PATTERN: {
{
if ( ! v ) if ( ! v )
return (optional && ! calc_static_size); return (optional && ! calc_static_size);
if ( ! singleton ) if ( ! singleton ) {
{
hk.ReserveType<uint64_t>("pattern-len1"); hk.ReserveType<uint64_t>("pattern-len1");
hk.ReserveType<uint64_t>("pattern-len2"); hk.ReserveType<uint64_t>("pattern-len2");
} }
// +1 in the following to include null terminators // +1 in the following to include null terminators
hk.Reserve("pattern-string1", strlen(v->AsPattern()->PatternText()) + 1, 0); hk.Reserve("pattern-string1", strlen(v->AsPattern()->PatternText()) + 1, 0);
hk.Reserve("pattern-string1", strlen(v->AsPattern()->AnywherePatternText()) + 1, hk.Reserve("pattern-string1", strlen(v->AsPattern()->AnywherePatternText()) + 1, 0);
0);
break; break;
} }
case TYPE_RECORD: case TYPE_RECORD: {
{
if ( ! v ) if ( ! v )
return (optional && ! calc_static_size); return (optional && ! calc_static_size);
@ -842,22 +695,19 @@ bool CompositeHash::ReserveSingleTypeKeySize(HashKey& hk, Type* bt, const Val* v
RecordType* rt = bt->AsRecordType(); RecordType* rt = bt->AsRecordType();
int num_fields = rt->NumFields(); int num_fields = rt->NumFields();
for ( int i = 0; i < num_fields; ++i ) for ( int i = 0; i < num_fields; ++i ) {
{
Attributes* a = rt->FieldDecl(i)->attrs.get(); Attributes* a = rt->FieldDecl(i)->attrs.get();
bool optional_attr = (a && a->Find(ATTR_OPTIONAL)); bool optional_attr = (a && a->Find(ATTR_OPTIONAL));
auto rv_v = rv ? rv->GetField(i) : nullptr; auto rv_v = rv ? rv->GetField(i) : nullptr;
if ( ! ReserveSingleTypeKeySize(hk, rt->GetFieldType(i).get(), rv_v.get(), if ( ! ReserveSingleTypeKeySize(hk, rt->GetFieldType(i).get(), rv_v.get(), type_check,
type_check, optional_attr, calc_static_size, optional_attr, calc_static_size, false) )
false) )
return false; return false;
} }
break; break;
} }
case TYPE_TABLE: case TYPE_TABLE: {
{
if ( ! v ) if ( ! v )
return (optional && ! calc_static_size); return (optional && ! calc_static_size);
@ -866,21 +716,17 @@ bool CompositeHash::ReserveSingleTypeKeySize(HashKey& hk, Type* bt, const Val* v
hk.ReserveType<int>("table-size"); hk.ReserveType<int>("table-size");
for ( auto& kv : *hashkeys ) for ( auto& kv : *hashkeys ) {
{
auto key = kv.second; auto key = kv.second;
if ( ! ReserveSingleTypeKeySize(hk, key->GetType().get(), key.get(), if ( ! ReserveSingleTypeKeySize(hk, key->GetType().get(), key.get(), type_check, false,
type_check, false, calc_static_size, calc_static_size, false) )
false) )
return false; return false;
if ( ! bt->IsSet() ) if ( ! bt->IsSet() ) {
{
auto val = const_cast<TableVal*>(tv)->FindOrDefault(key); auto val = const_cast<TableVal*>(tv)->FindOrDefault(key);
if ( ! ReserveSingleTypeKeySize(hk, val->GetType().get(), val.get(), if ( ! ReserveSingleTypeKeySize(hk, val->GetType().get(), val.get(), type_check, false,
type_check, false, calc_static_size, calc_static_size, false) )
false) )
return false; return false;
} }
} }
@ -888,48 +734,40 @@ bool CompositeHash::ReserveSingleTypeKeySize(HashKey& hk, Type* bt, const Val* v
break; break;
} }
case TYPE_VECTOR: case TYPE_VECTOR: {
{
if ( ! v ) if ( ! v )
return (optional && ! calc_static_size); return (optional && ! calc_static_size);
hk.ReserveType<int>("vector-size"); hk.ReserveType<int>("vector-size");
VectorVal* vv = const_cast<VectorVal*>(v->AsVectorVal()); VectorVal* vv = const_cast<VectorVal*>(v->AsVectorVal());
for ( unsigned int i = 0; i < vv->Size(); ++i ) for ( unsigned int i = 0; i < vv->Size(); ++i ) {
{
auto val = vv->ValAt(i); auto val = vv->ValAt(i);
hk.ReserveType<unsigned int>("vector-idx"); hk.ReserveType<unsigned int>("vector-idx");
hk.ReserveType<unsigned int>("vector-idx-present"); hk.ReserveType<unsigned int>("vector-idx-present");
if ( val && ! ReserveSingleTypeKeySize( if ( val && ! ReserveSingleTypeKeySize(hk, bt->AsVectorType()->Yield().get(), val.get(),
hk, bt->AsVectorType()->Yield().get(), val.get(),
type_check, false, calc_static_size, false) ) type_check, false, calc_static_size, false) )
return false; return false;
} }
break; break;
} }
case TYPE_LIST: case TYPE_LIST: {
{
if ( ! v ) if ( ! v )
return (optional && ! calc_static_size); return (optional && ! calc_static_size);
hk.ReserveType<int>("list-size"); hk.ReserveType<int>("list-size");
ListVal* lv = const_cast<ListVal*>(v->AsListVal()); ListVal* lv = const_cast<ListVal*>(v->AsListVal());
for ( int i = 0; i < lv->Length(); ++i ) for ( int i = 0; i < lv->Length(); ++i ) {
{ if ( ! ReserveSingleTypeKeySize(hk, lv->Idx(i)->GetType().get(), lv->Idx(i).get(), type_check,
if ( ! ReserveSingleTypeKeySize(hk, lv->Idx(i)->GetType().get(), false, calc_static_size, false) )
lv->Idx(i).get(), type_check, false,
calc_static_size, false) )
return false; return false;
} }
break; break;
} }
default: default: {
{ reporter->InternalError("bad index type in CompositeHash::ReserveSingleTypeKeySize");
reporter->InternalError(
"bad index type in CompositeHash::ReserveSingleTypeKeySize");
return false; return false;
} }
} }
@ -945,8 +783,7 @@ bool CompositeHash::ReserveSingleTypeKeySize(HashKey& hk, Type* bt, const Val* v
hk.Reserve("string", v->AsString()->Len()); hk.Reserve("string", v->AsString()->Len());
break; break;
case TYPE_INTERNAL_ERROR: case TYPE_INTERNAL_ERROR: return false;
return false;
} }
return true; return true;

View file

@ -7,21 +7,18 @@
#include "zeek/Func.h" #include "zeek/Func.h"
#include "zeek/Type.h" #include "zeek/Type.h"
namespace zeek namespace zeek {
{
class ListVal; class ListVal;
using ListValPtr = zeek::IntrusivePtr<ListVal>; using ListValPtr = zeek::IntrusivePtr<ListVal>;
} // namespace zeek } // namespace zeek
namespace zeek::detail namespace zeek::detail {
{
class HashKey; class HashKey;
class CompositeHash class CompositeHash {
{
public: public:
explicit CompositeHash(TypeListPtr composite_type); explicit CompositeHash(TypeListPtr composite_type);
@ -33,15 +30,13 @@ public:
ListValPtr RecoverVals(const HashKey& k) const; ListValPtr RecoverVals(const HashKey& k) const;
protected: protected:
bool SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type_check, bool optional, bool SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type_check, bool optional, bool singleton) const;
bool singleton) const;
// Recovers just one Val of possibly many; called from RecoverVals. // Recovers just one Val of possibly many; called from RecoverVals.
// Upon return, pval will point to the recovered Val of type t. // Upon return, pval will point to the recovered Val of type t.
// Returns and updated kp for the next Val. Calls reporter->InternalError() // Returns and updated kp for the next Val. Calls reporter->InternalError()
// upon errors, so there is no return value for invalid input. // upon errors, so there is no return value for invalid input.
bool RecoverOneVal(const HashKey& k, Type* t, ValPtr* pval, bool optional, bool RecoverOneVal(const HashKey& k, Type* t, ValPtr* pval, bool optional, bool singleton) const;
bool singleton) const;
// Compute the size of the composite key. If v is non-nil then // Compute the size of the composite key. If v is non-nil then
// the value is computed for the particular list of values. // the value is computed for the particular list of values.
@ -60,8 +55,7 @@ protected:
// lower for the common case of these not being needed. // lower for the common case of these not being needed.
std::unique_ptr<std::unordered_map<const Func*, uint32_t>> func_to_func_id; std::unique_ptr<std::unordered_map<const Func*, uint32_t>> func_to_func_id;
std::unique_ptr<std::vector<FuncPtr>> func_id_to_func; std::unique_ptr<std::vector<FuncPtr>> func_id_to_func;
void BuildFuncMappings() void BuildFuncMappings() {
{
func_to_func_id = std::make_unique<std::unordered_map<const Func*, uint32_t>>(); func_to_func_id = std::make_unique<std::unordered_map<const Func*, uint32_t>>();
func_id_to_func = std::make_unique<std::vector<FuncPtr>>(); func_id_to_func = std::make_unique<std::vector<FuncPtr>>();
} }

View file

@ -22,18 +22,13 @@
#include "zeek/packet_analysis/protocol/tcp/TCP.h" #include "zeek/packet_analysis/protocol/tcp/TCP.h"
#include "zeek/session/Manager.h" #include "zeek/session/Manager.h"
namespace zeek namespace zeek {
{
uint64_t Connection::total_connections = 0; uint64_t Connection::total_connections = 0;
uint64_t Connection::current_connections = 0; uint64_t Connection::current_connections = 0;
Connection::Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, Connection::Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, const Packet* pkt)
const Packet* pkt) : Session(t, connection_timeout, connection_status_update, detail::connection_status_update_interval), key(k) {
: Session(t, connection_timeout, connection_status_update,
detail::connection_status_update_interval),
key(k)
{
orig_addr = id->src_addr; orig_addr = id->src_addr;
resp_addr = id->dst_addr; resp_addr = id->dst_addr;
orig_port = id->src_port; orig_port = id->src_port;
@ -75,8 +70,7 @@ Connection::Connection(const detail::ConnKey& k, double t, const ConnTuple* id,
encapsulation = pkt->encap; encapsulation = pkt->encap;
} }
Connection::~Connection() Connection::~Connection() {
{
if ( ! finished ) if ( ! finished )
reporter->InternalError("Done() not called before destruction of Connection"); reporter->InternalError("Done() not called before destruction of Connection");
@ -90,16 +84,11 @@ Connection::~Connection()
--current_connections; --current_connections;
} }
void Connection::CheckEncapsulation(const std::shared_ptr<EncapsulationStack>& arg_encap) void Connection::CheckEncapsulation(const std::shared_ptr<EncapsulationStack>& arg_encap) {
{ if ( encapsulation && arg_encap ) {
if ( encapsulation && arg_encap ) if ( *encapsulation != *arg_encap ) {
{ if ( tunnel_changed && (zeek::detail::tunnel_max_changes_per_connection == 0 ||
if ( *encapsulation != *arg_encap ) tunnel_changes < zeek::detail::tunnel_max_changes_per_connection) ) {
{
if ( tunnel_changed &&
(zeek::detail::tunnel_max_changes_per_connection == 0 ||
tunnel_changes < zeek::detail::tunnel_max_changes_per_connection) )
{
tunnel_changes++; tunnel_changes++;
EnqueueEvent(tunnel_changed, nullptr, GetVal(), arg_encap->ToVal()); EnqueueEvent(tunnel_changed, nullptr, GetVal(), arg_encap->ToVal());
} }
@ -108,10 +97,8 @@ void Connection::CheckEncapsulation(const std::shared_ptr<EncapsulationStack>& a
} }
} }
else if ( encapsulation ) else if ( encapsulation ) {
{ if ( tunnel_changed ) {
if ( tunnel_changed )
{
EncapsulationStack empty; EncapsulationStack empty;
EnqueueEvent(tunnel_changed, nullptr, GetVal(), empty.ToVal()); EnqueueEvent(tunnel_changed, nullptr, GetVal(), empty.ToVal());
} }
@ -119,8 +106,7 @@ void Connection::CheckEncapsulation(const std::shared_ptr<EncapsulationStack>& a
encapsulation = nullptr; encapsulation = nullptr;
} }
else if ( arg_encap ) else if ( arg_encap ) {
{
if ( tunnel_changed ) if ( tunnel_changed )
EnqueueEvent(tunnel_changed, nullptr, GetVal(), arg_encap->ToVal()); EnqueueEvent(tunnel_changed, nullptr, GetVal(), arg_encap->ToVal());
@ -128,14 +114,11 @@ void Connection::CheckEncapsulation(const std::shared_ptr<EncapsulationStack>& a
} }
} }
void Connection::Done() void Connection::Done() {
{
finished = 1; finished = 1;
if ( adapter ) if ( adapter ) {
{ if ( ConnTransport() == TRANSPORT_TCP ) {
if ( ConnTransport() == TRANSPORT_TCP )
{
auto* ta = static_cast<packet_analysis::TCP::TCPSessionAdapter*>(adapter); auto* ta = static_cast<packet_analysis::TCP::TCPSessionAdapter*>(adapter);
assert(ta->IsAnalyzer("TCP")); assert(ta->IsAnalyzer("TCP"));
analyzer::tcp::TCP_Endpoint* to = ta->Orig(); analyzer::tcp::TCP_Endpoint* to = ta->Orig();
@ -149,16 +132,14 @@ void Connection::Done()
} }
} }
void Connection::NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, int caplen, void Connection::NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, int caplen, const u_char*& data,
const u_char*& data, int& record_packet, int& record_content, int& record_packet, int& record_content,
// arguments for reproducing packets // arguments for reproducing packets
const Packet* pkt) const Packet* pkt) {
{
run_state::current_timestamp = t; run_state::current_timestamp = t;
run_state::current_pkt = pkt; run_state::current_pkt = pkt;
if ( adapter ) if ( adapter ) {
{
if ( adapter->Skipping() ) if ( adapter->Skipping() )
return; return;
@ -175,16 +156,10 @@ void Connection::NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, i
run_state::current_pkt = nullptr; run_state::current_pkt = nullptr;
} }
bool Connection::IsReuse(double t, const u_char* pkt) bool Connection::IsReuse(double t, const u_char* pkt) { return adapter && adapter->IsReuse(t, pkt); }
{
return adapter && adapter->IsReuse(t, pkt);
}
bool Connection::ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scaling_threshold, bool Connection::ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scaling_threshold, uint32_t scaling_base) {
uint32_t scaling_base) if ( ++counter == scaling_threshold ) {
{
if ( ++counter == scaling_threshold )
{
AddHistory(code); AddHistory(code);
auto new_threshold = scaling_threshold * scaling_base; auto new_threshold = scaling_threshold * scaling_base;
@ -204,8 +179,7 @@ bool Connection::ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scal
return false; return false;
} }
void Connection::HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t threshold) void Connection::HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t threshold) {
{
if ( ! e ) if ( ! e )
return; return;
@ -217,13 +191,11 @@ void Connection::HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t
EnqueueEvent(e, nullptr, GetVal(), val_mgr->Bool(is_orig), val_mgr->Count(threshold)); EnqueueEvent(e, nullptr, GetVal(), val_mgr->Bool(is_orig), val_mgr->Count(threshold));
} }
namespace namespace {
{
// Flip everything that needs to be flipped in the connection // Flip everything that needs to be flipped in the connection
// record that is known on this level. This needs to align // record that is known on this level. This needs to align
// with GetVal() and connection's layout in init-bare. // with GetVal() and connection's layout in init-bare.
void flip_conn_val(const RecordValPtr& conn_val) void flip_conn_val(const RecordValPtr& conn_val) {
{
// Flip the the conn_id (c$id). // Flip the the conn_id (c$id).
const auto& id_val = conn_val->GetField<zeek::RecordVal>(0); const auto& id_val = conn_val->GetField<zeek::RecordVal>(0);
const auto& tmp_addr = id_val->GetField<zeek::AddrVal>(0); const auto& tmp_addr = id_val->GetField<zeek::AddrVal>(0);
@ -238,12 +210,10 @@ void flip_conn_val(const RecordValPtr& conn_val)
conn_val->Assign(1, conn_val->GetField(2)); conn_val->Assign(1, conn_val->GetField(2));
conn_val->Assign(2, tmp_endp); conn_val->Assign(2, tmp_endp);
} }
} } // namespace
const RecordValPtr& Connection::GetVal() const RecordValPtr& Connection::GetVal() {
{ if ( ! conn_val ) {
if ( ! conn_val )
{
conn_val = make_intrusive<RecordVal>(id::connection); conn_val = make_intrusive<RecordVal>(id::connection);
TransportProto prot_type = ConnTransport(); TransportProto prot_type = ConnTransport();
@ -301,8 +271,7 @@ const RecordValPtr& Connection::GetVal()
conn_val->AssignTime(3, start_time); // ### conn_val->AssignTime(3, start_time); // ###
conn_val->AssignInterval(4, last_time - start_time); conn_val->AssignInterval(4, last_time - start_time);
if ( ! history.empty() ) if ( ! history.empty() ) {
{
auto v = conn_val->GetFieldAs<StringVal>(6); auto v = conn_val->GetFieldAs<StringVal>(6);
if ( *v != history ) if ( *v != history )
conn_val->Assign(6, history); conn_val->Assign(6, history);
@ -313,23 +282,15 @@ const RecordValPtr& Connection::GetVal()
return conn_val; return conn_val;
} }
analyzer::Analyzer* Connection::FindAnalyzer(analyzer::ID id) analyzer::Analyzer* Connection::FindAnalyzer(analyzer::ID id) { return adapter ? adapter->FindChild(id) : nullptr; }
{
return adapter ? adapter->FindChild(id) : nullptr;
}
analyzer::Analyzer* Connection::FindAnalyzer(const zeek::Tag& tag) analyzer::Analyzer* Connection::FindAnalyzer(const zeek::Tag& tag) {
{
return adapter ? adapter->FindChild(tag) : nullptr; return adapter ? adapter->FindChild(tag) : nullptr;
} }
analyzer::Analyzer* Connection::FindAnalyzer(const char* name) analyzer::Analyzer* Connection::FindAnalyzer(const char* name) { return adapter->FindChild(name); }
{
return adapter->FindChild(name);
}
void Connection::AppendAddl(const char* str) void Connection::AppendAddl(const char* str) {
{
const auto& cv = GetVal(); const auto& cv = GetVal();
const char* old = cv->GetFieldAs<StringVal>(6)->CheckString(); const char* old = cv->GetFieldAs<StringVal>(6)->CheckString();
@ -338,27 +299,23 @@ void Connection::AppendAddl(const char* str)
cv->Assign(6, util::fmt(format, old, str)); cv->Assign(6, util::fmt(format, old, str));
} }
void Connection::Match(detail::Rule::PatternType type, const u_char* data, int len, bool is_orig, void Connection::Match(detail::Rule::PatternType type, const u_char* data, int len, bool is_orig, bool bol, bool eol,
bool bol, bool eol, bool clear_state) bool clear_state) {
{
if ( primary_PIA ) if ( primary_PIA )
primary_PIA->Match(type, data, len, is_orig, bol, eol, clear_state); primary_PIA->Match(type, data, len, is_orig, bol, eol, clear_state);
} }
void Connection::RemovalEvent() void Connection::RemovalEvent() {
{
if ( connection_state_remove ) if ( connection_state_remove )
EnqueueEvent(connection_state_remove, nullptr, GetVal()); EnqueueEvent(connection_state_remove, nullptr, GetVal());
} }
void Connection::Weird(const char* name, const char* addl, const char* source) void Connection::Weird(const char* name, const char* addl, const char* source) {
{
weird = 1; weird = 1;
reporter->Weird(this, name, addl ? addl : "", source ? source : ""); reporter->Weird(this, name, addl ? addl : "", source ? source : "");
} }
void Connection::FlipRoles() void Connection::FlipRoles() {
{
IPAddr tmp_addr = resp_addr; IPAddr tmp_addr = resp_addr;
resp_addr = orig_addr; resp_addr = orig_addr;
orig_addr = tmp_addr; orig_addr = tmp_addr;
@ -395,23 +352,15 @@ void Connection::FlipRoles()
EnqueueEvent(connection_flipped, nullptr, GetVal()); EnqueueEvent(connection_flipped, nullptr, GetVal());
} }
void Connection::Describe(ODesc* d) const void Connection::Describe(ODesc* d) const {
{
session::Session::Describe(d); session::Session::Describe(d);
switch ( proto ) switch ( proto ) {
{ case TRANSPORT_TCP: d->Add("TCP"); break;
case TRANSPORT_TCP:
d->Add("TCP");
break;
case TRANSPORT_UDP: case TRANSPORT_UDP: d->Add("UDP"); break;
d->Add("UDP");
break;
case TRANSPORT_ICMP: case TRANSPORT_ICMP: d->Add("ICMP"); break;
d->Add("ICMP");
break;
case TRANSPORT_UNKNOWN: case TRANSPORT_UNKNOWN:
d->Add("unknown"); d->Add("unknown");
@ -419,8 +368,7 @@ void Connection::Describe(ODesc* d) const
break; break;
default: default: reporter->InternalError("unhandled transport type in Connection::Describe");
reporter->InternalError("unhandled transport type in Connection::Describe");
} }
d->SP(); d->SP();
@ -438,8 +386,7 @@ void Connection::Describe(ODesc* d) const
d->NL(); d->NL();
} }
void Connection::IDString(ODesc* d) const void Connection::IDString(ODesc* d) const {
{
d->Add(orig_addr); d->Add(orig_addr);
d->AddRaw(":", 1); d->AddRaw(":", 1);
d->Add(ntohs(orig_port)); d->Add(ntohs(orig_port));
@ -449,27 +396,21 @@ void Connection::IDString(ODesc* d) const
d->Add(ntohs(resp_port)); d->Add(ntohs(resp_port));
} }
void Connection::SetSessionAdapter(packet_analysis::IP::SessionAdapter* aa, analyzer::pia::PIA* pia) void Connection::SetSessionAdapter(packet_analysis::IP::SessionAdapter* aa, analyzer::pia::PIA* pia) {
{
adapter = aa; adapter = aa;
primary_PIA = pia; primary_PIA = pia;
} }
void Connection::CheckFlowLabel(bool is_orig, uint32_t flow_label) void Connection::CheckFlowLabel(bool is_orig, uint32_t flow_label) {
{
uint32_t& my_flow_label = is_orig ? orig_flow_label : resp_flow_label; uint32_t& my_flow_label = is_orig ? orig_flow_label : resp_flow_label;
if ( my_flow_label != flow_label ) if ( my_flow_label != flow_label ) {
{ if ( conn_val ) {
if ( conn_val )
{
RecordVal* endp = conn_val->GetFieldAs<RecordVal>(is_orig ? 1 : 2); RecordVal* endp = conn_val->GetFieldAs<RecordVal>(is_orig ? 1 : 2);
endp->Assign(4, flow_label); endp->Assign(4, flow_label);
} }
if ( connection_flow_label_changed && if ( connection_flow_label_changed && (is_orig ? saw_first_orig_packet : saw_first_resp_packet) ) {
(is_orig ? saw_first_orig_packet : saw_first_resp_packet) )
{
EnqueueEvent(connection_flow_label_changed, nullptr, GetVal(), val_mgr->Bool(is_orig), EnqueueEvent(connection_flow_label_changed, nullptr, GetVal(), val_mgr->Bool(is_orig),
val_mgr->Count(my_flow_label), val_mgr->Count(flow_label)); val_mgr->Count(my_flow_label), val_mgr->Count(flow_label));
} }
@ -483,8 +424,7 @@ void Connection::CheckFlowLabel(bool is_orig, uint32_t flow_label)
saw_first_resp_packet = 1; saw_first_resp_packet = 1;
} }
bool Connection::PermitWeird(const char* name, uint64_t threshold, uint64_t rate, double duration) bool Connection::PermitWeird(const char* name, uint64_t threshold, uint64_t rate, double duration) {
{
return detail::PermitWeird(weird_state, name, threshold, rate, duration); return detail::PermitWeird(weird_state, name, threshold, rate, duration);
} }

View file

@ -19,8 +19,7 @@
#include "zeek/iosource/Packet.h" #include "zeek/iosource/Packet.h"
#include "zeek/session/Session.h" #include "zeek/session/Session.h"
namespace zeek namespace zeek {
{
class Connection; class Connection;
class EncapsulationStack; class EncapsulationStack;
@ -30,12 +29,10 @@ class RecordVal;
using ValPtr = IntrusivePtr<Val>; using ValPtr = IntrusivePtr<Val>;
using RecordValPtr = IntrusivePtr<RecordVal>; using RecordValPtr = IntrusivePtr<RecordVal>;
namespace session namespace session {
{
class Manager; class Manager;
} }
namespace detail namespace detail {
{
class Specific_RE_Matcher; class Specific_RE_Matcher;
class RuleEndpointState; class RuleEndpointState;
@ -43,25 +40,21 @@ class RuleHdrTest;
} // namespace detail } // namespace detail
namespace analyzer namespace analyzer {
{
class Analyzer; class Analyzer;
} }
namespace packet_analysis::IP namespace packet_analysis::IP {
{
class SessionAdapter; class SessionAdapter;
} }
enum ConnEventToFlag enum ConnEventToFlag {
{
NUL_IN_LINE, NUL_IN_LINE,
SINGULAR_CR, SINGULAR_CR,
SINGULAR_LF, SINGULAR_LF,
NUM_EVENTS_TO_FLAG, NUM_EVENTS_TO_FLAG,
}; };
struct ConnTuple struct ConnTuple {
{
IPAddr src_addr; IPAddr src_addr;
IPAddr dst_addr; IPAddr dst_addr;
uint32_t src_port = 0; uint32_t src_port = 0;
@ -70,17 +63,13 @@ struct ConnTuple
TransportProto proto = TRANSPORT_UNKNOWN; TransportProto proto = TRANSPORT_UNKNOWN;
}; };
static inline int addr_port_canon_lt(const IPAddr& addr1, uint32_t p1, const IPAddr& addr2, static inline int addr_port_canon_lt(const IPAddr& addr1, uint32_t p1, const IPAddr& addr2, uint32_t p2) {
uint32_t p2)
{
return addr1 < addr2 || (addr1 == addr2 && p1 < p2); return addr1 < addr2 || (addr1 == addr2 && p1 < p2);
} }
class Connection final : public session::Session class Connection final : public session::Session {
{
public: public:
Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, const Packet* pkt);
const Packet* pkt);
~Connection() override; ~Connection() override;
/** /**
@ -109,8 +98,8 @@ public:
// If record_content is true, then its entire contents should // If record_content is true, then its entire contents should
// be recorded, otherwise just up through the transport header. // be recorded, otherwise just up through the transport header.
// Both are assumed set to true when called. // Both are assumed set to true when called.
void NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, int caplen, void NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, int caplen, const u_char*& data,
const u_char*& data, int& record_packet, int& record_content, int& record_packet, int& record_content,
// arguments for reproducing packets // arguments for reproducing packets
const Packet* pkt); const Packet* pkt);
@ -118,10 +107,8 @@ public:
// connection is in the session map. If it is removed, the key // connection is in the session map. If it is removed, the key
// should be marked invalid. // should be marked invalid.
const detail::ConnKey& Key() const { return key; } const detail::ConnKey& Key() const { return key; }
session::detail::Key SessionKey(bool copy) const override session::detail::Key SessionKey(bool copy) const override {
{ return session::detail::Key{&key, sizeof(key), session::detail::Key::CONNECTION_KEY_TYPE, copy};
return session::detail::Key{&key, sizeof(key), session::detail::Key::CONNECTION_KEY_TYPE,
copy};
} }
const IPAddr& OrigAddr() const { return orig_addr; } const IPAddr& OrigAddr() const { return orig_addr; }
@ -137,8 +124,7 @@ public:
analyzer::Analyzer* FindAnalyzer(const char* name); // find first in tree. analyzer::Analyzer* FindAnalyzer(const char* name); // find first in tree.
TransportProto ConnTransport() const { return proto; } TransportProto ConnTransport() const { return proto; }
std::string TransportIdentifier() const override std::string TransportIdentifier() const override {
{
if ( proto == TRANSPORT_TCP ) if ( proto == TRANSPORT_TCP )
return "tcp"; return "tcp";
else if ( proto == TRANSPORT_UDP ) else if ( proto == TRANSPORT_UDP )
@ -164,8 +150,8 @@ public:
*/ */
void AppendAddl(const char* str); void AppendAddl(const char* str);
void Match(detail::Rule::PatternType type, const u_char* data, int len, bool is_orig, bool bol, void Match(detail::Rule::PatternType type, const u_char* data, int len, bool is_orig, bool bol, bool eol,
bool eol, bool clear_state); bool clear_state);
/** /**
* Generates connection removal event(s). * Generates connection removal event(s).
@ -175,10 +161,8 @@ public:
void Weird(const char* name, const char* addl = "", const char* source = ""); void Weird(const char* name, const char* addl = "", const char* source = "");
bool DidWeird() const { return weird != 0; } bool DidWeird() const { return weird != 0; }
inline bool FlagEvent(ConnEventToFlag e) inline bool FlagEvent(ConnEventToFlag e) {
{ if ( e >= 0 && e < NUM_EVENTS_TO_FLAG ) {
if ( e >= 0 && e < NUM_EVENTS_TO_FLAG )
{
if ( suppress_event & (1 << e) ) if ( suppress_event & (1 << e) )
return false; return false;
suppress_event |= 1 << e; suppress_event |= 1 << e;
@ -196,10 +180,8 @@ public:
static uint64_t CurrentConnections() { return current_connections; } static uint64_t CurrentConnections() { return current_connections; }
// Returns true if the history was already seen, false otherwise. // Returns true if the history was already seen, false otherwise.
bool CheckHistory(uint32_t mask, char code) bool CheckHistory(uint32_t mask, char code) {
{ if ( (hist_seen & mask) == 0 ) {
if ( (hist_seen & mask) == 0 )
{
hist_seen |= mask; hist_seen |= mask;
AddHistory(code); AddHistory(code);
return false; return false;
@ -212,8 +194,7 @@ public:
// code if it has crossed the next scaling threshold. Scaling // code if it has crossed the next scaling threshold. Scaling
// is done in terms of powers of the third argument. // is done in terms of powers of the third argument.
// Returns true if the threshold was crossed, false otherwise. // Returns true if the threshold was crossed, false otherwise.
bool ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scaling_threshold, bool ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scaling_threshold, uint32_t scaling_base = 10);
uint32_t scaling_base = 10);
void HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t threshold); void HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t threshold);

View file

@ -8,12 +8,10 @@
#include "zeek/EquivClass.h" #include "zeek/EquivClass.h"
#include "zeek/Hash.h" #include "zeek/Hash.h"
namespace zeek::detail namespace zeek::detail {
{
DFA_State::DFA_State(int arg_state_num, const EquivClass* ec, NFA_state_list* arg_nfa_states, DFA_State::DFA_State(int arg_state_num, const EquivClass* ec, NFA_state_list* arg_nfa_states,
AcceptingSet* arg_accept) AcceptingSet* arg_accept) {
{
state_num = arg_state_num; state_num = arg_state_num;
num_sym = ec->NumClasses(); num_sym = ec->NumClasses();
nfa_states = arg_nfa_states; nfa_states = arg_nfa_states;
@ -28,39 +26,31 @@ DFA_State::DFA_State(int arg_state_num, const EquivClass* ec, NFA_state_list* ar
xtions[i] = DFA_UNCOMPUTED_STATE_PTR; xtions[i] = DFA_UNCOMPUTED_STATE_PTR;
} }
DFA_State::~DFA_State() DFA_State::~DFA_State() {
{
delete[] xtions; delete[] xtions;
delete nfa_states; delete nfa_states;
delete accept; delete accept;
delete meta_ec; delete meta_ec;
} }
void DFA_State::AddXtion(int sym, DFA_State* next_state) void DFA_State::AddXtion(int sym, DFA_State* next_state) { xtions[sym] = next_state; }
{
xtions[sym] = next_state;
}
void DFA_State::SymPartition(const EquivClass* ec) void DFA_State::SymPartition(const EquivClass* ec) {
{
// Partitioning is done by creating equivalence classes for those // Partitioning is done by creating equivalence classes for those
// characters which have out-transitions from the given state. Thus // characters which have out-transitions from the given state. Thus
// we are really creating equivalence classes of equivalence classes. // we are really creating equivalence classes of equivalence classes.
meta_ec = new EquivClass(ec->NumClasses()); meta_ec = new EquivClass(ec->NumClasses());
assert(nfa_states); assert(nfa_states);
for ( int i = 0; i < nfa_states->length(); ++i ) for ( int i = 0; i < nfa_states->length(); ++i ) {
{
NFA_State* n = (*nfa_states)[i]; NFA_State* n = (*nfa_states)[i];
int sym = n->TransSym(); int sym = n->TransSym();
if ( sym == SYM_EPSILON ) if ( sym == SYM_EPSILON )
continue; continue;
if ( sym != SYM_CCL ) if ( sym != SYM_CCL ) { // character transition
{ // character transition if ( ec->IsRep(sym) ) {
if ( ec->IsRep(sym) )
{
sym = ec->SymEquivClass(sym); sym = ec->SymEquivClass(sym);
meta_ec->UniqueChar(sym); meta_ec->UniqueChar(sym);
} }
@ -74,11 +64,9 @@ void DFA_State::SymPartition(const EquivClass* ec)
meta_ec->BuildECs(); meta_ec->BuildECs();
} }
DFA_State* DFA_State::ComputeXtion(int sym, DFA_Machine* machine) DFA_State* DFA_State::ComputeXtion(int sym, DFA_Machine* machine) {
{
int equiv_sym = meta_ec->EquivRep(sym); int equiv_sym = meta_ec->EquivRep(sym);
if ( xtions[equiv_sym] != DFA_UNCOMPUTED_STATE_PTR ) if ( xtions[equiv_sym] != DFA_UNCOMPUTED_STATE_PTR ) {
{
AddXtion(sym, xtions[equiv_sym]); AddXtion(sym, xtions[equiv_sym]);
return xtions[sym]; return xtions[sym];
} }
@ -88,14 +76,12 @@ DFA_State* DFA_State::ComputeXtion(int sym, DFA_Machine* machine)
DFA_State* next_d; DFA_State* next_d;
NFA_state_list* ns = SymFollowSet(equiv_sym, ec); NFA_state_list* ns = SymFollowSet(equiv_sym, ec);
if ( ns->length() > 0 ) if ( ns->length() > 0 ) {
{
NFA_state_list* state_set = epsilon_closure(ns); NFA_state_list* state_set = epsilon_closure(ns);
if ( ! machine->StateSetToDFA_State(state_set, next_d, ec) ) if ( ! machine->StateSetToDFA_State(state_set, next_d, ec) )
delete state_set; delete state_set;
} }
else else {
{
delete ns; delete ns;
next_d = nullptr; // Jam next_d = nullptr; // Jam
} }
@ -107,8 +93,7 @@ DFA_State* DFA_State::ComputeXtion(int sym, DFA_Machine* machine)
return xtions[sym]; return xtions[sym];
} }
void DFA_State::AppendIfNew(int sym, int_list* sym_list) void DFA_State::AppendIfNew(int sym, int_list* sym_list) {
{
for ( auto value : *sym_list ) for ( auto value : *sym_list )
if ( value == sym ) if ( value == sym )
return; return;
@ -116,26 +101,21 @@ void DFA_State::AppendIfNew(int sym, int_list* sym_list)
sym_list->push_back(sym); sym_list->push_back(sym);
} }
NFA_state_list* DFA_State::SymFollowSet(int ec_sym, const EquivClass* ec) NFA_state_list* DFA_State::SymFollowSet(int ec_sym, const EquivClass* ec) {
{
NFA_state_list* ns = new NFA_state_list; NFA_state_list* ns = new NFA_state_list;
assert(nfa_states); assert(nfa_states);
for ( int i = 0; i < nfa_states->length(); ++i ) for ( int i = 0; i < nfa_states->length(); ++i ) {
{
NFA_State* n = (*nfa_states)[i]; NFA_State* n = (*nfa_states)[i];
if ( n->TransSym() == SYM_CCL ) if ( n->TransSym() == SYM_CCL ) { // it's a character class
{ // it's a character class
CCL* ccl = n->TransCCL(); CCL* ccl = n->TransCCL();
int_list* syms = ccl->Syms(); int_list* syms = ccl->Syms();
if ( ccl->IsNegated() ) if ( ccl->IsNegated() ) {
{
size_t j; size_t j;
for ( j = 0; j < syms->size(); ++j ) for ( j = 0; j < syms->size(); ++j ) {
{
// Loop through (sorted) negated // Loop through (sorted) negated
// character class, which has // character class, which has
// presumably already been converted // presumably already been converted
@ -151,25 +131,21 @@ NFA_state_list* DFA_State::SymFollowSet(int ec_sym, const EquivClass* ec)
continue; continue;
} }
for ( auto sym : *syms ) for ( auto sym : *syms ) {
{
if ( sym > ec_sym ) if ( sym > ec_sym )
break; break;
if ( sym == ec_sym ) if ( sym == ec_sym ) {
{
n->AddXtionsTo(ns); n->AddXtionsTo(ns);
break; break;
} }
} }
} }
else if ( n->TransSym() == SYM_EPSILON ) else if ( n->TransSym() == SYM_EPSILON ) { // do nothing
{ // do nothing
} }
else if ( ec->IsRep(n->TransSym()) ) else if ( ec->IsRep(n->TransSym()) ) {
{
if ( ec_sym == ec->SymEquivClass(n->TransSym()) ) if ( ec_sym == ec->SymEquivClass(n->TransSym()) )
n->AddXtionsTo(ns); n->AddXtionsTo(ns);
} }
@ -179,14 +155,11 @@ NFA_state_list* DFA_State::SymFollowSet(int ec_sym, const EquivClass* ec)
return ns; return ns;
} }
void DFA_State::ClearMarks() void DFA_State::ClearMarks() {
{ if ( mark ) {
if ( mark )
{
SetMark(nullptr); SetMark(nullptr);
for ( int i = 0; i < num_sym; ++i ) for ( int i = 0; i < num_sym; ++i ) {
{
DFA_State* s = xtions[i]; DFA_State* s = xtions[i];
if ( s && s != DFA_UNCOMPUTED_STATE_PTR ) if ( s && s != DFA_UNCOMPUTED_STATE_PTR )
@ -195,20 +168,15 @@ void DFA_State::ClearMarks()
} }
} }
void DFA_State::Describe(ODesc* d) const void DFA_State::Describe(ODesc* d) const { d->Add("DFA state"); }
{
d->Add("DFA state");
}
void DFA_State::Dump(FILE* f, DFA_Machine* m) void DFA_State::Dump(FILE* f, DFA_Machine* m) {
{
if ( mark ) if ( mark )
return; return;
fprintf(f, "\nDFA state %d:", StateNum()); fprintf(f, "\nDFA state %d:", StateNum());
if ( accept ) if ( accept ) {
{
AcceptingSet::const_iterator it; AcceptingSet::const_iterator it;
for ( it = accept->begin(); it != accept->end(); ++it ) for ( it = accept->begin(); it != accept->end(); ++it )
@ -218,8 +186,7 @@ void DFA_State::Dump(FILE* f, DFA_Machine* m)
fprintf(f, "\n"); fprintf(f, "\n");
int num_trans = 0; int num_trans = 0;
for ( int sym = 0; sym < num_sym; ++sym ) for ( int sym = 0; sym < num_sym; ++sym ) {
{
DFA_State* s = xtions[sym]; DFA_State* s = xtions[sym];
if ( ! s ) if ( ! s )
@ -244,11 +211,9 @@ void DFA_State::Dump(FILE* f, DFA_Machine* m)
snprintf(xbuf, xbuf_size, "'%c'-'%c'", r, m->Rep(i - 1)); snprintf(xbuf, xbuf_size, "'%c'-'%c'", r, m->Rep(i - 1));
if ( s == DFA_UNCOMPUTED_STATE_PTR ) if ( s == DFA_UNCOMPUTED_STATE_PTR )
fprintf(f, "%stransition on %s to <uncomputed>", ++num_trans == 1 ? "\t" : "\n\t", fprintf(f, "%stransition on %s to <uncomputed>", ++num_trans == 1 ? "\t" : "\n\t", xbuf);
xbuf);
else else
fprintf(f, "%stransition on %s to state %d", ++num_trans == 1 ? "\t" : "\n\t", xbuf, fprintf(f, "%stransition on %s to state %d", ++num_trans == 1 ? "\t" : "\n\t", xbuf, s->StateNum());
s->StateNum());
delete[] xbuf; delete[] xbuf;
@ -260,8 +225,7 @@ void DFA_State::Dump(FILE* f, DFA_Machine* m)
SetMark(this); SetMark(this);
for ( int sym = 0; sym < num_sym; ++sym ) for ( int sym = 0; sym < num_sym; ++sym ) {
{
DFA_State* s = xtions[sym]; DFA_State* s = xtions[sym];
if ( s && s != DFA_UNCOMPUTED_STATE_PTR ) if ( s && s != DFA_UNCOMPUTED_STATE_PTR )
@ -269,10 +233,8 @@ void DFA_State::Dump(FILE* f, DFA_Machine* m)
} }
} }
void DFA_State::Stats(unsigned int* computed, unsigned int* uncomputed) void DFA_State::Stats(unsigned int* computed, unsigned int* uncomputed) {
{ for ( int sym = 0; sym < num_sym; ++sym ) {
for ( int sym = 0; sym < num_sym; ++sym )
{
DFA_State* s = xtions[sym]; DFA_State* s = xtions[sym];
if ( s == DFA_UNCOMPUTED_STATE_PTR ) if ( s == DFA_UNCOMPUTED_STATE_PTR )
@ -282,23 +244,17 @@ void DFA_State::Stats(unsigned int* computed, unsigned int* uncomputed)
} }
} }
unsigned int DFA_State::Size() unsigned int DFA_State::Size() {
{
return sizeof(*this) + util::pad_size(sizeof(DFA_State*) * num_sym) + return sizeof(*this) + util::pad_size(sizeof(DFA_State*) * num_sym) +
(accept ? util::pad_size(sizeof(int) * accept->size()) : 0) + (accept ? util::pad_size(sizeof(int) * accept->size()) : 0) +
(nfa_states ? util::pad_size(sizeof(NFA_State*) * nfa_states->length()) : 0) + (nfa_states ? util::pad_size(sizeof(NFA_State*) * nfa_states->length()) : 0) +
(meta_ec ? meta_ec->Size() : 0); (meta_ec ? meta_ec->Size() : 0);
} }
DFA_State_Cache::DFA_State_Cache() DFA_State_Cache::DFA_State_Cache() { hits = misses = 0; }
{
hits = misses = 0;
}
DFA_State_Cache::~DFA_State_Cache() DFA_State_Cache::~DFA_State_Cache() {
{ for ( auto& entry : states ) {
for ( auto& entry : states )
{
assert(entry.second); assert(entry.second);
Unref(entry.second); Unref(entry.second);
} }
@ -306,22 +262,18 @@ DFA_State_Cache::~DFA_State_Cache()
states.clear(); states.clear();
} }
DFA_State* DFA_State_Cache::Lookup(const NFA_state_list& nfas, DigestStr* digest) DFA_State* DFA_State_Cache::Lookup(const NFA_state_list& nfas, DigestStr* digest) {
{
// We assume that state ID's don't exceed 10 digits, plus // We assume that state ID's don't exceed 10 digits, plus
// we allow one more character for the delimiter. // we allow one more character for the delimiter.
auto id_tag_buf = std::make_unique<u_char[]>(nfas.length() * 11 + 1); auto id_tag_buf = std::make_unique<u_char[]>(nfas.length() * 11 + 1);
auto id_tag = id_tag_buf.get(); auto id_tag = id_tag_buf.get();
u_char* p = id_tag; u_char* p = id_tag;
for ( int i = 0; i < nfas.length(); ++i ) for ( int i = 0; i < nfas.length(); ++i ) {
{
NFA_State* n = nfas[i]; NFA_State* n = nfas[i];
if ( n->TransSym() != SYM_EPSILON || n->Accept() != NO_ACCEPT ) if ( n->TransSym() != SYM_EPSILON || n->Accept() != NO_ACCEPT ) {
{
int id = n->ID(); int id = n->ID();
do do {
{
*p++ = '0' + (char)(id % 10); *p++ = '0' + (char)(id % 10);
id /= 10; id /= 10;
} while ( id > 0 ); } while ( id > 0 );
@ -338,8 +290,7 @@ DFA_State* DFA_State_Cache::Lookup(const NFA_state_list& nfas, DigestStr* digest
*digest = DigestStr(reinterpret_cast<const unsigned char*>(hash), 16); *digest = DigestStr(reinterpret_cast<const unsigned char*>(hash), 16);
auto entry = states.find(*digest); auto entry = states.find(*digest);
if ( entry == states.end() ) if ( entry == states.end() ) {
{
++misses; ++misses;
return nullptr; return nullptr;
} }
@ -350,14 +301,12 @@ DFA_State* DFA_State_Cache::Lookup(const NFA_state_list& nfas, DigestStr* digest
return entry->second; return entry->second;
} }
DFA_State* DFA_State_Cache::Insert(DFA_State* state, DigestStr digest) DFA_State* DFA_State_Cache::Insert(DFA_State* state, DigestStr digest) {
{
states.emplace(std::move(digest), state); states.emplace(std::move(digest), state);
return state; return state;
} }
void DFA_State_Cache::GetStats(Stats* s) void DFA_State_Cache::GetStats(Stats* s) {
{
s->dfa_states = 0; s->dfa_states = 0;
s->nfa_states = 0; s->nfa_states = 0;
s->computed = 0; s->computed = 0;
@ -366,8 +315,7 @@ void DFA_State_Cache::GetStats(Stats* s)
s->hits = hits; s->hits = hits;
s->misses = misses; s->misses = misses;
for ( const auto& state : states ) for ( const auto& state : states ) {
{
DFA_State* e = state.second; DFA_State* e = state.second;
++s->dfa_states; ++s->dfa_states;
s->nfa_states += e->NFAStateNum(); s->nfa_states += e->NFAStateNum();
@ -376,8 +324,7 @@ void DFA_State_Cache::GetStats(Stats* s)
} }
} }
DFA_Machine::DFA_Machine(NFA_Machine* n, EquivClass* arg_ec) DFA_Machine::DFA_Machine(NFA_Machine* n, EquivClass* arg_ec) {
{
state_count = 0; state_count = 0;
nfa = n; nfa = n;
@ -390,38 +337,29 @@ DFA_Machine::DFA_Machine(NFA_Machine* n, EquivClass* arg_ec)
NFA_state_list* ns = new NFA_state_list; NFA_state_list* ns = new NFA_state_list;
ns->push_back(n->FirstState()); ns->push_back(n->FirstState());
if ( ns->length() > 0 ) if ( ns->length() > 0 ) {
{
NFA_state_list* state_set = epsilon_closure(ns); NFA_state_list* state_set = epsilon_closure(ns);
StateSetToDFA_State(state_set, start_state, ec); StateSetToDFA_State(state_set, start_state, ec);
} }
else else {
{
start_state = nullptr; // Jam start_state = nullptr; // Jam
delete ns; delete ns;
} }
} }
DFA_Machine::~DFA_Machine() DFA_Machine::~DFA_Machine() {
{
delete dfa_state_cache; delete dfa_state_cache;
Unref(nfa); Unref(nfa);
} }
void DFA_Machine::Describe(ODesc* d) const void DFA_Machine::Describe(ODesc* d) const { d->Add("DFA machine"); }
{
d->Add("DFA machine");
}
void DFA_Machine::Dump(FILE* f) void DFA_Machine::Dump(FILE* f) {
{
start_state->Dump(f, this); start_state->Dump(f, this);
start_state->ClearMarks(); start_state->ClearMarks();
} }
bool DFA_Machine::StateSetToDFA_State(NFA_state_list* state_set, DFA_State*& d, bool DFA_Machine::StateSetToDFA_State(NFA_state_list* state_set, DFA_State*& d, const EquivClass* ec) {
const EquivClass* ec)
{
DigestStr digest; DigestStr digest;
d = dfa_state_cache->Lookup(*state_set, &digest); d = dfa_state_cache->Lookup(*state_set, &digest);
@ -430,16 +368,14 @@ bool DFA_Machine::StateSetToDFA_State(NFA_state_list* state_set, DFA_State*& d,
AcceptingSet* accept = new AcceptingSet; AcceptingSet* accept = new AcceptingSet;
for ( int i = 0; i < state_set->length(); ++i ) for ( int i = 0; i < state_set->length(); ++i ) {
{
int acc = (*state_set)[i]->Accept(); int acc = (*state_set)[i]->Accept();
if ( acc != NO_ACCEPT ) if ( acc != NO_ACCEPT )
accept->insert(acc); accept->insert(acc);
} }
if ( accept->empty() ) if ( accept->empty() ) {
{
delete accept; delete accept;
accept = nullptr; accept = nullptr;
} }
@ -450,8 +386,7 @@ bool DFA_Machine::StateSetToDFA_State(NFA_state_list* state_set, DFA_State*& d,
return true; return true;
} }
int DFA_Machine::Rep(int sym) int DFA_Machine::Rep(int sym) {
{
for ( int i = 0; i < NUM_SYM; ++i ) for ( int i = 0; i < NUM_SYM; ++i )
if ( ec->SymEquivClass(i) == sym ) if ( ec->SymEquivClass(i) == sym )
return i; return i;

View file

@ -11,8 +11,7 @@
#include "zeek/Obj.h" #include "zeek/Obj.h"
#include "zeek/RE.h" // for typedef AcceptingSet #include "zeek/RE.h" // for typedef AcceptingSet
namespace zeek::detail namespace zeek::detail {
{
class DFA_State; class DFA_State;
class DFA_Machine; class DFA_Machine;
@ -22,11 +21,9 @@ class DFA_Machine;
#define DFA_UNCOMPUTED_STATE -2 #define DFA_UNCOMPUTED_STATE -2
#define DFA_UNCOMPUTED_STATE_PTR ((DFA_State*)DFA_UNCOMPUTED_STATE) #define DFA_UNCOMPUTED_STATE_PTR ((DFA_State*)DFA_UNCOMPUTED_STATE)
class DFA_State : public Obj class DFA_State : public Obj {
{
public: public:
DFA_State(int state_num, const EquivClass* ec, NFA_state_list* nfa_states, DFA_State(int state_num, const EquivClass* ec, NFA_state_list* nfa_states, AcceptingSet* accept);
AcceptingSet* accept);
~DFA_State() override; ~DFA_State() override;
int StateNum() const { return state_num; } int StateNum() const { return state_num; }
@ -72,8 +69,7 @@ protected:
using DigestStr = std::basic_string<u_char>; using DigestStr = std::basic_string<u_char>;
class DFA_State_Cache class DFA_State_Cache {
{
public: public:
DFA_State_Cache(); DFA_State_Cache();
~DFA_State_Cache(); ~DFA_State_Cache();
@ -86,8 +82,7 @@ public:
int NumEntries() const { return states.size(); } int NumEntries() const { return states.size(); }
struct Stats struct Stats {
{
// Sum of all NFA states // Sum of all NFA states
unsigned int nfa_states; unsigned int nfa_states;
unsigned int dfa_states; unsigned int dfa_states;
@ -108,8 +103,7 @@ private:
std::map<DigestStr, DFA_State*> states; std::map<DigestStr, DFA_State*> states;
}; };
class DFA_Machine : public Obj class DFA_Machine : public Obj {
{
public: public:
DFA_Machine(NFA_Machine* n, EquivClass* ec); DFA_Machine(NFA_Machine* n, EquivClass* ec);
~DFA_Machine() override; ~DFA_Machine() override;
@ -142,8 +136,7 @@ protected:
NFA_Machine* nfa; NFA_Machine* nfa;
}; };
inline DFA_State* DFA_State::Xtion(int sym, DFA_Machine* machine) inline DFA_State* DFA_State::Xtion(int sym, DFA_Machine* machine) {
{
if ( xtions[sym] == DFA_UNCOMPUTED_STATE_PTR ) if ( xtions[sym] == DFA_UNCOMPUTED_STATE_PTR )
return ComputeXtion(sym, machine); return ComputeXtion(sym, machine);
else else

View file

@ -6,11 +6,9 @@
#include "zeek/DNS_Mgr.h" #include "zeek/DNS_Mgr.h"
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
namespace zeek::detail namespace zeek::detail {
{
DNS_Mapping::DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type) DNS_Mapping::DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type) {
{
Init(h); Init(h);
req_host = host; req_host = host;
req_ttl = ttl; req_ttl = ttl;
@ -20,16 +18,14 @@ DNS_Mapping::DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int
names.push_back(std::move(host)); names.push_back(std::move(host));
} }
DNS_Mapping::DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl) DNS_Mapping::DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl) {
{
Init(h); Init(h);
req_addr = addr; req_addr = addr;
req_ttl = ttl; req_ttl = ttl;
req_type = T_PTR; req_type = T_PTR;
} }
DNS_Mapping::DNS_Mapping(FILE* f) DNS_Mapping::DNS_Mapping(FILE* f) {
{
Clear(); Clear();
init_failed = true; init_failed = true;
@ -38,8 +34,7 @@ DNS_Mapping::DNS_Mapping(FILE* f)
char buf[512]; char buf[512];
if ( ! fgets(buf, sizeof(buf), f) ) if ( ! fgets(buf, sizeof(buf), f) ) {
{
no_mapping = true; no_mapping = true;
return; return;
} }
@ -49,9 +44,8 @@ DNS_Mapping::DNS_Mapping(FILE* f)
int failed_local; int failed_local;
int num_addrs; int num_addrs;
if ( sscanf(buf, "%lf %d %512s %d %512s %d %d %" PRIu32, &creation_time, &is_req_host, req_buf, if ( sscanf(buf, "%lf %d %512s %d %512s %d %d %" PRIu32, &creation_time, &is_req_host, req_buf, &failed_local,
&failed_local, name_buf, &req_type, &num_addrs, &req_ttl) != 8 ) name_buf, &req_type, &num_addrs, &req_ttl) != 8 ) {
{
no_mapping = true; no_mapping = true;
return; return;
} }
@ -65,8 +59,7 @@ DNS_Mapping::DNS_Mapping(FILE* f)
names.emplace_back(name_buf); names.emplace_back(name_buf);
for ( int i = 0; i < num_addrs; ++i ) for ( int i = 0; i < num_addrs; ++i ) {
{
if ( ! fgets(buf, sizeof(buf), f) ) if ( ! fgets(buf, sizeof(buf), f) )
return; return;
@ -80,13 +73,11 @@ DNS_Mapping::DNS_Mapping(FILE* f)
init_failed = false; init_failed = false;
} }
ListValPtr DNS_Mapping::Addrs() ListValPtr DNS_Mapping::Addrs() {
{
if ( failed ) if ( failed )
return nullptr; return nullptr;
if ( ! addrs_val ) if ( ! addrs_val ) {
{
addrs_val = make_intrusive<ListVal>(TYPE_ADDR); addrs_val = make_intrusive<ListVal>(TYPE_ADDR);
for ( const auto& addr : addrs ) for ( const auto& addr : addrs )
@ -96,8 +87,7 @@ ListValPtr DNS_Mapping::Addrs()
return addrs_val; return addrs_val;
} }
TableValPtr DNS_Mapping::AddrsSet() TableValPtr DNS_Mapping::AddrsSet() {
{
auto l = Addrs(); auto l = Addrs();
if ( ! l || l->Length() == 0 ) if ( ! l || l->Length() == 0 )
@ -106,8 +96,7 @@ TableValPtr DNS_Mapping::AddrsSet()
return l->ToSetVal(); return l->ToSetVal();
} }
StringValPtr DNS_Mapping::Host() StringValPtr DNS_Mapping::Host() {
{
if ( failed || names.empty() ) if ( failed || names.empty() )
return nullptr; return nullptr;
@ -117,16 +106,14 @@ StringValPtr DNS_Mapping::Host()
return host_val; return host_val;
} }
void DNS_Mapping::Init(struct hostent* h) void DNS_Mapping::Init(struct hostent* h) {
{
no_mapping = false; no_mapping = false;
init_failed = false; init_failed = false;
creation_time = util::current_time(); creation_time = util::current_time();
host_val = nullptr; host_val = nullptr;
addrs_val = nullptr; addrs_val = nullptr;
if ( ! h ) if ( ! h ) {
{
Clear(); Clear();
return; return;
} }
@ -136,10 +123,8 @@ void DNS_Mapping::Init(struct hostent* h)
// TODO: this could easily be expanded to include all of the aliases as well // TODO: this could easily be expanded to include all of the aliases as well
names.emplace_back(h->h_name); names.emplace_back(h->h_name);
if ( h->h_addr_list ) if ( h->h_addr_list ) {
{ for ( int i = 0; h->h_addr_list[i] != NULL; ++i ) {
for ( int i = 0; h->h_addr_list[i] != NULL; ++i )
{
if ( h->h_addrtype == AF_INET ) if ( h->h_addrtype == AF_INET )
addrs.emplace_back(IPv4, (uint32_t*)h->h_addr_list[i], IPAddr::Network); addrs.emplace_back(IPv4, (uint32_t*)h->h_addr_list[i], IPAddr::Network);
else if ( h->h_addrtype == AF_INET6 ) else if ( h->h_addrtype == AF_INET6 )
@ -150,8 +135,7 @@ void DNS_Mapping::Init(struct hostent* h)
failed = false; failed = false;
} }
void DNS_Mapping::Clear() void DNS_Mapping::Clear() {
{
names.clear(); names.clear();
host_val = nullptr; host_val = nullptr;
addrs.clear(); addrs.clear();
@ -161,8 +145,7 @@ void DNS_Mapping::Clear()
failed = true; failed = true;
} }
void DNS_Mapping::Save(FILE* f) const void DNS_Mapping::Save(FILE* f) const {
{
fprintf(f, "%.0f %d %s %d %s %d %zu %" PRIu32 "\n", creation_time, ! req_host.empty(), fprintf(f, "%.0f %d %s %d %s %d %zu %" PRIu32 "\n", creation_time, ! req_host.empty(),
req_host.empty() ? req_addr.AsString().c_str() : req_host.c_str(), failed, req_host.empty() ? req_addr.AsString().c_str() : req_host.c_str(), failed,
names.empty() ? "*" : names[0].c_str(), req_type, addrs.size(), req_ttl); names.empty() ? "*" : names[0].c_str(), req_type, addrs.size(), req_ttl);
@ -171,8 +154,7 @@ void DNS_Mapping::Save(FILE* f) const
fprintf(f, "%s\n", addr.AsString().c_str()); fprintf(f, "%s\n", addr.AsString().c_str());
} }
void DNS_Mapping::Merge(const DNS_MappingPtr& other) void DNS_Mapping::Merge(const DNS_MappingPtr& other) {
{
std::copy(other->names.begin(), other->names.end(), std::back_inserter(names)); std::copy(other->names.begin(), other->names.end(), std::back_inserter(names));
std::copy(other->addrs.begin(), other->addrs.end(), std::back_inserter(addrs)); std::copy(other->addrs.begin(), other->addrs.end(), std::back_inserter(addrs));
} }
@ -181,20 +163,15 @@ void DNS_Mapping::Merge(const DNS_MappingPtr& other)
// allows us to change the structure of the cache without breaking something in DNS_Mgr. // allows us to change the structure of the cache without breaking something in DNS_Mgr.
constexpr int FILE_VERSION = 1; constexpr int FILE_VERSION = 1;
void DNS_Mapping::InitializeCache(FILE* f) void DNS_Mapping::InitializeCache(FILE* f) { fprintf(f, "%d\n", FILE_VERSION); }
{
fprintf(f, "%d\n", FILE_VERSION);
}
bool DNS_Mapping::ValidateCacheVersion(FILE* f) bool DNS_Mapping::ValidateCacheVersion(FILE* f) {
{
char buf[512]; char buf[512];
if ( ! fgets(buf, sizeof(buf), f) ) if ( ! fgets(buf, sizeof(buf), f) )
return false; return false;
int version; int version;
if ( sscanf(buf, "%d", &version) != 1 ) if ( sscanf(buf, "%d", &version) != 1 ) {
{
reporter->Warning("Existing DNS cache did not have correct version, ignoring"); reporter->Warning("Existing DNS cache did not have correct version, ignoring");
return false; return false;
} }
@ -206,8 +183,7 @@ bool DNS_Mapping::ValidateCacheVersion(FILE* f)
////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////////////////
TEST_CASE("dns_mapping init null hostent") TEST_CASE("dns_mapping init null hostent") {
{
DNS_Mapping mapping("www.apple.com", nullptr, 123, T_A); DNS_Mapping mapping("www.apple.com", nullptr, 123, T_A);
CHECK(! mapping.Valid()); CHECK(! mapping.Valid());
@ -216,8 +192,7 @@ TEST_CASE("dns_mapping init null hostent")
CHECK(mapping.Host() == nullptr); CHECK(mapping.Host() == nullptr);
} }
TEST_CASE("dns_mapping init host") TEST_CASE("dns_mapping init host") {
{
IPAddr addr("1.2.3.4"); IPAddr addr("1.2.3.4");
in4_addr in4; in4_addr in4;
addr.CopyIPv4(&in4); addr.CopyIPv4(&in4);
@ -255,8 +230,7 @@ TEST_CASE("dns_mapping init host")
delete[] he.h_name; delete[] he.h_name;
} }
TEST_CASE("dns_mapping init addr") TEST_CASE("dns_mapping init addr") {
{
IPAddr addr("1.2.3.4"); IPAddr addr("1.2.3.4");
in4_addr in4; in4_addr in4;
addr.CopyIPv4(&in4); addr.CopyIPv4(&in4);
@ -294,8 +268,7 @@ TEST_CASE("dns_mapping init addr")
delete[] he.h_name; delete[] he.h_name;
} }
TEST_CASE("dns_mapping save reload") TEST_CASE("dns_mapping save reload") {
{
// TODO: this test uses fmemopen and mkdtemp, both of which aren't available on // TODO: this test uses fmemopen and mkdtemp, both of which aren't available on
// Windows. We'll have to figure out another way to do this test there. // Windows. We'll have to figure out another way to do this test there.
#ifndef _MSC_VER #ifndef _MSC_VER
@ -362,8 +335,7 @@ TEST_CASE("dns_mapping save reload")
#endif #endif
} }
TEST_CASE("dns_mapping multiple addresses") TEST_CASE("dns_mapping multiple addresses") {
{
IPAddr addr("1.2.3.4"); IPAddr addr("1.2.3.4");
in4_addr in4_1; in4_addr in4_1;
addr.CopyIPv4(&in4_1); addr.CopyIPv4(&in4_1);
@ -399,8 +371,7 @@ TEST_CASE("dns_mapping multiple addresses")
delete[] he.h_name; delete[] he.h_name;
} }
TEST_CASE("dns_mapping ipv6") TEST_CASE("dns_mapping ipv6") {
{
IPAddr addr("64:ff9b:1::"); IPAddr addr("64:ff9b:1::");
in6_addr in6; in6_addr in6;
addr.CopyIPv6(&in6); addr.CopyIPv6(&in6);

View file

@ -8,14 +8,12 @@
#include "zeek/IPAddr.h" #include "zeek/IPAddr.h"
#include "zeek/Val.h" #include "zeek/Val.h"
namespace zeek::detail namespace zeek::detail {
{
class DNS_Mapping; class DNS_Mapping;
using DNS_MappingPtr = std::shared_ptr<DNS_Mapping>; using DNS_MappingPtr = std::shared_ptr<DNS_Mapping>;
class DNS_Mapping class DNS_Mapping {
{
public: public:
DNS_Mapping() = delete; DNS_Mapping() = delete;
DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type); DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type);

File diff suppressed because it is too large Load diff

View file

@ -27,14 +27,14 @@ typedef struct ares_channeldata* ares_channel;
#define T_TXT 16 #define T_TXT 16
#endif #endif
namespace zeek namespace zeek {
{
class Val; class Val;
class ListVal; class ListVal;
class TableVal; class TableVal;
class StringVal; class StringVal;
template <class T> class IntrusivePtr; template<class T>
class IntrusivePtr;
using ValPtr = IntrusivePtr<Val>; using ValPtr = IntrusivePtr<Val>;
using ListValPtr = IntrusivePtr<ListVal>; using ListValPtr = IntrusivePtr<ListVal>;
using TableValPtr = IntrusivePtr<TableVal>; using TableValPtr = IntrusivePtr<TableVal>;
@ -42,28 +42,24 @@ using StringValPtr = IntrusivePtr<StringVal>;
} // namespace zeek } // namespace zeek
namespace zeek::detail namespace zeek::detail {
{
class DNS_Mapping; class DNS_Mapping;
using DNS_MappingPtr = std::shared_ptr<DNS_Mapping>; using DNS_MappingPtr = std::shared_ptr<DNS_Mapping>;
class DNS_Request; class DNS_Request;
enum DNS_MgrMode enum DNS_MgrMode {
{
DNS_PRIME, // used to prime the cache DNS_PRIME, // used to prime the cache
DNS_FORCE, // internal error if cache miss DNS_FORCE, // internal error if cache miss
DNS_DEFAULT, // lookup names as they're requested DNS_DEFAULT, // lookup names as they're requested
DNS_FAKE, // don't look up names, just return dummy results DNS_FAKE, // don't look up names, just return dummy results
}; };
class DNS_Mgr : public iosource::IOSource class DNS_Mgr : public iosource::IOSource {
{
public: public:
/** /**
* Base class for callback handling for asynchronous lookups. * Base class for callback handling for asynchronous lookups.
*/ */
class LookupCallback class LookupCallback {
{
public: public:
virtual ~LookupCallback() = default; virtual ~LookupCallback() = default;
@ -200,8 +196,7 @@ public:
*/ */
bool Save(); bool Save();
struct Stats struct Stats {
{
unsigned long requests; // These count only async requests. unsigned long requests; // These count only async requests.
unsigned long successful; unsigned long successful;
unsigned long failed; unsigned long failed;
@ -254,12 +249,9 @@ protected:
friend class LookupCallback; friend class LookupCallback;
friend class DNS_Request; friend class DNS_Request;
StringValPtr LookupAddrInCache(const IPAddr& addr, bool cleanup_expired = false, StringValPtr LookupAddrInCache(const IPAddr& addr, bool cleanup_expired = false, bool check_failed = false);
bool check_failed = false); TableValPtr LookupNameInCache(const std::string& name, bool cleanup_expired = false, bool check_failed = false);
TableValPtr LookupNameInCache(const std::string& name, bool cleanup_expired = false, StringValPtr LookupOtherInCache(const std::string& name, int request_type, bool cleanup_expired = false);
bool check_failed = false);
StringValPtr LookupOtherInCache(const std::string& name, int request_type,
bool cleanup_expired = false);
// Finish the request if we have a result. If not, time it out if // Finish the request if we have a result. If not, time it out if
// requested. // requested.
@ -307,8 +299,7 @@ protected:
using CallbackList = std::list<LookupCallback*>; using CallbackList = std::list<LookupCallback*>;
struct AsyncRequest struct AsyncRequest {
{
double time = 0.0; double time = 0.0;
IPAddr addr; IPAddr addr;
std::string host; std::string host;
@ -316,9 +307,7 @@ protected:
int type = 0; int type = 0;
bool processed = false; bool processed = false;
AsyncRequest(std::string host, int request_type) : host(std::move(host)), type(request_type) AsyncRequest(std::string host, int request_type) : host(std::move(host)), type(request_type) {}
{
}
AsyncRequest(const IPAddr& addr) : addr(addr), type(T_PTR) {} AsyncRequest(const IPAddr& addr) : addr(addr), type(T_PTR) {}
void Resolved(const std::string& name); void Resolved(const std::string& name);
@ -326,8 +315,7 @@ protected:
void Timeout(); void Timeout();
}; };
struct AsyncRequestCompare struct AsyncRequestCompare {
{
bool operator()(const AsyncRequest* a, const AsyncRequest* b) { return a->time > b->time; } bool operator()(const AsyncRequest* a, const AsyncRequest* b) { return a->time > b->time; }
}; };

View file

@ -18,17 +18,12 @@
#include "zeek/Val.h" #include "zeek/Val.h"
#include "zeek/module_util.h" #include "zeek/module_util.h"
namespace zeek::detail namespace zeek::detail {
{
// BreakpointTimer used for time-based breakpoints // BreakpointTimer used for time-based breakpoints
class BreakpointTimer final : public Timer class BreakpointTimer final : public Timer {
{
public: public:
BreakpointTimer(DbgBreakpoint* arg_bp, double arg_t) : Timer(arg_t, TIMER_BREAKPOINT) BreakpointTimer(DbgBreakpoint* arg_bp, double arg_t) : Timer(arg_t, TIMER_BREAKPOINT) { bp = arg_bp; }
{
bp = arg_bp;
}
void Dispatch(double t, bool is_expire) override; void Dispatch(double t, bool is_expire) override;
@ -36,16 +31,14 @@ protected:
DbgBreakpoint* bp; DbgBreakpoint* bp;
}; };
void BreakpointTimer::Dispatch(double t, bool is_expire) void BreakpointTimer::Dispatch(double t, bool is_expire) {
{
if ( is_expire ) if ( is_expire )
return; return;
bp->ShouldBreak(t); bp->ShouldBreak(t);
} }
DbgBreakpoint::DbgBreakpoint() DbgBreakpoint::DbgBreakpoint() {
{
kind = BP_STMT; kind = BP_STMT;
enabled = temporary = false; enabled = temporary = false;
@ -61,14 +54,12 @@ DbgBreakpoint::DbgBreakpoint()
source_line = 0; source_line = 0;
} }
DbgBreakpoint::~DbgBreakpoint() DbgBreakpoint::~DbgBreakpoint() {
{
SetEnable(false); // clean up any active state SetEnable(false); // clean up any active state
RemoveFromGlobalMap(); RemoveFromGlobalMap();
} }
bool DbgBreakpoint::SetEnable(bool do_enable) bool DbgBreakpoint::SetEnable(bool do_enable) {
{
bool old_value = enabled; bool old_value = enabled;
enabled = do_enable; enabled = do_enable;
@ -82,23 +73,19 @@ bool DbgBreakpoint::SetEnable(bool do_enable)
return old_value; return old_value;
} }
void DbgBreakpoint::AddToGlobalMap() void DbgBreakpoint::AddToGlobalMap() {
{
// Make sure it's not there already. // Make sure it's not there already.
RemoveFromGlobalMap(); RemoveFromGlobalMap();
g_debugger_state.breakpoint_map.insert(BPMapType::value_type(at_stmt, this)); g_debugger_state.breakpoint_map.insert(BPMapType::value_type(at_stmt, this));
} }
void DbgBreakpoint::RemoveFromGlobalMap() void DbgBreakpoint::RemoveFromGlobalMap() {
{
std::pair<BPMapType::iterator, BPMapType::iterator> p; std::pair<BPMapType::iterator, BPMapType::iterator> p;
p = g_debugger_state.breakpoint_map.equal_range(at_stmt); p = g_debugger_state.breakpoint_map.equal_range(at_stmt);
for ( BPMapType::iterator i = p.first; i != p.second; ) for ( BPMapType::iterator i = p.first; i != p.second; ) {
{ if ( i->second == this ) {
if ( i->second == this )
{
BPMapType::iterator next = i; BPMapType::iterator next = i;
++next; ++next;
g_debugger_state.breakpoint_map.erase(i); g_debugger_state.breakpoint_map.erase(i);
@ -109,34 +96,28 @@ void DbgBreakpoint::RemoveFromGlobalMap()
} }
} }
void DbgBreakpoint::AddToStmt() void DbgBreakpoint::AddToStmt() {
{
if ( at_stmt ) if ( at_stmt )
at_stmt->IncrBPCount(); at_stmt->IncrBPCount();
} }
void DbgBreakpoint::RemoveFromStmt() void DbgBreakpoint::RemoveFromStmt() {
{
if ( at_stmt ) if ( at_stmt )
at_stmt->DecrBPCount(); at_stmt->DecrBPCount();
} }
bool DbgBreakpoint::SetLocation(ParseLocationRec plr, std::string_view loc_str) bool DbgBreakpoint::SetLocation(ParseLocationRec plr, std::string_view loc_str) {
{ if ( plr.type == PLR_UNKNOWN ) {
if ( plr.type == PLR_UNKNOWN )
{
debug_msg("Breakpoint specifier invalid or operation canceled.\n"); debug_msg("Breakpoint specifier invalid or operation canceled.\n");
return false; return false;
} }
if ( plr.type == PLR_FILE_AND_LINE ) if ( plr.type == PLR_FILE_AND_LINE ) {
{
kind = BP_LINE; kind = BP_LINE;
source_filename = plr.filename; source_filename = plr.filename;
source_line = plr.line; source_line = plr.line;
if ( ! plr.stmt ) if ( ! plr.stmt ) {
{
debug_msg("No statement at that line.\n"); debug_msg("No statement at that line.\n");
return false; return false;
} }
@ -147,15 +128,13 @@ bool DbgBreakpoint::SetLocation(ParseLocationRec plr, std::string_view loc_str)
debug_msg("Breakpoint %d set at %s\n", GetID(), Description()); debug_msg("Breakpoint %d set at %s\n", GetID(), Description());
} }
else if ( plr.type == PLR_FUNCTION ) else if ( plr.type == PLR_FUNCTION ) {
{
std::string loc_s(loc_str); std::string loc_s(loc_str);
kind = BP_FUNC; kind = BP_FUNC;
function_name = make_full_var_name(current_module.c_str(), loc_s.c_str()); function_name = make_full_var_name(current_module.c_str(), loc_s.c_str());
at_stmt = plr.stmt; at_stmt = plr.stmt;
const Location* loc = at_stmt->GetLocationInfo(); const Location* loc = at_stmt->GetLocationInfo();
snprintf(description, sizeof(description), "%s at %s:%d", function_name.c_str(), snprintf(description, sizeof(description), "%s at %s:%d", function_name.c_str(), loc->filename, loc->last_line);
loc->filename, loc->last_line);
debug_msg("Breakpoint %d set at %s\n", GetID(), Description()); debug_msg("Breakpoint %d set at %s\n", GetID(), Description());
} }
@ -165,8 +144,7 @@ bool DbgBreakpoint::SetLocation(ParseLocationRec plr, std::string_view loc_str)
return true; return true;
} }
bool DbgBreakpoint::SetLocation(Stmt* stmt) bool DbgBreakpoint::SetLocation(Stmt* stmt) {
{
if ( ! stmt ) if ( ! stmt )
return false; return false;
@ -184,8 +162,7 @@ bool DbgBreakpoint::SetLocation(Stmt* stmt)
return true; return true;
} }
bool DbgBreakpoint::SetLocation(double t) bool DbgBreakpoint::SetLocation(double t) {
{
debug_msg("SetLocation(time) has not been debugged."); debug_msg("SetLocation(time) has not been debugged.");
return false; return false;
@ -198,15 +175,11 @@ bool DbgBreakpoint::SetLocation(double t)
return false; return false;
} }
bool DbgBreakpoint::Reset() bool DbgBreakpoint::Reset() {
{
ParseLocationRec plr; ParseLocationRec plr;
switch ( kind ) switch ( kind ) {
{ case BP_TIME: debug_msg("Time-based breakpoints not yet supported.\n"); break;
case BP_TIME:
debug_msg("Time-based breakpoints not yet supported.\n");
break;
case BP_FUNC: case BP_FUNC:
case BP_STMT: case BP_STMT:
@ -223,59 +196,48 @@ bool DbgBreakpoint::Reset()
return false; return false;
} }
bool DbgBreakpoint::SetCondition(const std::string& new_condition) bool DbgBreakpoint::SetCondition(const std::string& new_condition) {
{
condition = new_condition; condition = new_condition;
return true; return true;
} }
bool DbgBreakpoint::SetRepeatCount(int count) bool DbgBreakpoint::SetRepeatCount(int count) {
{
repeat_count = count; repeat_count = count;
return true; return true;
} }
BreakCode DbgBreakpoint::HasHit() BreakCode DbgBreakpoint::HasHit() {
{ if ( temporary ) {
if ( temporary )
{
SetEnable(false); SetEnable(false);
return BC_HIT_AND_DELETE; return BC_HIT_AND_DELETE;
} }
if ( condition.size() ) if ( condition.size() ) {
{
// TODO: ### evaluate using debugger frame too // TODO: ### evaluate using debugger frame too
auto yes = dbg_eval_expr(condition.c_str()); auto yes = dbg_eval_expr(condition.c_str());
if ( ! yes ) if ( ! yes ) {
{ debug_msg("Breakpoint condition '%s' invalid, removing condition.\n", condition.c_str());
debug_msg("Breakpoint condition '%s' invalid, removing condition.\n",
condition.c_str());
SetCondition(""); SetCondition("");
PrintHitMsg(); PrintHitMsg();
return BC_HIT; return BC_HIT;
} }
if ( ! IsIntegral(yes->GetType()->Tag()) && ! IsBool(yes->GetType()->Tag()) ) if ( ! IsIntegral(yes->GetType()->Tag()) && ! IsBool(yes->GetType()->Tag()) ) {
{
PrintHitMsg(); PrintHitMsg();
debug_msg("Breakpoint condition should return an integral type"); debug_msg("Breakpoint condition should return an integral type");
return BC_HIT_AND_DELETE; return BC_HIT_AND_DELETE;
} }
yes->CoerceToInt(); yes->CoerceToInt();
if ( yes->IsZero() ) if ( yes->IsZero() ) {
{
return BC_NO_HIT; return BC_NO_HIT;
} }
} }
int repcount = GetRepeatCount(); int repcount = GetRepeatCount();
if ( repcount ) if ( repcount ) {
{ if ( ++hit_count == repcount ) {
if ( ++hit_count == repcount )
{
hit_count = 0; hit_count = 0;
PrintHitMsg(); PrintHitMsg();
return BC_HIT; return BC_HIT;
@ -288,13 +250,11 @@ BreakCode DbgBreakpoint::HasHit()
return BC_HIT; return BC_HIT;
} }
BreakCode DbgBreakpoint::ShouldBreak(Stmt* s) BreakCode DbgBreakpoint::ShouldBreak(Stmt* s) {
{
if ( ! IsEnabled() ) if ( ! IsEnabled() )
return BC_NO_HIT; return BC_NO_HIT;
switch ( kind ) switch ( kind ) {
{
case BP_STMT: case BP_STMT:
case BP_FUNC: case BP_FUNC:
if ( at_stmt != s ) if ( at_stmt != s )
@ -302,15 +262,12 @@ BreakCode DbgBreakpoint::ShouldBreak(Stmt* s)
break; break;
case BP_LINE: case BP_LINE:
assert(s->GetLocationInfo()->first_line <= source_line && assert(s->GetLocationInfo()->first_line <= source_line && s->GetLocationInfo()->last_line >= source_line);
s->GetLocationInfo()->last_line >= source_line);
break; break;
case BP_TIME: case BP_TIME: assert(false);
assert(false);
default: default: reporter->InternalError("Invalid breakpoint type in DbgBreakpoint::ShouldBreak");
reporter->InternalError("Invalid breakpoint type in DbgBreakpoint::ShouldBreak");
} }
// If we got here, that means that the breakpoint could hit, // If we got here, that means that the breakpoint could hit,
@ -323,8 +280,7 @@ BreakCode DbgBreakpoint::ShouldBreak(Stmt* s)
return code; return code;
} }
BreakCode DbgBreakpoint::ShouldBreak(double t) BreakCode DbgBreakpoint::ShouldBreak(double t) {
{
if ( kind != BP_TIME ) if ( kind != BP_TIME )
reporter->InternalError("Calling ShouldBreak(time) on a non-time breakpoint"); reporter->InternalError("Calling ShouldBreak(time) on a non-time breakpoint");
@ -341,14 +297,11 @@ BreakCode DbgBreakpoint::ShouldBreak(double t)
return code; return code;
} }
void DbgBreakpoint::PrintHitMsg() void DbgBreakpoint::PrintHitMsg() {
{ switch ( kind ) {
switch ( kind )
{
case BP_STMT: case BP_STMT:
case BP_FUNC: case BP_FUNC:
case BP_LINE: case BP_LINE: {
{
ODesc d; ODesc d;
Frame* f = g_frame_stack.back(); Frame* f = g_frame_stack.back();
const ScriptFunc* func = f->GetFunction(); const ScriptFunc* func = f->GetFunction();
@ -358,16 +311,13 @@ void DbgBreakpoint::PrintHitMsg()
const Location* loc = at_stmt->GetLocationInfo(); const Location* loc = at_stmt->GetLocationInfo();
debug_msg("Breakpoint %d, %s at %s:%d\n", GetID(), d.Description(), loc->filename, debug_msg("Breakpoint %d, %s at %s:%d\n", GetID(), d.Description(), loc->filename, loc->first_line);
loc->first_line);
} }
return; return;
case BP_TIME: case BP_TIME: assert(false);
assert(false);
default: default: reporter->InternalError("Missed a case in DbgBreakpoint::PrintHitMsg\n");
reporter->InternalError("Missed a case in DbgBreakpoint::PrintHitMsg\n");
} }
} }

View file

@ -6,27 +6,14 @@
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek::detail namespace zeek::detail {
{
class Stmt; class Stmt;
class ParseLocationRec; class ParseLocationRec;
enum BreakCode enum BreakCode { BC_NO_HIT, BC_HIT, BC_HIT_AND_DELETE };
{ class DbgBreakpoint {
BC_NO_HIT, enum Kind { BP_STMT = 0, BP_FUNC, BP_LINE, BP_TIME };
BC_HIT,
BC_HIT_AND_DELETE
};
class DbgBreakpoint
{
enum Kind
{
BP_STMT = 0,
BP_FUNC,
BP_LINE,
BP_TIME
};
public: public:
DbgBreakpoint(); DbgBreakpoint();

View file

@ -2,20 +2,17 @@
#pragma once #pragma once
namespace zeek::detail namespace zeek::detail {
{
class Expr; class Expr;
// Automatic displays: display these at each stoppage. // Automatic displays: display these at each stoppage.
class DbgDisplay class DbgDisplay {
{
public: public:
DbgDisplay(Expr* expr_to_display); DbgDisplay(Expr* expr_to_display);
bool IsEnabled() { return enabled; } bool IsEnabled() { return enabled; }
bool SetEnable(bool do_enable) bool SetEnable(bool do_enable) {
{
bool old_value = enabled; bool old_value = enabled;
enabled = do_enable; enabled = do_enable;
return old_value; return old_value;

View file

@ -7,18 +7,11 @@
#include "zeek/Debug.h" #include "zeek/Debug.h"
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
namespace zeek::detail namespace zeek::detail {
{
// Support classes // Support classes
DbgWatch::DbgWatch(zeek::Obj* var_to_watch) DbgWatch::DbgWatch(zeek::Obj* var_to_watch) { reporter->InternalError("DbgWatch unimplemented"); }
{
reporter->InternalError("DbgWatch unimplemented");
}
DbgWatch::DbgWatch(Expr* expr_to_watch) DbgWatch::DbgWatch(Expr* expr_to_watch) { reporter->InternalError("DbgWatch unimplemented"); }
{
reporter->InternalError("DbgWatch unimplemented");
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -4,18 +4,15 @@
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek namespace zeek {
{
class Obj; class Obj;
} }
namespace zeek::detail namespace zeek::detail {
{
class Expr; class Expr;
class DbgWatch class DbgWatch {
{
public: public:
explicit DbgWatch(Obj* var_to_watch); explicit DbgWatch(Obj* var_to_watch);
explicit DbgWatch(Expr* expr_to_watch); explicit DbgWatch(Expr* expr_to_watch);

View file

@ -32,8 +32,7 @@
#include "zeek/module_util.h" #include "zeek/module_util.h"
#include "zeek/util.h" #include "zeek/util.h"
extern "C" extern "C" {
{
#include "zeek/3rdparty/setsignal.h" #include "zeek/3rdparty/setsignal.h"
} }
@ -66,11 +65,9 @@ extern YYLTYPE yylloc; // holds start line and column of token
extern int line_number; extern int line_number;
extern const char* filename; extern const char* filename;
namespace zeek::detail namespace zeek::detail {
{
DebuggerState::DebuggerState() DebuggerState::DebuggerState() {
{
next_bp_id = next_watch_id = next_display_id = 1; next_bp_id = next_watch_id = next_display_id = 1;
BreakBeforeNextStmt(false); BreakBeforeNextStmt(false);
curr_frame_idx = 0; curr_frame_idx = 0;
@ -81,13 +78,9 @@ DebuggerState::DebuggerState()
dbg_locals = new Frame(1024, /* func = */ nullptr, /* fn_args = */ nullptr); dbg_locals = new Frame(1024, /* func = */ nullptr, /* fn_args = */ nullptr);
} }
DebuggerState::~DebuggerState() DebuggerState::~DebuggerState() { Unref(dbg_locals); }
{
Unref(dbg_locals);
}
bool StmtLocMapping::StartsAfter(const StmtLocMapping* m2) bool StmtLocMapping::StartsAfter(const StmtLocMapping* m2) {
{
if ( ! m2 ) if ( ! m2 )
reporter->InternalError("Assertion failed: m2 != 0"); reporter->InternalError("Assertion failed: m2 != 0");
@ -96,8 +89,7 @@ bool StmtLocMapping::StartsAfter(const StmtLocMapping* m2)
} }
// Generic debug message output. // Generic debug message output.
int debug_msg(const char* fmt, ...) int debug_msg(const char* fmt, ...) {
{
va_list args; va_list args;
int retval; int retval;
@ -110,8 +102,7 @@ int debug_msg(const char* fmt, ...)
// Trace message output // Trace message output
FILE* TraceState::SetTraceFile(const char* trace_filename) FILE* TraceState::SetTraceFile(const char* trace_filename) {
{
FILE* newfile; FILE* newfile;
if ( util::streq(trace_filename, "-") ) if ( util::streq(trace_filename, "-") )
@ -120,12 +111,10 @@ FILE* TraceState::SetTraceFile(const char* trace_filename)
newfile = fopen(trace_filename, "w"); newfile = fopen(trace_filename, "w");
FILE* oldfile = trace_file; FILE* oldfile = trace_file;
if ( newfile ) if ( newfile ) {
{
trace_file = newfile; trace_file = newfile;
} }
else else {
{
fprintf(stderr, "Unable to open trace file %s\n", trace_filename); fprintf(stderr, "Unable to open trace file %s\n", trace_filename);
trace_file = nullptr; trace_file = nullptr;
} }
@ -133,20 +122,17 @@ FILE* TraceState::SetTraceFile(const char* trace_filename)
return oldfile; return oldfile;
} }
void TraceState::TraceOn() void TraceState::TraceOn() {
{
fprintf(stderr, "Execution tracing ON.\n"); fprintf(stderr, "Execution tracing ON.\n");
dbgtrace = true; dbgtrace = true;
} }
void TraceState::TraceOff() void TraceState::TraceOff() {
{
fprintf(stderr, "Execution tracing OFF.\n"); fprintf(stderr, "Execution tracing OFF.\n");
dbgtrace = false; dbgtrace = false;
} }
int TraceState::LogTrace(const char* fmt, ...) int TraceState::LogTrace(const char* fmt, ...) {
{
va_list args; va_list args;
int retval; int retval;
@ -159,21 +145,18 @@ int TraceState::LogTrace(const char* fmt, ...)
Location loc; Location loc;
loc.filename = nullptr; loc.filename = nullptr;
if ( g_frame_stack.size() > 0 && g_frame_stack.back() ) if ( g_frame_stack.size() > 0 && g_frame_stack.back() ) {
{
stmt = g_frame_stack.back()->GetNextStmt(); stmt = g_frame_stack.back()->GetNextStmt();
if ( stmt ) if ( stmt )
loc = *stmt->GetLocationInfo(); loc = *stmt->GetLocationInfo();
else else {
{
const ScriptFunc* f = g_frame_stack.back()->GetFunction(); const ScriptFunc* f = g_frame_stack.back()->GetFunction();
if ( f ) if ( f )
loc = *f->GetLocationInfo(); loc = *f->GetLocationInfo();
} }
} }
if ( ! loc.filename ) if ( ! loc.filename ) {
{
loc.filename = util::copy_string("<no filename>"); loc.filename = util::copy_string("<no filename>");
loc.last_line = 0; loc.last_line = 0;
} }
@ -193,17 +176,14 @@ int TraceState::LogTrace(const char* fmt, ...)
} }
// Helper functions. // Helper functions.
void get_first_statement(Stmt* list, Stmt*& first, Location& loc) void get_first_statement(Stmt* list, Stmt*& first, Location& loc) {
{ if ( ! list ) {
if ( ! list )
{
first = nullptr; first = nullptr;
return; return;
} }
first = list; first = list;
while ( first->Tag() == STMT_LIST ) while ( first->Tag() == STMT_LIST ) {
{
if ( first->AsStmtList()->Stmts()[0] ) if ( first->AsStmtList()->Stmts()[0] )
first = first->AsStmtList()->Stmts()[0].get(); first = first->AsStmtList()->Stmts()[0].get();
else else
@ -214,27 +194,23 @@ void get_first_statement(Stmt* list, Stmt*& first, Location& loc)
} }
static void parse_function_name(vector<ParseLocationRec>& result, ParseLocationRec& plr, static void parse_function_name(vector<ParseLocationRec>& result, ParseLocationRec& plr,
const string& s) const string& s) { // function name
{ // function name
const auto& id = lookup_ID(s.c_str(), current_module.c_str()); const auto& id = lookup_ID(s.c_str(), current_module.c_str());
if ( ! id ) if ( ! id ) {
{
string fullname = make_full_var_name(current_module.c_str(), s.c_str()); string fullname = make_full_var_name(current_module.c_str(), s.c_str());
debug_msg("Function %s not defined.\n", fullname.c_str()); debug_msg("Function %s not defined.\n", fullname.c_str());
plr.type = PLR_UNKNOWN; plr.type = PLR_UNKNOWN;
return; return;
} }
if ( ! id->GetType()->AsFuncType() ) if ( ! id->GetType()->AsFuncType() ) {
{
debug_msg("Function %s not declared.\n", id->Name()); debug_msg("Function %s not declared.\n", id->Name());
plr.type = PLR_UNKNOWN; plr.type = PLR_UNKNOWN;
return; return;
} }
if ( ! id->HasVal() ) if ( ! id->HasVal() ) {
{
debug_msg("Function %s declared but not defined.\n", id->Name()); debug_msg("Function %s declared but not defined.\n", id->Name());
plr.type = PLR_UNKNOWN; plr.type = PLR_UNKNOWN;
return; return;
@ -243,8 +219,7 @@ static void parse_function_name(vector<ParseLocationRec>& result, ParseLocationR
const Func* func = id->GetVal()->AsFunc(); const Func* func = id->GetVal()->AsFunc();
const vector<Func::Body>& bodies = func->GetBodies(); const vector<Func::Body>& bodies = func->GetBodies();
if ( bodies.size() == 0 ) if ( bodies.size() == 0 ) {
{
debug_msg("Function %s is a built-in function\n", id->Name()); debug_msg("Function %s is a built-in function\n", id->Name());
plr.type = PLR_UNKNOWN; plr.type = PLR_UNKNOWN;
return; return;
@ -254,14 +229,12 @@ static void parse_function_name(vector<ParseLocationRec>& result, ParseLocationR
if ( bodies.size() == 1 ) if ( bodies.size() == 1 )
body = bodies[0].stmts.get(); body = bodies[0].stmts.get();
else else {
{ while ( true ) {
while ( true ) debug_msg(
{ "There are multiple definitions of that event handler.\n"
debug_msg("There are multiple definitions of that event handler.\n"
"Please choose one of the following options:\n"); "Please choose one of the following options:\n");
for ( unsigned int i = 0; i < bodies.size(); ++i ) for ( unsigned int i = 0; i < bodies.size(); ++i ) {
{
Stmt* first; Stmt* first;
Location stmt_loc; Location stmt_loc;
get_first_statement(bodies[i].stmts.get(), first, stmt_loc); get_first_statement(bodies[i].stmts.get(), first, stmt_loc);
@ -273,8 +246,7 @@ static void parse_function_name(vector<ParseLocationRec>& result, ParseLocationR
debug_msg("Enter your choice: "); debug_msg("Enter your choice: ");
char charinput[256]; char charinput[256];
if ( ! fgets(charinput, sizeof(charinput) - 1, stdin) ) if ( ! fgets(charinput, sizeof(charinput) - 1, stdin) ) {
{
plr.type = PLR_UNKNOWN; plr.type = PLR_UNKNOWN;
return; return;
} }
@ -287,15 +259,13 @@ static void parse_function_name(vector<ParseLocationRec>& result, ParseLocationR
if ( input == "a" ) if ( input == "a" )
break; break;
if ( input == "n" ) if ( input == "n" ) {
{
plr.type = PLR_UNKNOWN; plr.type = PLR_UNKNOWN;
return; return;
} }
int option = atoi(input.c_str()); int option = atoi(input.c_str());
if ( option > 0 && option <= (int)bodies.size() ) if ( option > 0 && option <= (int)bodies.size() ) {
{
body = bodies[option - 1].stmts.get(); body = bodies[option - 1].stmts.get();
break; break;
} }
@ -308,24 +278,20 @@ static void parse_function_name(vector<ParseLocationRec>& result, ParseLocationR
Stmt* first; Stmt* first;
Location stmt_loc; Location stmt_loc;
if ( body ) if ( body ) {
{
get_first_statement(body, first, stmt_loc); get_first_statement(body, first, stmt_loc);
if ( first ) if ( first ) {
{
plr.stmt = first; plr.stmt = first;
plr.filename = stmt_loc.filename; plr.filename = stmt_loc.filename;
plr.line = stmt_loc.last_line; plr.line = stmt_loc.last_line;
} }
} }
else else {
{
result.pop_back(); result.pop_back();
ParseLocationRec result_plr; ParseLocationRec result_plr;
for ( const auto& body : bodies ) for ( const auto& body : bodies ) {
{
get_first_statement(body.stmts.get(), first, stmt_loc); get_first_statement(body.stmts.get(), first, stmt_loc);
if ( ! first ) if ( ! first )
continue; continue;
@ -339,8 +305,7 @@ static void parse_function_name(vector<ParseLocationRec>& result, ParseLocationR
} }
} }
vector<ParseLocationRec> parse_location_string(const string& s) vector<ParseLocationRec> parse_location_string(const string& s) {
{
vector<ParseLocationRec> result; vector<ParseLocationRec> result;
result.emplace_back(); result.emplace_back();
ParseLocationRec& plr = result[0]; ParseLocationRec& plr = result[0];
@ -350,21 +315,18 @@ vector<ParseLocationRec> parse_location_string(const string& s)
// up the line number to find the corresponding statement. // up the line number to find the corresponding statement.
std::string loc_filename; std::string loc_filename;
if ( sscanf(s.c_str(), "%d", &plr.line) ) if ( sscanf(s.c_str(), "%d", &plr.line) ) { // just a line number (implicitly referring to the current file)
{ // just a line number (implicitly referring to the current file)
loc_filename = g_debugger_state.last_loc.filename; loc_filename = g_debugger_state.last_loc.filename;
plr.type = PLR_FILE_AND_LINE; plr.type = PLR_FILE_AND_LINE;
} }
else else {
{
string::size_type pos_colon = s.find(':'); string::size_type pos_colon = s.find(':');
string::size_type pos_dblcolon = s.find("::"); string::size_type pos_dblcolon = s.find("::");
if ( pos_colon == string::npos || pos_dblcolon != string::npos ) if ( pos_colon == string::npos || pos_dblcolon != string::npos )
parse_function_name(result, plr, s); parse_function_name(result, plr, s);
else else { // file:line
{ // file:line
string policy_filename = s.substr(0, pos_colon); string policy_filename = s.substr(0, pos_colon);
string line_string = s.substr(pos_colon + 1, s.length() - pos_colon); string line_string = s.substr(pos_colon + 1, s.length() - pos_colon);
@ -373,8 +335,7 @@ vector<ParseLocationRec> parse_location_string(const string& s)
string path(util::find_script_file(policy_filename, util::zeek_path())); string path(util::find_script_file(policy_filename, util::zeek_path()));
if ( path.empty() ) if ( path.empty() ) {
{
debug_msg("No such policy file: %s.\n", policy_filename.c_str()); debug_msg("No such policy file: %s.\n", policy_filename.c_str());
plr.type = PLR_UNKNOWN; plr.type = PLR_UNKNOWN;
return result; return result;
@ -385,30 +346,25 @@ vector<ParseLocationRec> parse_location_string(const string& s)
} }
} }
if ( plr.type == PLR_FILE_AND_LINE ) if ( plr.type == PLR_FILE_AND_LINE ) {
{
auto iter = g_dbgfilemaps.find(loc_filename); auto iter = g_dbgfilemaps.find(loc_filename);
if ( iter == g_dbgfilemaps.end() ) if ( iter == g_dbgfilemaps.end() )
reporter->InternalError("Policy file %s should have been loaded\n", reporter->InternalError("Policy file %s should have been loaded\n", loc_filename.data());
loc_filename.data());
if ( plr.line > how_many_lines_in(loc_filename.data()) ) if ( plr.line > how_many_lines_in(loc_filename.data()) ) {
{
debug_msg("No line %d in %s.\n", plr.line, loc_filename.data()); debug_msg("No line %d in %s.\n", plr.line, loc_filename.data());
plr.type = PLR_UNKNOWN; plr.type = PLR_UNKNOWN;
return result; return result;
} }
StmtLocMapping* hit = nullptr; StmtLocMapping* hit = nullptr;
for ( const auto entry : *(iter->second) ) for ( const auto entry : *(iter->second) ) {
{
plr.filename = entry->Loc().filename; plr.filename = entry->Loc().filename;
if ( entry->Loc().first_line > plr.line ) if ( entry->Loc().first_line > plr.line )
break; break;
if ( plr.line >= entry->Loc().first_line && plr.line <= entry->Loc().last_line ) if ( plr.line >= entry->Loc().first_line && plr.line <= entry->Loc().last_line ) {
{
hit = entry; hit = entry;
break; break;
} }
@ -431,8 +387,7 @@ static int dbg_dispatch_cmd(DebugCmd cmd_code, const vector<string>& args);
void using_history(void); void using_history(void);
static bool init_readline() static bool init_readline() {
{
// ### Set up custom completion. // ### Set up custom completion.
rl_outstream = stderr; rl_outstream = stderr;
@ -443,14 +398,12 @@ static bool init_readline()
#endif #endif
void break_signal(int) void break_signal(int) {
{
g_debugger_state.BreakBeforeNextStmt(true); g_debugger_state.BreakBeforeNextStmt(true);
g_debugger_state.BreakFromSignal(true); g_debugger_state.BreakFromSignal(true);
} }
int dbg_init_debugger(const char* cmdfile) int dbg_init_debugger(const char* cmdfile) {
{
if ( ! g_policy_debug ) if ( ! g_policy_debug )
return 0; // probably shouldn't have been called return 0; // probably shouldn't have been called
@ -474,8 +427,7 @@ int dbg_init_debugger(const char* cmdfile)
return 1; return 1;
} }
int dbg_shutdown_debugger() int dbg_shutdown_debugger() {
{
// ### TODO: Remove signal handlers // ### TODO: Remove signal handlers
return 1; return 1;
} }
@ -488,23 +440,19 @@ int dbg_shutdown_debugger()
// Parse the string into individual tokens, similarly to how shell // Parse the string into individual tokens, similarly to how shell
// would do it. // would do it.
void tokenize(const char* cstr, string& operation, vector<string>& arguments) void tokenize(const char* cstr, string& operation, vector<string>& arguments) {
{
int num_tokens = 0; int num_tokens = 0;
char delim = '\0'; char delim = '\0';
const string str(cstr); const string str(cstr);
for ( int i = 0; i < (signed int)str.length(); ++i ) for ( int i = 0; i < (signed int)str.length(); ++i ) {
{
while ( isspace((unsigned char)str[i]) ) while ( isspace((unsigned char)str[i]) )
++i; ++i;
int start = i; int start = i;
for ( ; str[i]; ++i ) for ( ; str[i]; ++i ) {
{ if ( str[i] == '\\' ) {
if ( str[i] == '\\' )
{
if ( i < (signed int)str.length() ) if ( i < (signed int)str.length() )
++i; ++i;
} }
@ -515,8 +463,7 @@ void tokenize(const char* cstr, string& operation, vector<string>& arguments)
else if ( ! delim && (str[i] == '\'' || str[i] == '"') ) else if ( ! delim && (str[i] == '\'' || str[i] == '"') )
delim = str[i]; delim = str[i];
else if ( delim && str[i] == delim ) else if ( delim && str[i] == delim ) {
{
delim = '\0'; delim = '\0';
++i; ++i;
break; break;
@ -538,8 +485,7 @@ void tokenize(const char* cstr, string& operation, vector<string>& arguments)
} }
// Given a command string, parse it and send the command to be dispatched. // Given a command string, parse it and send the command to be dispatched.
int dbg_execute_command(const char* cmd) int dbg_execute_command(const char* cmd) {
{
bool matched_history = false; bool matched_history = false;
if ( ! cmd ) if ( ! cmd )
@ -549,16 +495,14 @@ int dbg_execute_command(const char* cmd)
{ {
#ifdef HAVE_READLINE #ifdef HAVE_READLINE
int i; int i;
for ( i = history_length; i >= 1; --i ) for ( i = history_length; i >= 1; --i ) {
{
HIST_ENTRY* entry = history_get(i); HIST_ENTRY* entry = history_get(i);
if ( ! entry ) if ( ! entry )
return 0; return 0;
const DebugCmdInfo* info = (const DebugCmdInfo*)entry->data; const DebugCmdInfo* info = (const DebugCmdInfo*)entry->data;
if ( info && info->Repeatable() ) if ( info && info->Repeatable() ) {
{
cmd = entry->line; cmd = entry->line;
matched_history = true; matched_history = true;
break; break;
@ -583,14 +527,12 @@ int dbg_execute_command(const char* cmd)
auto matching_cmds = matching_cmds_buf.get(); auto matching_cmds = matching_cmds_buf.get();
int num_matches = find_all_matching_cmds(opstring, matching_cmds); int num_matches = find_all_matching_cmds(opstring, matching_cmds);
if ( ! num_matches ) if ( ! num_matches ) {
{
debug_msg("No Matching command for '%s'.\n", opstring.c_str()); debug_msg("No Matching command for '%s'.\n", opstring.c_str());
return 0; return 0;
} }
if ( num_matches > 1 ) if ( num_matches > 1 ) {
{
debug_msg("Ambiguous command; could be\n"); debug_msg("Ambiguous command; could be\n");
for ( int i = 0; i < num_debug_cmds(); ++i ) for ( int i = 0; i < num_debug_cmds(); ++i )
@ -603,16 +545,14 @@ int dbg_execute_command(const char* cmd)
// Matched exactly one command: find out which one. // Matched exactly one command: find out which one.
DebugCmd cmd_code = dcInvalid; DebugCmd cmd_code = dcInvalid;
for ( int i = 0; i < num_debug_cmds(); ++i ) for ( int i = 0; i < num_debug_cmds(); ++i )
if ( matching_cmds[i] ) if ( matching_cmds[i] ) {
{
cmd_code = (DebugCmd)i; cmd_code = (DebugCmd)i;
break; break;
} }
#ifdef HAVE_READLINE #ifdef HAVE_READLINE
// Insert command into history. // Insert command into history.
if ( ! matched_history && cmd && *cmd ) if ( ! matched_history && cmd && *cmd ) {
{
/* The prototype for add_history(), at least under MacOS, /* The prototype for add_history(), at least under MacOS,
* has it taking a char* rather than a const char*. * has it taking a char* rather than a const char*.
* But documentation at * But documentation at
@ -644,17 +584,11 @@ int dbg_execute_command(const char* cmd)
} }
// Call the appropriate function for the command. // Call the appropriate function for the command.
static int dbg_dispatch_cmd(DebugCmd cmd_code, const vector<string>& args) static int dbg_dispatch_cmd(DebugCmd cmd_code, const vector<string>& args) {
{ switch ( cmd_code ) {
switch ( cmd_code ) case dcHelp: dbg_cmd_help(cmd_code, args); break;
{
case dcHelp:
dbg_cmd_help(cmd_code, args);
break;
case dcQuit: case dcQuit: debug_msg("Program Terminating\n"); exit(0);
debug_msg("Program Terminating\n");
exit(0);
case dcNext: case dcNext:
g_frame_stack.back()->BreakBeforeNextStmt(true); g_frame_stack.back()->BreakBeforeNextStmt(true);
@ -678,50 +612,36 @@ static int dbg_dispatch_cmd(DebugCmd cmd_code, const vector<string>& args)
g_debugger_state.BreakBeforeNextStmt(false); g_debugger_state.BreakBeforeNextStmt(false);
break; break;
case dcBreak: case dcBreak: dbg_cmd_break(cmd_code, args); break;
dbg_cmd_break(cmd_code, args);
break;
case dcBreakCondition: case dcBreakCondition: dbg_cmd_break_condition(cmd_code, args); break;
dbg_cmd_break_condition(cmd_code, args);
break;
case dcDeleteBreak: case dcDeleteBreak:
case dcClearBreak: case dcClearBreak:
case dcDisableBreak: case dcDisableBreak:
case dcEnableBreak: case dcEnableBreak:
case dcIgnoreBreak: case dcIgnoreBreak: dbg_cmd_break_set_state(cmd_code, args); break;
dbg_cmd_break_set_state(cmd_code, args);
break;
case dcPrint: case dcPrint: dbg_cmd_print(cmd_code, args); break;
dbg_cmd_print(cmd_code, args);
break;
case dcBacktrace: case dcBacktrace: return dbg_cmd_backtrace(cmd_code, args);
return dbg_cmd_backtrace(cmd_code, args);
case dcFrame: case dcFrame:
case dcUp: case dcUp:
case dcDown: case dcDown: return dbg_cmd_frame(cmd_code, args);
return dbg_cmd_frame(cmd_code, args);
case dcInfo: case dcInfo: return dbg_cmd_info(cmd_code, args);
return dbg_cmd_info(cmd_code, args);
case dcList: case dcList: return dbg_cmd_list(cmd_code, args);
return dbg_cmd_list(cmd_code, args);
case dcDisplay: case dcDisplay:
case dcUndisplay: case dcUndisplay: debug_msg("Command not yet implemented.\n"); break;
debug_msg("Command not yet implemented.\n");
break;
case dcTrace: case dcTrace: return dbg_cmd_trace(cmd_code, args);
return dbg_cmd_trace(cmd_code, args);
default: default:
debug_msg("INTERNAL ERROR: " debug_msg(
"INTERNAL ERROR: "
"Got an unknown debugger command in DbgDispatchCmd: %d\n", "Got an unknown debugger command in DbgDispatchCmd: %d\n",
cmd_code); cmd_code);
return 0; return 0;
@ -730,8 +650,7 @@ static int dbg_dispatch_cmd(DebugCmd cmd_code, const vector<string>& args)
return 0; return 0;
} }
static char* get_prompt(bool reset_counter = false) static char* get_prompt(bool reset_counter = false) {
{
static char prompt[512]; static char prompt[512];
static int counter = 0; static int counter = 0;
@ -743,8 +662,7 @@ static char* get_prompt(bool reset_counter = false)
return prompt; return prompt;
} }
string get_context_description(const Stmt* stmt, const Frame* frame) string get_context_description(const Stmt* stmt, const Frame* frame) {
{
ODesc d; ODesc d;
const ScriptFunc* func = frame ? frame->GetFunction() : nullptr; const ScriptFunc* func = frame ? frame->GetFunction() : nullptr;
@ -756,8 +674,7 @@ string get_context_description(const Stmt* stmt, const Frame* frame)
Location loc; Location loc;
if ( stmt ) if ( stmt )
loc = *stmt->GetLocationInfo(); loc = *stmt->GetLocationInfo();
else else {
{
loc.filename = util::copy_string("<no filename>"); loc.filename = util::copy_string("<no filename>");
loc.last_line = 0; loc.last_line = 0;
} }
@ -771,13 +688,11 @@ string get_context_description(const Stmt* stmt, const Frame* frame)
return retval; return retval;
} }
int dbg_handle_debug_input() int dbg_handle_debug_input() {
{
static char* input_line = nullptr; static char* input_line = nullptr;
int status = 0; int status = 0;
if ( g_debugger_state.BreakFromSignal() ) if ( g_debugger_state.BreakFromSignal() ) {
{
debug_msg("Program received signal SIGINT: entering debugger\n"); debug_msg("Program received signal SIGINT: entering debugger\n");
g_debugger_state.BreakFromSignal(false); g_debugger_state.BreakFromSignal(false);
@ -796,8 +711,7 @@ int dbg_handle_debug_input()
const Location loc = *stmt->GetLocationInfo(); const Location loc = *stmt->GetLocationInfo();
if ( ! step_or_next_pending || g_frame_stack.back() != last_frame ) if ( ! step_or_next_pending || g_frame_stack.back() != last_frame ) {
{
string context = get_context_description(stmt, g_frame_stack.back()); string context = get_context_description(stmt, g_frame_stack.back());
debug_msg("%s\n", context.c_str()); debug_msg("%s\n", context.c_str());
} }
@ -807,8 +721,7 @@ int dbg_handle_debug_input()
PrintLines(loc.filename, loc.first_line, loc.last_line - loc.first_line + 1, true); PrintLines(loc.filename, loc.first_line, loc.last_line - loc.first_line + 1, true);
g_debugger_state.last_loc = loc; g_debugger_state.last_loc = loc;
do do {
{
// readline returns a pointer to a buffer it allocates; it's // readline returns a pointer to a buffer it allocates; it's
// freed at the bottom. // freed at the bottom.
#ifdef HAVE_READLINE #ifdef HAVE_READLINE
@ -830,8 +743,7 @@ int dbg_handle_debug_input()
status = dbg_execute_command(input_line); status = dbg_execute_command(input_line);
if ( input_line ) if ( input_line ) {
{
free(input_line); // this was malloc'ed free(input_line); // this was malloc'ed
input_line = nullptr; input_line = nullptr;
} }
@ -850,13 +762,11 @@ int dbg_handle_debug_input()
} }
// Return true to continue execution, false to abort. // Return true to continue execution, false to abort.
bool pre_execute_stmt(Stmt* stmt, Frame* f) bool pre_execute_stmt(Stmt* stmt, Frame* f) {
{
if ( ! g_policy_debug || stmt->Tag() == STMT_LIST || stmt->Tag() == STMT_NULL ) if ( ! g_policy_debug || stmt->Tag() == STMT_LIST || stmt->Tag() == STMT_NULL )
return true; return true;
if ( g_trace_state.DoTrace() ) if ( g_trace_state.DoTrace() ) {
{
ODesc d; ODesc d;
stmt->Describe(&d); stmt->Describe(&d);
@ -874,8 +784,7 @@ bool pre_execute_stmt(Stmt* stmt, Frame* f)
bool should_break = false; bool should_break = false;
if ( g_debugger_state.BreakBeforeNextStmt() || f->BreakBeforeNextStmt() ) if ( g_debugger_state.BreakBeforeNextStmt() || f->BreakBeforeNextStmt() ) {
{
if ( g_debugger_state.BreakBeforeNextStmt() ) if ( g_debugger_state.BreakBeforeNextStmt() )
g_debugger_state.BreakBeforeNextStmt(false); g_debugger_state.BreakBeforeNextStmt(false);
@ -885,8 +794,7 @@ bool pre_execute_stmt(Stmt* stmt, Frame* f)
should_break = true; should_break = true;
} }
if ( stmt->BPCount() ) if ( stmt->BPCount() ) {
{
pair<BPMapType::iterator, BPMapType::iterator> p; pair<BPMapType::iterator, BPMapType::iterator> p;
p = g_debugger_state.breakpoint_map.equal_range(stmt); p = g_debugger_state.breakpoint_map.equal_range(stmt);
@ -894,8 +802,7 @@ bool pre_execute_stmt(Stmt* stmt, Frame* f)
if ( p.first == p.second ) if ( p.first == p.second )
reporter->InternalError("Breakpoint count nonzero, but no matching breakpoints"); reporter->InternalError("Breakpoint count nonzero, but no matching breakpoints");
for ( BPMapType::iterator i = p.first; i != p.second; ++i ) for ( BPMapType::iterator i = p.first; i != p.second; ++i ) {
{
int break_code = i->second->ShouldBreak(stmt); int break_code = i->second->ShouldBreak(stmt);
if ( break_code == 2 ) // ### 2? if ( break_code == 2 ) // ### 2?
{ {
@ -913,8 +820,7 @@ bool pre_execute_stmt(Stmt* stmt, Frame* f)
return true; return true;
} }
bool post_execute_stmt(Stmt* stmt, Frame* f, Val* result, StmtFlowType* flow) bool post_execute_stmt(Stmt* stmt, Frame* f, Val* result, StmtFlowType* flow) {
{
// Handle the case where someone issues a "next" debugger command, // Handle the case where someone issues a "next" debugger command,
// but we're at a return statement, so the next statement is in // but we're at a return statement, so the next statement is in
// some other function. // some other function.
@ -922,10 +828,8 @@ bool post_execute_stmt(Stmt* stmt, Frame* f, Val* result, StmtFlowType* flow)
g_debugger_state.BreakBeforeNextStmt(true); g_debugger_state.BreakBeforeNextStmt(true);
// Handle "finish" commands. // Handle "finish" commands.
if ( *flow == FLOW_RETURN && f->BreakOnReturn() ) if ( *flow == FLOW_RETURN && f->BreakOnReturn() ) {
{ if ( result ) {
if ( result )
{
ODesc d; ODesc d;
result->Describe(&d); result->Describe(&d);
debug_msg("Return Value: '%s'\n", d.Description()); debug_msg("Return Value: '%s'\n", d.Description());
@ -940,16 +844,14 @@ bool post_execute_stmt(Stmt* stmt, Frame* f, Val* result, StmtFlowType* flow)
return true; return true;
} }
ValPtr dbg_eval_expr(const char* expr) ValPtr dbg_eval_expr(const char* expr) {
{
// Push the current frame's associated scope. // Push the current frame's associated scope.
// Note: g_debugger_state.curr_frame_idx is the user-visible number, // Note: g_debugger_state.curr_frame_idx is the user-visible number,
// while the array index goes in the opposite direction // while the array index goes in the opposite direction
int frame_idx = (g_frame_stack.size() - 1) - g_debugger_state.curr_frame_idx; int frame_idx = (g_frame_stack.size() - 1) - g_debugger_state.curr_frame_idx;
if ( ! (frame_idx >= 0 && (unsigned)frame_idx < g_frame_stack.size()) ) if ( ! (frame_idx >= 0 && (unsigned)frame_idx < g_frame_stack.size()) )
reporter->InternalError( reporter->InternalError("Assertion failed: frame_idx >= 0 && (unsigned) frame_idx < g_frame_stack.size()");
"Assertion failed: frame_idx >= 0 && (unsigned) frame_idx < g_frame_stack.size()");
Frame* frame = g_frame_stack[frame_idx]; Frame* frame = g_frame_stack[frame_idx];
if ( ! (frame) ) if ( ! (frame) )
@ -973,15 +875,13 @@ ValPtr dbg_eval_expr(const char* expr)
// Parse the thing into an expr. // Parse the thing into an expr.
ValPtr result; ValPtr result;
if ( yyparse() ) if ( yyparse() ) {
{
if ( g_curr_debug_error ) if ( g_curr_debug_error )
debug_msg("Parsing expression '%s' failed: %s\n", expr, g_curr_debug_error); debug_msg("Parsing expression '%s' failed: %s\n", expr, g_curr_debug_error);
else else
debug_msg("Parsing expression '%s' failed\n", expr); debug_msg("Parsing expression '%s' failed\n", expr);
if ( g_curr_debug_expr ) if ( g_curr_debug_expr ) {
{
delete g_curr_debug_expr; delete g_curr_debug_expr;
g_curr_debug_expr = nullptr; g_curr_debug_expr = nullptr;
} }

View file

@ -11,17 +11,16 @@
#include "zeek/StmtEnums.h" #include "zeek/StmtEnums.h"
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek namespace zeek {
{
class Val; class Val;
template <class T> class IntrusivePtr; template<class T>
class IntrusivePtr;
using ValPtr = zeek::IntrusivePtr<Val>; using ValPtr = zeek::IntrusivePtr<Val>;
extern std::string current_module; extern std::string current_module;
namespace detail namespace detail {
{
class Frame; class Frame;
class Stmt; class Stmt;
@ -30,14 +29,8 @@ class DbgWatch;
class DbgDisplay; class DbgDisplay;
// This needs to be defined before we do the includes that come after it. // This needs to be defined before we do the includes that come after it.
enum ParseLocationRecType enum ParseLocationRecType { PLR_UNKNOWN, PLR_FILE_AND_LINE, PLR_FUNCTION };
{ class ParseLocationRec {
PLR_UNKNOWN,
PLR_FILE_AND_LINE,
PLR_FUNCTION
};
class ParseLocationRec
{
public: public:
ParseLocationRecType type; ParseLocationRecType type;
int32_t line; int32_t line;
@ -51,11 +44,9 @@ using Filemap = std::deque<StmtLocMapping*>; // mapping for a single file
using BPIDMapType = std::map<int, DbgBreakpoint*>; using BPIDMapType = std::map<int, DbgBreakpoint*>;
using BPMapType = std::multimap<const Stmt*, DbgBreakpoint*>; using BPMapType = std::multimap<const Stmt*, DbgBreakpoint*>;
class TraceState class TraceState {
{
public: public:
TraceState() TraceState() {
{
dbgtrace = false; dbgtrace = false;
trace_file = stderr; trace_file = stderr;
} }
@ -77,8 +68,7 @@ protected:
extern TraceState g_trace_state; extern TraceState g_trace_state;
class DebuggerState class DebuggerState {
{
public: public:
DebuggerState(); DebuggerState();
~DebuggerState(); ~DebuggerState();
@ -121,12 +111,10 @@ private:
// Source line -> statement mapping. // Source line -> statement mapping.
// (obj -> source line mapping available in object itself) // (obj -> source line mapping available in object itself)
class StmtLocMapping class StmtLocMapping {
{
public: public:
StmtLocMapping() {} StmtLocMapping() {}
StmtLocMapping(const Location* l, Stmt* s) StmtLocMapping(const Location* l, Stmt* s) {
{
loc = *l; loc = *l;
stmt = s; stmt = s;
} }

View file

@ -26,29 +26,22 @@
using namespace std; using namespace std;
namespace zeek::detail namespace zeek::detail {
{
DebugCmdInfoQueue g_DebugCmdInfos; DebugCmdInfoQueue g_DebugCmdInfos;
// //
// Helper routines // Helper routines
// //
static bool string_is_regex(const string& s) static bool string_is_regex(const string& s) { return strpbrk(s.data(), "?*\\+"); }
{
return strpbrk(s.data(), "?*\\+");
}
static void lookup_global_symbols_regex(const string& orig_regex, vector<ID*>& matches, static void lookup_global_symbols_regex(const string& orig_regex, vector<ID*>& matches, bool func_only = false) {
bool func_only = false)
{
if ( util::streq(orig_regex.c_str(), "") ) if ( util::streq(orig_regex.c_str(), "") )
return; return;
string regex = "^"; string regex = "^";
int len = orig_regex.length(); int len = orig_regex.length();
for ( int i = 0; i < len; ++i ) for ( int i = 0; i < len; ++i ) {
{
if ( orig_regex[i] == '*' ) if ( orig_regex[i] == '*' )
regex.push_back('.'); regex.push_back('.');
regex.push_back(orig_regex[i]); regex.push_back(orig_regex[i]);
@ -56,8 +49,7 @@ static void lookup_global_symbols_regex(const string& orig_regex, vector<ID*>& m
regex.push_back('$'); regex.push_back('$');
regex_t re; regex_t re;
if ( regcomp(&re, regex.c_str(), REG_EXTENDED | REG_NOSUB) ) if ( regcomp(&re, regex.c_str(), REG_EXTENDED | REG_NOSUB) ) {
{
debug_msg("Invalid regular expression: %s\n", regex.c_str()); debug_msg("Invalid regular expression: %s\n", regex.c_str());
return; return;
} }
@ -66,8 +58,7 @@ static void lookup_global_symbols_regex(const string& orig_regex, vector<ID*>& m
const auto& syms = global->Vars(); const auto& syms = global->Vars();
ID* nextid; ID* nextid;
for ( const auto& sym : syms ) for ( const auto& sym : syms ) {
{
ID* nextid = sym.second.get(); ID* nextid = sym.second.get();
if ( ! func_only || nextid->GetType()->Tag() == TYPE_FUNC ) if ( ! func_only || nextid->GetType()->Tag() == TYPE_FUNC )
if ( ! regexec(&re, nextid->Name(), 0, 0, 0) ) if ( ! regexec(&re, nextid->Name(), 0, 0, 0) )
@ -75,16 +66,13 @@ static void lookup_global_symbols_regex(const string& orig_regex, vector<ID*>& m
} }
} }
static void choose_global_symbols_regex(const string& regex, vector<ID*>& choices, static void choose_global_symbols_regex(const string& regex, vector<ID*>& choices, bool func_only = false) {
bool func_only = false)
{
lookup_global_symbols_regex(regex, choices, func_only); lookup_global_symbols_regex(regex, choices, func_only);
if ( choices.size() <= 1 ) if ( choices.size() <= 1 )
return; return;
while ( true ) while ( true ) {
{
debug_msg("There were multiple matches, please choose:\n"); debug_msg("There were multiple matches, please choose:\n");
for ( size_t i = 0; i < choices.size(); i++ ) for ( size_t i = 0; i < choices.size(); i++ )
@ -95,8 +83,7 @@ static void choose_global_symbols_regex(const string& regex, vector<ID*>& choice
debug_msg("Enter your choice: "); debug_msg("Enter your choice: ");
char charinput[256]; char charinput[256];
if ( ! fgets(charinput, sizeof(charinput) - 1, stdin) ) if ( ! fgets(charinput, sizeof(charinput) - 1, stdin) ) {
{
choices.clear(); choices.clear();
return; return;
} }
@ -107,14 +94,12 @@ static void choose_global_symbols_regex(const string& regex, vector<ID*>& choice
if ( input == "a" ) if ( input == "a" )
return; return;
if ( input == "n" ) if ( input == "n" ) {
{
choices.clear(); choices.clear();
return; return;
} }
int option = atoi(input.c_str()); int option = atoi(input.c_str());
if ( option > 0 && option <= (int)choices.size() ) if ( option > 0 && option <= (int)choices.size() ) {
{
ID* choice = choices[option - 1]; ID* choice = choices[option - 1];
choices.clear(); choices.clear();
choices.push_back(choice); choices.push_back(choice);
@ -127,19 +112,16 @@ static void choose_global_symbols_regex(const string& regex, vector<ID*>& choice
// DebugCmdInfo implementation // DebugCmdInfo implementation
// //
DebugCmdInfo::DebugCmdInfo(const DebugCmdInfo& info) : cmd(info.cmd), helpstring(nullptr) DebugCmdInfo::DebugCmdInfo(const DebugCmdInfo& info) : cmd(info.cmd), helpstring(nullptr) {
{
num_names = info.num_names; num_names = info.num_names;
names = info.names; names = info.names;
resume_execution = info.resume_execution; resume_execution = info.resume_execution;
repeatable = info.repeatable; repeatable = info.repeatable;
} }
DebugCmdInfo::DebugCmdInfo(DebugCmd arg_cmd, const char* const* arg_names, int arg_num_names, DebugCmdInfo::DebugCmdInfo(DebugCmd arg_cmd, const char* const* arg_names, int arg_num_names, bool arg_resume_execution,
bool arg_resume_execution, const char* const arg_helpstring, const char* const arg_helpstring, bool arg_repeatable)
bool arg_repeatable) : cmd(arg_cmd), helpstring(arg_helpstring) {
: cmd(arg_cmd), helpstring(arg_helpstring)
{
num_names = arg_num_names; num_names = arg_num_names;
resume_execution = arg_resume_execution; resume_execution = arg_resume_execution;
repeatable = arg_repeatable; repeatable = arg_repeatable;
@ -148,34 +130,29 @@ DebugCmdInfo::DebugCmdInfo(DebugCmd arg_cmd, const char* const* arg_names, int a
names.push_back(arg_names[i]); names.push_back(arg_names[i]);
} }
const DebugCmdInfo* get_debug_cmd_info(DebugCmd cmd) const DebugCmdInfo* get_debug_cmd_info(DebugCmd cmd) {
{
if ( (int)cmd < g_DebugCmdInfos.size() ) if ( (int)cmd < g_DebugCmdInfos.size() )
return g_DebugCmdInfos[(int)cmd]; return g_DebugCmdInfos[(int)cmd];
else else
return nullptr; return nullptr;
} }
int find_all_matching_cmds(const string& prefix, const char* array_of_matches[]) int find_all_matching_cmds(const string& prefix, const char* array_of_matches[]) {
{
// Trivial implementation for now (### use hashing later). // Trivial implementation for now (### use hashing later).
unsigned int arglen = prefix.length(); unsigned int arglen = prefix.length();
int matches = 0; int matches = 0;
for ( int i = 0; i < num_debug_cmds(); ++i ) for ( int i = 0; i < num_debug_cmds(); ++i ) {
{
array_of_matches[g_DebugCmdInfos[i]->Cmd()] = nullptr; array_of_matches[g_DebugCmdInfos[i]->Cmd()] = nullptr;
for ( int j = 0; j < g_DebugCmdInfos[i]->NumNames(); ++j ) for ( int j = 0; j < g_DebugCmdInfos[i]->NumNames(); ++j ) {
{
const char* curr_name = g_DebugCmdInfos[i]->Names()[j]; const char* curr_name = g_DebugCmdInfos[i]->Names()[j];
if ( strncmp(curr_name, prefix.c_str(), arglen) ) if ( strncmp(curr_name, prefix.c_str(), arglen) )
continue; continue;
// If exact match, then only return that one. // If exact match, then only return that one.
if ( ! prefix.compare(curr_name) ) if ( ! prefix.compare(curr_name) ) {
{
for ( int k = 0; k < num_debug_cmds(); ++k ) for ( int k = 0; k < num_debug_cmds(); ++k )
array_of_matches[k] = nullptr; array_of_matches[k] = nullptr;
@ -196,21 +173,17 @@ int find_all_matching_cmds(const string& prefix, const char* array_of_matches[])
// Implementation of some debugger commands // Implementation of some debugger commands
// Start, end bounds of which frame numbers to print // Start, end bounds of which frame numbers to print
static int dbg_backtrace_internal(int start, int end) static int dbg_backtrace_internal(int start, int end) {
{ if ( start < 0 || end < 0 || (unsigned)start >= g_frame_stack.size() || (unsigned)end >= g_frame_stack.size() )
if ( start < 0 || end < 0 || (unsigned)start >= g_frame_stack.size() ||
(unsigned)end >= g_frame_stack.size() )
reporter->InternalError("Invalid stack frame index in DbgBacktraceInternal\n"); reporter->InternalError("Invalid stack frame index in DbgBacktraceInternal\n");
if ( start < end ) if ( start < end ) {
{
int temp = start; int temp = start;
start = end; start = end;
end = temp; end = temp;
} }
for ( int i = start; i >= end; --i ) for ( int i = start; i >= end; --i ) {
{
const Frame* f = g_frame_stack[i]; const Frame* f = g_frame_stack[i];
const Stmt* stmt = f ? f->GetNextStmt() : nullptr; const Stmt* stmt = f ? f->GetNextStmt() : nullptr;
@ -222,41 +195,35 @@ static int dbg_backtrace_internal(int start, int end)
} }
// Returns 0 for illegal arguments, or 1 on success. // Returns 0 for illegal arguments, or 1 on success.
int dbg_cmd_backtrace(DebugCmd cmd, const vector<string>& args) int dbg_cmd_backtrace(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcBacktrace); assert(cmd == dcBacktrace);
assert(g_frame_stack.size() > 0); assert(g_frame_stack.size() > 0);
unsigned int start_iter; unsigned int start_iter;
int end_iter; int end_iter;
if ( args.size() > 0 ) if ( args.size() > 0 ) {
{
int how_many; // determines how we traverse the frames int how_many; // determines how we traverse the frames
int valid_arg = sscanf(args[0].c_str(), "%i", &how_many); int valid_arg = sscanf(args[0].c_str(), "%i", &how_many);
if ( ! valid_arg ) if ( ! valid_arg ) {
{
debug_msg("Argument to backtrace '%s' invalid: must be an integer\n", args[0].c_str()); debug_msg("Argument to backtrace '%s' invalid: must be an integer\n", args[0].c_str());
return 0; return 0;
} }
if ( how_many > 0 ) if ( how_many > 0 ) { // innermost N frames
{ // innermost N frames
start_iter = g_frame_stack.size() - 1; start_iter = g_frame_stack.size() - 1;
end_iter = start_iter - how_many + 1; end_iter = start_iter - how_many + 1;
if ( end_iter < 0 ) if ( end_iter < 0 )
end_iter = 0; end_iter = 0;
} }
else else { // outermost N frames
{ // outermost N frames
start_iter = how_many - 1; start_iter = how_many - 1;
if ( start_iter + 1 > g_frame_stack.size() ) if ( start_iter + 1 > g_frame_stack.size() )
start_iter = g_frame_stack.size() - 1; start_iter = g_frame_stack.size() - 1;
end_iter = 0; end_iter = 0;
} }
} }
else else {
{
start_iter = g_frame_stack.size() - 1; start_iter = g_frame_stack.size() - 1;
end_iter = 0; end_iter = 0;
} }
@ -265,30 +232,24 @@ int dbg_cmd_backtrace(DebugCmd cmd, const vector<string>& args)
} }
// Returns 0 if invalid args, else 1. // Returns 0 if invalid args, else 1.
int dbg_cmd_frame(DebugCmd cmd, const vector<string>& args) int dbg_cmd_frame(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcFrame || cmd == dcUp || cmd == dcDown); assert(cmd == dcFrame || cmd == dcUp || cmd == dcDown);
if ( cmd == dcFrame ) if ( cmd == dcFrame ) {
{
int idx = 0; int idx = 0;
if ( args.size() > 0 ) if ( args.size() > 0 ) {
{ if ( args.size() > 1 ) {
if ( args.size() > 1 )
{
debug_msg("Too many arguments: expecting frame number 'n'\n"); debug_msg("Too many arguments: expecting frame number 'n'\n");
return 0; return 0;
} }
if ( ! sscanf(args[0].c_str(), "%d", &idx) ) if ( ! sscanf(args[0].c_str(), "%d", &idx) ) {
{
debug_msg("Argument to frame must be a positive integer\n"); debug_msg("Argument to frame must be a positive integer\n");
return 0; return 0;
} }
if ( idx < 0 || (unsigned int)idx >= g_frame_stack.size() ) if ( idx < 0 || (unsigned int)idx >= g_frame_stack.size() ) {
{
debug_msg("No frame %d", idx); debug_msg("No frame %d", idx);
return 0; return 0;
} }
@ -297,10 +258,8 @@ int dbg_cmd_frame(DebugCmd cmd, const vector<string>& args)
g_debugger_state.curr_frame_idx = idx; g_debugger_state.curr_frame_idx = idx;
} }
else if ( cmd == dcDown ) else if ( cmd == dcDown ) {
{ if ( g_debugger_state.curr_frame_idx == 0 ) {
if ( g_debugger_state.curr_frame_idx == 0 )
{
debug_msg("Innermost frame already selected\n"); debug_msg("Innermost frame already selected\n");
return 0; return 0;
} }
@ -308,10 +267,8 @@ int dbg_cmd_frame(DebugCmd cmd, const vector<string>& args)
g_debugger_state.curr_frame_idx--; g_debugger_state.curr_frame_idx--;
} }
else if ( cmd == dcUp ) else if ( cmd == dcUp ) {
{ if ( (unsigned int)(g_debugger_state.curr_frame_idx + 1) == g_frame_stack.size() ) {
if ( (unsigned int)(g_debugger_state.curr_frame_idx + 1) == g_frame_stack.size() )
{
debug_msg("Outermost frame already selected\n"); debug_msg("Outermost frame already selected\n");
return 0; return 0;
} }
@ -334,13 +291,11 @@ int dbg_cmd_frame(DebugCmd cmd, const vector<string>& args)
return dbg_backtrace_internal(user_frame_number, user_frame_number); return dbg_backtrace_internal(user_frame_number, user_frame_number);
} }
int dbg_cmd_help(DebugCmd cmd, const vector<string>& args) int dbg_cmd_help(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcHelp); assert(cmd == dcHelp);
debug_msg("Help summary: \n\n"); debug_msg("Help summary: \n\n");
for ( int i = 1; i < num_debug_cmds(); ++i ) for ( int i = 1; i < num_debug_cmds(); ++i ) {
{
const DebugCmdInfo* info = get_debug_cmd_info(DebugCmd(i)); const DebugCmdInfo* info = get_debug_cmd_info(DebugCmd(i));
debug_msg("%s -- %s\n", info->Names()[0], info->Helpstring()); debug_msg("%s -- %s\n", info->Names()[0], info->Helpstring());
} }
@ -348,16 +303,14 @@ int dbg_cmd_help(DebugCmd cmd, const vector<string>& args)
return -1; return -1;
} }
int dbg_cmd_break(DebugCmd cmd, const vector<string>& args) int dbg_cmd_break(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcBreak); assert(cmd == dcBreak);
vector<DbgBreakpoint*> bps; vector<DbgBreakpoint*> bps;
int cond_index = -1; // at which argument pos. does bp condition start? int cond_index = -1; // at which argument pos. does bp condition start?
if ( args.empty() || args[0] == "if" ) if ( args.empty() || args[0] == "if" ) { // break on next stmt
{ // break on next stmt
int user_frame_number = g_frame_stack.size() - 1 - g_debugger_state.curr_frame_idx; int user_frame_number = g_frame_stack.size() - 1 - g_debugger_state.curr_frame_idx;
Stmt* stmt = g_frame_stack[user_frame_number]->GetNextStmt(); Stmt* stmt = g_frame_stack[user_frame_number]->GetNextStmt();
@ -367,8 +320,7 @@ int dbg_cmd_break(DebugCmd cmd, const vector<string>& args)
DbgBreakpoint* bp = new DbgBreakpoint(); DbgBreakpoint* bp = new DbgBreakpoint();
bp->SetID(g_debugger_state.NextBPID()); bp->SetID(g_debugger_state.NextBPID());
if ( ! bp->SetLocation(stmt) ) if ( ! bp->SetLocation(stmt) ) {
{
debug_msg("Breakpoint not set.\n"); debug_msg("Breakpoint not set.\n");
delete bp; delete bp;
return 0; return 0;
@ -380,11 +332,9 @@ int dbg_cmd_break(DebugCmd cmd, const vector<string>& args)
bps.push_back(bp); bps.push_back(bp);
} }
else else {
{
vector<string> locstrings; vector<string> locstrings;
if ( string_is_regex(args[0]) ) if ( string_is_regex(args[0]) ) {
{
vector<ID*> choices; vector<ID*> choices;
choose_global_symbols_regex(args[0], choices, true); choose_global_symbols_regex(args[0], choices, true);
for ( unsigned int i = 0; i < choices.size(); ++i ) for ( unsigned int i = 0; i < choices.size(); ++i )
@ -393,16 +343,13 @@ int dbg_cmd_break(DebugCmd cmd, const vector<string>& args)
else else
locstrings.push_back(args[0]); locstrings.push_back(args[0]);
for ( unsigned int strindex = 0; strindex < locstrings.size(); ++strindex ) for ( unsigned int strindex = 0; strindex < locstrings.size(); ++strindex ) {
{
debug_msg("Setting breakpoint on %s:\n", locstrings[strindex].c_str()); debug_msg("Setting breakpoint on %s:\n", locstrings[strindex].c_str());
vector<ParseLocationRec> plrs = parse_location_string(locstrings[strindex]); vector<ParseLocationRec> plrs = parse_location_string(locstrings[strindex]);
for ( const auto& plr : plrs ) for ( const auto& plr : plrs ) {
{
DbgBreakpoint* bp = new DbgBreakpoint(); DbgBreakpoint* bp = new DbgBreakpoint();
bp->SetID(g_debugger_state.NextBPID()); bp->SetID(g_debugger_state.NextBPID());
if ( ! bp->SetLocation(plr, locstrings[strindex]) ) if ( ! bp->SetLocation(plr, locstrings[strindex]) ) {
{
debug_msg("Breakpoint not set.\n"); debug_msg("Breakpoint not set.\n");
delete bp; delete bp;
} }
@ -416,20 +363,17 @@ int dbg_cmd_break(DebugCmd cmd, const vector<string>& args)
} }
// Is there a condition specified? // Is there a condition specified?
if ( cond_index >= 0 && ! bps.empty() ) if ( cond_index >= 0 && ! bps.empty() ) {
{
// ### Implement conditions // ### Implement conditions
string cond; string cond;
for ( const auto& arg : args ) for ( const auto& arg : args ) {
{
cond += arg; cond += arg;
cond += " "; cond += " ";
} }
bps[0]->SetCondition(cond); bps[0]->SetCondition(cond);
} }
for ( auto& bp : bps ) for ( auto& bp : bps ) {
{
bp->SetTemporary(false); bp->SetTemporary(false);
g_debugger_state.breakpoints[bp->GetID()] = bp; g_debugger_state.breakpoints[bp->GetID()] = bp;
} }
@ -438,12 +382,10 @@ int dbg_cmd_break(DebugCmd cmd, const vector<string>& args)
} }
// Set a condition on an existing breakpoint. // Set a condition on an existing breakpoint.
int dbg_cmd_break_condition(DebugCmd cmd, const vector<string>& args) int dbg_cmd_break_condition(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcBreakCondition); assert(cmd == dcBreakCondition);
if ( args.size() < 2 ) if ( args.size() < 2 ) {
{
debug_msg("Arguments must specify breakpoint number and condition.\n"); debug_msg("Arguments must specify breakpoint number and condition.\n");
return 0; return 0;
} }
@ -452,8 +394,7 @@ int dbg_cmd_break_condition(DebugCmd cmd, const vector<string>& args)
DbgBreakpoint* bp = g_debugger_state.breakpoints[idx]; DbgBreakpoint* bp = g_debugger_state.breakpoints[idx];
string expr; string expr;
for ( int i = 1; i < int(args.size()); ++i ) for ( int i = 1; i < int(args.size()); ++i ) {
{
expr += args[i]; expr += args[i];
expr += " "; expr += " ";
} }
@ -463,47 +404,38 @@ int dbg_cmd_break_condition(DebugCmd cmd, const vector<string>& args)
} }
// Change the state of a breakpoint. // Change the state of a breakpoint.
int dbg_cmd_break_set_state(DebugCmd cmd, const vector<string>& args) int dbg_cmd_break_set_state(DebugCmd cmd, const vector<string>& args) {
{ assert(cmd == dcDeleteBreak || cmd == dcClearBreak || cmd == dcDisableBreak || cmd == dcEnableBreak ||
assert(cmd == dcDeleteBreak || cmd == dcClearBreak || cmd == dcDisableBreak || cmd == dcIgnoreBreak);
cmd == dcEnableBreak || cmd == dcIgnoreBreak);
if ( cmd == dcClearBreak || cmd == dcIgnoreBreak ) if ( cmd == dcClearBreak || cmd == dcIgnoreBreak ) {
{
debug_msg("'clear' and 'ignore' commands not currently supported\n"); debug_msg("'clear' and 'ignore' commands not currently supported\n");
return 0; return 0;
} }
if ( g_debugger_state.breakpoints.empty() ) if ( g_debugger_state.breakpoints.empty() ) {
{
debug_msg("No breakpoints currently set.\n"); debug_msg("No breakpoints currently set.\n");
return -1; return -1;
} }
vector<int> bps_to_change; vector<int> bps_to_change;
if ( args.empty() ) if ( args.empty() ) {
{
BPIDMapType::iterator iter; BPIDMapType::iterator iter;
for ( iter = g_debugger_state.breakpoints.begin(); for ( iter = g_debugger_state.breakpoints.begin(); iter != g_debugger_state.breakpoints.end(); ++iter )
iter != g_debugger_state.breakpoints.end(); ++iter )
bps_to_change.push_back(iter->second->GetID()); bps_to_change.push_back(iter->second->GetID());
} }
else else {
{
for ( const auto& arg : args ) for ( const auto& arg : args )
if ( int idx = atoi(arg.c_str()) ) if ( int idx = atoi(arg.c_str()) )
bps_to_change.push_back(idx); bps_to_change.push_back(idx);
} }
for ( auto bp_change : bps_to_change ) for ( auto bp_change : bps_to_change ) {
{
BPIDMapType::iterator result = g_debugger_state.breakpoints.find(bp_change); BPIDMapType::iterator result = g_debugger_state.breakpoints.find(bp_change);
if ( result != g_debugger_state.breakpoints.end() ) if ( result != g_debugger_state.breakpoints.end() ) {
{ switch ( cmd ) {
switch ( cmd )
{
case dcDisableBreak: case dcDisableBreak:
g_debugger_state.breakpoints[bp_change]->SetEnable(false); g_debugger_state.breakpoints[bp_change]->SetEnable(false);
debug_msg("Breakpoint %d disabled\n", bp_change); debug_msg("Breakpoint %d disabled\n", bp_change);
@ -520,8 +452,7 @@ int dbg_cmd_break_set_state(DebugCmd cmd, const vector<string>& args)
debug_msg("Breakpoint %d deleted\n", bp_change); debug_msg("Breakpoint %d deleted\n", bp_change);
break; break;
default: default: reporter->InternalError("Invalid command in DbgCmdBreakSetState\n");
reporter->InternalError("Invalid command in DbgCmdBreakSetState\n");
} }
} }
@ -533,16 +464,14 @@ int dbg_cmd_break_set_state(DebugCmd cmd, const vector<string>& args)
} }
// Evaluate an expression and print the result. // Evaluate an expression and print the result.
int dbg_cmd_print(DebugCmd cmd, const vector<string>& args) int dbg_cmd_print(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcPrint); assert(cmd == dcPrint);
// ### TODO: add support for formats // ### TODO: add support for formats
// Just concatenate all the 'args' into one expression. // Just concatenate all the 'args' into one expression.
string expr; string expr;
for ( size_t i = 0; i < args.size(); ++i ) for ( size_t i = 0; i < args.size(); ++i ) {
{
expr += args[i]; expr += args[i];
if ( i < args.size() - 1 ) if ( i < args.size() - 1 )
expr += " "; expr += " ";
@ -550,14 +479,12 @@ int dbg_cmd_print(DebugCmd cmd, const vector<string>& args)
auto val = dbg_eval_expr(expr.c_str()); auto val = dbg_eval_expr(expr.c_str());
if ( val ) if ( val ) {
{
ODesc d; ODesc d;
val->Describe(&d); val->Describe(&d);
debug_msg("%s\n", d.Description()); debug_msg("%s\n", d.Description());
} }
else else {
{
debug_msg("<expression has no value>\n"); debug_msg("<expression has no value>\n");
} }
@ -566,12 +493,10 @@ int dbg_cmd_print(DebugCmd cmd, const vector<string>& args)
// Get the debugger's state. // Get the debugger's state.
// Allowed arguments are: break (breakpoints), watch, display, source. // Allowed arguments are: break (breakpoints), watch, display, source.
int dbg_cmd_info(DebugCmd cmd, const vector<string>& args) int dbg_cmd_info(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcInfo); assert(cmd == dcInfo);
if ( args.empty() ) if ( args.empty() ) {
{
debug_msg("Syntax: info info-command\n"); debug_msg("Syntax: info info-command\n");
debug_msg("List of info-commands:\n"); debug_msg("List of info-commands:\n");
debug_msg("info breakpoints -- List of breakpoints and watches\n"); debug_msg("info breakpoints -- List of breakpoints and watches\n");
@ -579,18 +504,14 @@ int dbg_cmd_info(DebugCmd cmd, const vector<string>& args)
} }
if ( ! strncmp(args[0].c_str(), "breakpoints", args[0].size()) || if ( ! strncmp(args[0].c_str(), "breakpoints", args[0].size()) ||
! strncmp(args[0].c_str(), "watch", args[0].size()) ) ! strncmp(args[0].c_str(), "watch", args[0].size()) ) {
{
debug_msg("Num Type Disp Enb What\n"); debug_msg("Num Type Disp Enb What\n");
BPIDMapType::iterator iter; BPIDMapType::iterator iter;
for ( iter = g_debugger_state.breakpoints.begin(); for ( iter = g_debugger_state.breakpoints.begin(); iter != g_debugger_state.breakpoints.end(); ++iter ) {
iter != g_debugger_state.breakpoints.end(); ++iter )
{
DbgBreakpoint* bp = (*iter).second; DbgBreakpoint* bp = (*iter).second;
debug_msg("%-4d%-15s%-5s%-4s%s\n", bp->GetID(), "breakpoint", debug_msg("%-4d%-15s%-5s%-4s%s\n", bp->GetID(), "breakpoint", bp->IsTemporary() ? "del" : "keep",
bp->IsTemporary() ? "del" : "keep", bp->IsEnabled() ? "y" : "n", bp->IsEnabled() ? "y" : "n", bp->Description());
bp->Description());
} }
} }
@ -600,22 +521,19 @@ int dbg_cmd_info(DebugCmd cmd, const vector<string>& args)
return 1; return 1;
} }
int dbg_cmd_list(DebugCmd cmd, const vector<string>& args) int dbg_cmd_list(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcList); assert(cmd == dcList);
// The constant 4 is to match the GDB behavior. // The constant 4 is to match the GDB behavior.
const unsigned int CENTER_IDX = 4; // 5th line is the 'interesting' one const unsigned int CENTER_IDX = 4; // 5th line is the 'interesting' one
int pre_offset = 0; int pre_offset = 0;
if ( args.size() > 1 ) if ( args.size() > 1 ) {
{
debug_msg("Syntax: list [file:]line OR list function_name\n"); debug_msg("Syntax: list [file:]line OR list function_name\n");
return 0; return 0;
} }
if ( args.empty() ) if ( args.empty() ) {
{
// Special case: if we just hit a breakpoint, then show // Special case: if we just hit a breakpoint, then show
// that line without advancing first. // that line without advancing first.
if ( g_debugger_state.already_did_list ) if ( g_debugger_state.already_did_list )
@ -626,11 +544,9 @@ int dbg_cmd_list(DebugCmd cmd, const vector<string>& args)
// Why -10 ? Because that's what GDB does. // Why -10 ? Because that's what GDB does.
pre_offset = -10; pre_offset = -10;
else if ( args[0][0] == '-' || args[0][0] == '+' ) else if ( args[0][0] == '-' || args[0][0] == '+' ) {
{
int offset; int offset;
if ( ! sscanf(args[0].c_str(), "%d", &offset) ) if ( ! sscanf(args[0].c_str(), "%d", &offset) ) {
{
debug_msg("Offset must be a number\n"); debug_msg("Offset must be a number\n");
return false; return false;
} }
@ -638,12 +554,10 @@ int dbg_cmd_list(DebugCmd cmd, const vector<string>& args)
pre_offset = offset; pre_offset = offset;
} }
else else {
{
vector<ParseLocationRec> plrs = parse_location_string(args[0]); vector<ParseLocationRec> plrs = parse_location_string(args[0]);
ParseLocationRec plr = plrs[0]; ParseLocationRec plr = plrs[0];
if ( plr.type == PLR_UNKNOWN ) if ( plr.type == PLR_UNKNOWN ) {
{
debug_msg("Invalid location specifier\n"); debug_msg("Invalid location specifier\n");
return false; return false;
} }
@ -663,32 +577,27 @@ int dbg_cmd_list(DebugCmd cmd, const vector<string>& args)
if ( g_debugger_state.last_loc.first_line > last_line_in_file ) if ( g_debugger_state.last_loc.first_line > last_line_in_file )
g_debugger_state.last_loc.first_line = last_line_in_file; g_debugger_state.last_loc.first_line = last_line_in_file;
PrintLines(g_debugger_state.last_loc.filename, PrintLines(g_debugger_state.last_loc.filename, g_debugger_state.last_loc.first_line - CENTER_IDX, 10, true);
g_debugger_state.last_loc.first_line - CENTER_IDX, 10, true);
g_debugger_state.already_did_list = true; g_debugger_state.already_did_list = true;
return 1; return 1;
} }
int dbg_cmd_trace(DebugCmd cmd, const vector<string>& args) int dbg_cmd_trace(DebugCmd cmd, const vector<string>& args) {
{
assert(cmd == dcTrace); assert(cmd == dcTrace);
if ( args.empty() ) if ( args.empty() ) {
{
debug_msg("Execution tracing is %s.\n", g_trace_state.DoTrace() ? "on" : "off"); debug_msg("Execution tracing is %s.\n", g_trace_state.DoTrace() ? "on" : "off");
return 1; return 1;
} }
if ( args[0] == "on" ) if ( args[0] == "on" ) {
{
g_trace_state.TraceOn(); g_trace_state.TraceOn();
return 1; return 1;
} }
if ( args[0] == "off" ) if ( args[0] == "off" ) {
{
g_trace_state.TraceOff(); g_trace_state.TraceOff();
return 1; return 1;
} }

View file

@ -11,11 +11,9 @@
// This file is generated during the build. // This file is generated during the build.
#include "DebugCmdConstants.h" #include "DebugCmdConstants.h"
namespace zeek::detail namespace zeek::detail {
{
class DebugCmdInfo class DebugCmdInfo {
{
public: public:
DebugCmdInfo(const DebugCmdInfo& info); DebugCmdInfo(const DebugCmdInfo& info);

View file

@ -11,45 +11,36 @@
zeek::detail::DebugLogger zeek::detail::debug_logger; zeek::detail::DebugLogger zeek::detail::debug_logger;
zeek::detail::DebugLogger& debug_logger = zeek::detail::debug_logger; zeek::detail::DebugLogger& debug_logger = zeek::detail::debug_logger;
namespace zeek::detail namespace zeek::detail {
{
// Same order here as in DebugStream. // Same order here as in DebugStream.
DebugLogger::Stream DebugLogger::streams[NUM_DBGS] = { DebugLogger::Stream DebugLogger::streams[NUM_DBGS] =
{"serial", 0, false}, {"rules", 0, false}, {"string", 0, false}, {{"serial", 0, false}, {"rules", 0, false}, {"string", 0, false}, {"notifiers", 0, false},
{"notifiers", 0, false}, {"main-loop", 0, false}, {"dpd", 0, false}, {"main-loop", 0, false}, {"dpd", 0, false}, {"packet_analysis", 0, false}, {"file_analysis", 0, false},
{"packet_analysis", 0, false}, {"file_analysis", 0, false}, {"tm", 0, false}, {"tm", 0, false}, {"logging", 0, false}, {"input", 0, false}, {"threading", 0, false},
{"logging", 0, false}, {"input", 0, false}, {"threading", 0, false}, {"plugins", 0, false}, {"zeekygen", 0, false}, {"pktio", 0, false}, {"broker", 0, false},
{"plugins", 0, false}, {"zeekygen", 0, false}, {"pktio", 0, false}, {"scripts", 0, false}, {"supervisor", 0, false}, {"hashkey", 0, false}, {"spicy", 0, false}};
{"broker", 0, false}, {"scripts", 0, false}, {"supervisor", 0, false},
{"hashkey", 0, false}, {"spicy", 0, false}};
DebugLogger::DebugLogger() DebugLogger::DebugLogger() {
{
verbose = false; verbose = false;
file = nullptr; file = nullptr;
} }
DebugLogger::~DebugLogger() DebugLogger::~DebugLogger() {
{
if ( file && file != stderr ) if ( file && file != stderr )
fclose(file); fclose(file);
} }
void DebugLogger::OpenDebugLog(const char* filename) void DebugLogger::OpenDebugLog(const char* filename) {
{ if ( filename ) {
if ( filename )
{
filename = util::detail::log_file_name(filename); filename = util::detail::log_file_name(filename);
file = fopen(filename, "w"); file = fopen(filename, "w");
if ( ! file ) if ( ! file ) {
{
// The reporter may not be initialized here yet. // The reporter may not be initialized here yet.
if ( reporter ) if ( reporter )
reporter->FatalError("can't open '%s' for debugging output", filename); reporter->FatalError("can't open '%s' for debugging output", filename);
else else {
{
fprintf(stderr, "can't open '%s' for debugging output\n", filename); fprintf(stderr, "can't open '%s' for debugging output\n", filename);
exit(1); exit(1);
} }
@ -61,8 +52,7 @@ void DebugLogger::OpenDebugLog(const char* filename)
file = stderr; file = stderr;
} }
void DebugLogger::ShowStreamsHelp() void DebugLogger::ShowStreamsHelp() {
{
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "Enable debug output into debug.log with -B <streams>.\n"); fprintf(stderr, "Enable debug output into debug.log with -B <streams>.\n");
fprintf(stderr, "<streams> is a comma-separated list of streams to enable.\n"); fprintf(stderr, "<streams> is a comma-separated list of streams to enable.\n");
@ -73,7 +63,8 @@ void DebugLogger::ShowStreamsHelp()
fprintf(stderr, " %s\n", streams[i].prefix); fprintf(stderr, " %s\n", streams[i].prefix);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, " plugin-<plugin-name> (replace '::' in name with '-'; e.g., '-B " fprintf(stderr,
" plugin-<plugin-name> (replace '::' in name with '-'; e.g., '-B "
"plugin-Zeek-Netmap')\n"); "plugin-Zeek-Netmap')\n");
fprintf(stderr, "\n"); fprintf(stderr, "\n");
fprintf(stderr, "Pseudo streams\n"); fprintf(stderr, "Pseudo streams\n");
@ -82,18 +73,14 @@ void DebugLogger::ShowStreamsHelp()
fprintf(stderr, "\n"); fprintf(stderr, "\n");
} }
void DebugLogger::EnableStreams(const char* s) void DebugLogger::EnableStreams(const char* s) {
{
char* brkt; char* brkt;
char* tmp = util::copy_string(s); char* tmp = util::copy_string(s);
char* tok = strtok(tmp, ","); char* tok = strtok(tmp, ",");
while ( tok ) while ( tok ) {
{ if ( strcasecmp("all", tok) == 0 ) {
if ( strcasecmp("all", tok) == 0 ) for ( int i = 0; i < NUM_DBGS; ++i ) {
{
for ( int i = 0; i < NUM_DBGS; ++i )
{
streams[i].enabled = true; streams[i].enabled = true;
enabled_streams.insert(streams[i].prefix); enabled_streams.insert(streams[i].prefix);
} }
@ -102,20 +89,17 @@ void DebugLogger::EnableStreams(const char* s)
goto next; goto next;
} }
if ( strcasecmp("verbose", tok) == 0 ) if ( strcasecmp("verbose", tok) == 0 ) {
{
verbose = true; verbose = true;
goto next; goto next;
} }
if ( strcasecmp("help", tok) == 0 ) if ( strcasecmp("help", tok) == 0 ) {
{
ShowStreamsHelp(); ShowStreamsHelp();
exit(0); exit(0);
} }
if ( util::starts_with(tok, "plugin-") ) if ( util::starts_with(tok, "plugin-") ) {
{
// Cannot verify this at this time, plugins may not // Cannot verify this at this time, plugins may not
// have been loaded. // have been loaded.
enabled_streams.insert(tok); enabled_streams.insert(tok);
@ -124,10 +108,8 @@ void DebugLogger::EnableStreams(const char* s)
int i; int i;
for ( i = 0; i < NUM_DBGS; ++i ) for ( i = 0; i < NUM_DBGS; ++i ) {
{ if ( strcasecmp(streams[i].prefix, tok) == 0 ) {
if ( strcasecmp(streams[i].prefix, tok) == 0 )
{
streams[i].enabled = true; streams[i].enabled = true;
enabled_streams.insert(tok); enabled_streams.insert(tok);
goto next; goto next;
@ -143,21 +125,18 @@ void DebugLogger::EnableStreams(const char* s)
delete[] tmp; delete[] tmp;
} }
bool DebugLogger::CheckStreams(const std::set<std::string>& plugin_names) bool DebugLogger::CheckStreams(const std::set<std::string>& plugin_names) {
{
bool ok = true; bool ok = true;
std::set<std::string> available_plugin_streams; std::set<std::string> available_plugin_streams;
for ( const auto& p : plugin_names ) for ( const auto& p : plugin_names )
available_plugin_streams.insert(PluginStreamName(p)); available_plugin_streams.insert(PluginStreamName(p));
for ( const auto& stream : enabled_streams ) for ( const auto& stream : enabled_streams ) {
{
if ( ! util::starts_with(stream, "plugin-") ) if ( ! util::starts_with(stream, "plugin-") )
continue; continue;
if ( available_plugin_streams.count(stream) == 0 ) if ( available_plugin_streams.count(stream) == 0 ) {
{
reporter->Error("No plugin debug stream '%s' found", stream.c_str()); reporter->Error("No plugin debug stream '%s' found", stream.c_str());
ok = false; ok = false;
} }
@ -166,15 +145,13 @@ bool DebugLogger::CheckStreams(const std::set<std::string>& plugin_names)
return ok; return ok;
} }
void DebugLogger::Log(DebugStream stream, const char* fmt, ...) void DebugLogger::Log(DebugStream stream, const char* fmt, ...) {
{
Stream* g = &streams[int(stream)]; Stream* g = &streams[int(stream)];
if ( ! g->enabled ) if ( ! g->enabled )
return; return;
fprintf(file, "%17.06f/%17.06f [%s] ", run_state::network_time, util::current_time(true), fprintf(file, "%17.06f/%17.06f [%s] ", run_state::network_time, util::current_time(true), g->prefix);
g->prefix);
for ( int i = g->indent; i > 0; --i ) for ( int i = g->indent; i > 0; --i )
fputs(" ", file); fputs(" ", file);
@ -188,8 +165,7 @@ void DebugLogger::Log(DebugStream stream, const char* fmt, ...)
fflush(file); fflush(file);
} }
void DebugLogger::Log(const plugin::Plugin& plugin, const char* fmt, ...) void DebugLogger::Log(const plugin::Plugin& plugin, const char* fmt, ...) {
{
std::string tok = PluginStreamName(plugin.Name()); std::string tok = PluginStreamName(plugin.Name());
if ( enabled_streams.find(tok) == enabled_streams.end() ) if ( enabled_streams.find(tok) == enabled_streams.end() )

View file

@ -17,27 +17,23 @@
if ( ::zeek::detail::debug_logger.IsEnabled(stream) ) \ if ( ::zeek::detail::debug_logger.IsEnabled(stream) ) \
::zeek::detail::debug_logger.Log(stream, __VA_ARGS__) ::zeek::detail::debug_logger.Log(stream, __VA_ARGS__)
#define DBG_LOG_VERBOSE(stream, ...) \ #define DBG_LOG_VERBOSE(stream, ...) \
if ( ::zeek::detail::debug_logger.IsVerbose() && \ if ( ::zeek::detail::debug_logger.IsVerbose() && ::zeek::detail::debug_logger.IsEnabled(stream) ) \
::zeek::detail::debug_logger.IsEnabled(stream) ) \
::zeek::detail::debug_logger.Log(stream, __VA_ARGS__) ::zeek::detail::debug_logger.Log(stream, __VA_ARGS__)
#define DBG_PUSH(stream) ::zeek::detail::debug_logger.PushIndent(stream) #define DBG_PUSH(stream) ::zeek::detail::debug_logger.PushIndent(stream)
#define DBG_POP(stream) ::zeek::detail::debug_logger.PopIndent(stream) #define DBG_POP(stream) ::zeek::detail::debug_logger.PopIndent(stream)
#define PLUGIN_DBG_LOG(plugin, ...) ::zeek::detail::debug_logger.Log(plugin, __VA_ARGS__) #define PLUGIN_DBG_LOG(plugin, ...) ::zeek::detail::debug_logger.Log(plugin, __VA_ARGS__)
namespace zeek namespace zeek {
{
namespace plugin namespace plugin {
{
class Plugin; class Plugin;
} }
// To add a new debugging stream, add a constant here as well as // To add a new debugging stream, add a constant here as well as
// an entry to DebugLogger::streams in DebugLogger.cc. // an entry to DebugLogger::streams in DebugLogger.cc.
enum DebugStream enum DebugStream {
{
DBG_SERIAL, // Serialization DBG_SERIAL, // Serialization
DBG_RULES, // Signature matching DBG_RULES, // Signature matching
DBG_STRING, // String code DBG_STRING, // String code
@ -62,11 +58,9 @@ enum DebugStream
NUM_DBGS // Has to be last NUM_DBGS // Has to be last
}; };
namespace detail namespace detail {
{
class DebugLogger class DebugLogger {
{
public: public:
// Output goes to stderr per default. // Output goes to stderr per default.
DebugLogger(); DebugLogger();
@ -75,8 +69,7 @@ public:
void OpenDebugLog(const char* filename = 0); void OpenDebugLog(const char* filename = 0);
void Log(DebugStream stream, const char* fmt, ...) __attribute__((format(printf, 3, 4))); void Log(DebugStream stream, const char* fmt, ...) __attribute__((format(printf, 3, 4)));
void Log(const plugin::Plugin& plugin, const char* fmt, ...) void Log(const plugin::Plugin& plugin, const char* fmt, ...) __attribute__((format(printf, 3, 4)));
__attribute__((format(printf, 3, 4)));
void PushIndent(DebugStream stream) { ++streams[int(stream)].indent; } void PushIndent(DebugStream stream) { ++streams[int(stream)].indent; }
void PopIndent(DebugStream stream) { --streams[int(stream)].indent; } void PopIndent(DebugStream stream) { --streams[int(stream)].indent; }
@ -101,8 +94,7 @@ private:
FILE* file; FILE* file;
bool verbose; bool verbose;
struct Stream struct Stream {
{
const char* prefix; const char* prefix;
int indent; int indent;
bool enabled; bool enabled;
@ -112,8 +104,7 @@ private:
static Stream streams[NUM_DBGS]; static Stream streams[NUM_DBGS];
const std::string PluginStreamName(const std::string& plugin_name) const std::string PluginStreamName(const std::string& plugin_name) {
{
return "plugin-" + util::strreplace(plugin_name, "::", "-"); return "plugin-" + util::strreplace(plugin_name, "::", "-");
} }
}; };

View file

@ -17,24 +17,20 @@
#define DEFAULT_SIZE 128 #define DEFAULT_SIZE 128
#define SLOP 10 #define SLOP 10
namespace zeek namespace zeek {
{
ODesc::ODesc(DescType t, File* arg_f) ODesc::ODesc(DescType t, File* arg_f) {
{
type = t; type = t;
style = STANDARD_STYLE; style = STANDARD_STYLE;
f = arg_f; f = arg_f;
if ( f == nullptr ) if ( f == nullptr ) {
{
size = DEFAULT_SIZE; size = DEFAULT_SIZE;
base = util::safe_malloc(size); base = util::safe_malloc(size);
((char*)base)[0] = '\0'; ((char*)base)[0] = '\0';
offset = 0; offset = 0;
} }
else else {
{
offset = size = 0; offset = size = 0;
base = nullptr; base = nullptr;
} }
@ -50,10 +46,8 @@ ODesc::ODesc(DescType t, File* arg_f)
utf8 = false; utf8 = false;
} }
ODesc::~ODesc() ODesc::~ODesc() {
{ if ( f ) {
if ( f )
{
if ( do_flush ) if ( do_flush )
f->Flush(); f->Flush();
} }
@ -61,38 +55,28 @@ ODesc::~ODesc()
free(base); free(base);
} }
void ODesc::EnableEscaping() void ODesc::EnableEscaping() { escape = true; }
{
escape = true;
}
void ODesc::EnableUTF8() void ODesc::EnableUTF8() { utf8 = true; }
{
utf8 = true;
}
void ODesc::PushIndent() void ODesc::PushIndent() {
{
++indent_level; ++indent_level;
NL(); NL();
} }
void ODesc::PopIndent() void ODesc::PopIndent() {
{
if ( --indent_level < 0 ) if ( --indent_level < 0 )
reporter->InternalError("ODesc::PopIndent underflow"); reporter->InternalError("ODesc::PopIndent underflow");
NL(); NL();
} }
void ODesc::PopIndentNoNL() void ODesc::PopIndentNoNL() {
{
if ( --indent_level < 0 ) if ( --indent_level < 0 )
reporter->InternalError("ODesc::PopIndent underflow"); reporter->InternalError("ODesc::PopIndent underflow");
} }
void ODesc::Add(const char* s, int do_indent) void ODesc::Add(const char* s, int do_indent) {
{
unsigned int n = strlen(s); unsigned int n = strlen(s);
if ( do_indent && IsReadable() && offset > 0 && ((const char*)base)[offset - 1] == '\n' ) if ( do_indent && IsReadable() && offset > 0 && ((const char*)base)[offset - 1] == '\n' )
@ -104,60 +88,50 @@ void ODesc::Add(const char* s, int do_indent)
AddBytes(s, n); AddBytes(s, n);
} }
void ODesc::Add(int i) void ODesc::Add(int i) {
{
if ( IsBinary() ) if ( IsBinary() )
AddBytes(&i, sizeof(i)); AddBytes(&i, sizeof(i));
else else {
{
char tmp[256]; char tmp[256];
modp_litoa10(i, tmp); modp_litoa10(i, tmp);
Add(tmp); Add(tmp);
} }
} }
void ODesc::Add(uint32_t u) void ODesc::Add(uint32_t u) {
{
if ( IsBinary() ) if ( IsBinary() )
AddBytes(&u, sizeof(u)); AddBytes(&u, sizeof(u));
else else {
{
char tmp[256]; char tmp[256];
modp_ulitoa10(u, tmp); modp_ulitoa10(u, tmp);
Add(tmp); Add(tmp);
} }
} }
void ODesc::Add(int64_t i) void ODesc::Add(int64_t i) {
{
if ( IsBinary() ) if ( IsBinary() )
AddBytes(&i, sizeof(i)); AddBytes(&i, sizeof(i));
else else {
{
char tmp[256]; char tmp[256];
modp_litoa10(i, tmp); modp_litoa10(i, tmp);
Add(tmp); Add(tmp);
} }
} }
void ODesc::Add(uint64_t u) void ODesc::Add(uint64_t u) {
{
if ( IsBinary() ) if ( IsBinary() )
AddBytes(&u, sizeof(u)); AddBytes(&u, sizeof(u));
else else {
{
char tmp[256]; char tmp[256];
modp_ulitoa10(u, tmp); modp_ulitoa10(u, tmp);
Add(tmp); Add(tmp);
} }
} }
void ODesc::Add(double d, bool no_exp) void ODesc::Add(double d, bool no_exp) {
{
if ( IsBinary() ) if ( IsBinary() )
AddBytes(&d, sizeof(d)); AddBytes(&d, sizeof(d));
else else {
{
// Buffer needs enough chars to store max. possible "double" value // Buffer needs enough chars to store max. possible "double" value
// of 1.79e308 without using scientific notation. // of 1.79e308 without using scientific notation.
char tmp[350]; char tmp[350];
@ -169,8 +143,7 @@ void ODesc::Add(double d, bool no_exp)
Add(tmp); Add(tmp);
auto approx_equal = [](double a, double b, double tolerance = 1e-6) -> bool auto approx_equal = [](double a, double b, double tolerance = 1e-6) -> bool {
{
auto v = a - b; auto v = a - b;
return v < 0 ? -v < tolerance : v < tolerance; return v < 0 ? -v < tolerance : v < tolerance;
}; };
@ -181,18 +154,11 @@ void ODesc::Add(double d, bool no_exp)
} }
} }
void ODesc::Add(const IPAddr& addr) void ODesc::Add(const IPAddr& addr) { Add(addr.AsString()); }
{
Add(addr.AsString());
}
void ODesc::Add(const IPPrefix& prefix) void ODesc::Add(const IPPrefix& prefix) { Add(prefix.AsString()); }
{
Add(prefix.AsString());
}
void ODesc::AddCS(const char* s) void ODesc::AddCS(const char* s) {
{
int n = strlen(s); int n = strlen(s);
Add(n); Add(n);
if ( ! IsBinary() ) if ( ! IsBinary() )
@ -200,21 +166,17 @@ void ODesc::AddCS(const char* s)
Add(s); Add(s);
} }
void ODesc::AddBytes(const String* s) void ODesc::AddBytes(const String* s) {
{ if ( IsReadable() ) {
if ( IsReadable() )
{
if ( Style() == RAW_STYLE ) if ( Style() == RAW_STYLE )
AddBytes(reinterpret_cast<const char*>(s->Bytes()), s->Len()); AddBytes(reinterpret_cast<const char*>(s->Bytes()), s->Len());
else else {
{
const char* str = s->Render(String::EXPANDED_STRING); const char* str = s->Render(String::EXPANDED_STRING);
Add(str); Add(str);
delete[] str; delete[] str;
} }
} }
else else {
{
Add(s->Len()); Add(s->Len());
if ( ! IsBinary() ) if ( ! IsBinary() )
Add(" "); Add(" ");
@ -222,23 +184,19 @@ void ODesc::AddBytes(const String* s)
} }
} }
void ODesc::Indent() void ODesc::Indent() {
{ if ( indent_with_spaces > 0 ) {
if ( indent_with_spaces > 0 )
{
for ( int i = 0; i < indent_level; ++i ) for ( int i = 0; i < indent_level; ++i )
for ( int j = 0; j < indent_with_spaces; ++j ) for ( int j = 0; j < indent_with_spaces; ++j )
Add(" ", 0); Add(" ", 0);
} }
else else {
{
for ( int i = 0; i < indent_level; ++i ) for ( int i = 0; i < indent_level; ++i )
Add("\t", 0); Add("\t", 0);
} }
} }
static bool starts_with(const char* str1, const char* str2, size_t len) static bool starts_with(const char* str1, const char* str2, size_t len) {
{
for ( size_t i = 0; i < len; ++i ) for ( size_t i = 0; i < len; ++i )
if ( str1[i] != str2[i] ) if ( str1[i] != str2[i] )
return false; return false;
@ -246,13 +204,11 @@ static bool starts_with(const char* str1, const char* str2, size_t len)
return true; return true;
} }
size_t ODesc::StartsWithEscapeSequence(const char* start, const char* end) size_t ODesc::StartsWithEscapeSequence(const char* start, const char* end) {
{
if ( escape_sequences.empty() ) if ( escape_sequences.empty() )
return 0; return 0;
for ( const auto& esc_str : escape_sequences ) for ( const auto& esc_str : escape_sequences ) {
{
size_t esc_len = esc_str.length(); size_t esc_len = esc_str.length();
if ( start + esc_len > end ) if ( start + esc_len > end )
@ -265,13 +221,11 @@ size_t ODesc::StartsWithEscapeSequence(const char* start, const char* end)
return 0; return 0;
} }
std::pair<const char*, size_t> ODesc::FirstEscapeLoc(const char* bytes, size_t n) std::pair<const char*, size_t> ODesc::FirstEscapeLoc(const char* bytes, size_t n) {
{
if ( IsBinary() ) if ( IsBinary() )
return {nullptr, 0}; return {nullptr, 0};
for ( size_t i = 0; i < n; ++i ) for ( size_t i = 0; i < n; ++i ) {
{
auto printable = isprint(bytes[i]); auto printable = isprint(bytes[i]);
if ( ! printable && ! utf8 ) if ( ! printable && ! utf8 )
@ -289,10 +243,8 @@ std::pair<const char*, size_t> ODesc::FirstEscapeLoc(const char* bytes, size_t n
return {nullptr, 0}; return {nullptr, 0};
} }
void ODesc::AddBytes(const void* bytes, unsigned int n) void ODesc::AddBytes(const void* bytes, unsigned int n) {
{ if ( ! escape ) {
if ( ! escape )
{
AddBytesRaw(bytes, n); AddBytesRaw(bytes, n);
return; return;
} }
@ -300,14 +252,11 @@ void ODesc::AddBytes(const void* bytes, unsigned int n)
const char* s = (const char*)bytes; const char* s = (const char*)bytes;
const char* e = (const char*)bytes + n; const char* e = (const char*)bytes + n;
while ( s < e ) while ( s < e ) {
{
auto [esc_start, esc_len] = FirstEscapeLoc(s, e - s); auto [esc_start, esc_len] = FirstEscapeLoc(s, e - s);
if ( esc_start != nullptr ) if ( esc_start != nullptr ) {
{ if ( utf8 ) {
if ( utf8 )
{
std::string result = util::json_escape_utf8(s, esc_start - s, false); std::string result = util::json_escape_utf8(s, esc_start - s, false);
AddBytesRaw(result.c_str(), result.size()); AddBytesRaw(result.c_str(), result.size());
} }
@ -317,10 +266,8 @@ void ODesc::AddBytes(const void* bytes, unsigned int n)
util::get_escaped_string(this, esc_start, esc_len, true); util::get_escaped_string(this, esc_start, esc_len, true);
s = esc_start + esc_len; s = esc_start + esc_len;
} }
else else {
{ if ( utf8 ) {
if ( utf8 )
{
std::string result = util::json_escape_utf8(s, e - s, false); std::string result = util::json_escape_utf8(s, e - s, false);
AddBytesRaw(result.c_str(), result.size()); AddBytesRaw(result.c_str(), result.size());
} }
@ -332,17 +279,14 @@ void ODesc::AddBytes(const void* bytes, unsigned int n)
} }
} }
void ODesc::AddBytesRaw(const void* bytes, unsigned int n) void ODesc::AddBytesRaw(const void* bytes, unsigned int n) {
{
if ( n == 0 ) if ( n == 0 )
return; return;
if ( f ) if ( f ) {
{
static bool write_failed = false; static bool write_failed = false;
if ( ! f->Write((const char*)bytes, n) ) if ( ! f->Write((const char*)bytes, n) ) {
{
if ( ! write_failed ) if ( ! write_failed )
// Most likely it's a "disk full" so report // Most likely it's a "disk full" so report
// subsequent failures only once. // subsequent failures only once.
@ -355,8 +299,7 @@ void ODesc::AddBytesRaw(const void* bytes, unsigned int n)
write_failed = false; write_failed = false;
} }
else else {
{
Grow(n); Grow(n);
// The following casting contortions are necessary because // The following casting contortions are necessary because
@ -369,11 +312,9 @@ void ODesc::AddBytesRaw(const void* bytes, unsigned int n)
} }
} }
void ODesc::Grow(unsigned int n) void ODesc::Grow(unsigned int n) {
{
bool size_changed = false; bool size_changed = false;
while ( offset + n + SLOP >= size ) while ( offset + n + SLOP >= size ) {
{
size *= 2; size *= 2;
size_changed = true; size_changed = true;
} }
@ -382,13 +323,11 @@ void ODesc::Grow(unsigned int n)
base = util::safe_realloc(base, size); base = util::safe_realloc(base, size);
} }
void ODesc::Clear() void ODesc::Clear() {
{
offset = 0; offset = 0;
// If we've allocated an exceedingly large amount of space, free it. // If we've allocated an exceedingly large amount of space, free it.
if ( size > 10 * 1024 * 1024 ) if ( size > 10 * 1024 * 1024 ) {
{
free(base); free(base);
size = DEFAULT_SIZE; size = DEFAULT_SIZE;
base = util::safe_malloc(size); base = util::safe_malloc(size);
@ -396,20 +335,17 @@ void ODesc::Clear()
} }
} }
bool ODesc::PushType(const Type* type) bool ODesc::PushType(const Type* type) {
{
auto res = encountered_types.insert(type); auto res = encountered_types.insert(type);
return std::get<1>(res); return std::get<1>(res);
} }
bool ODesc::PopType(const Type* type) bool ODesc::PopType(const Type* type) {
{
size_t res = encountered_types.erase(type); size_t res = encountered_types.erase(type);
return (res == 1); return (res == 1);
} }
bool ODesc::FindType(const Type* type) bool ODesc::FindType(const Type* type) {
{
auto res = encountered_types.find(type); auto res = encountered_types.find(type);
if ( res != encountered_types.end() ) if ( res != encountered_types.end() )
@ -418,8 +354,7 @@ bool ODesc::FindType(const Type* type)
return false; return false;
} }
std::string obj_desc(const Obj* o) std::string obj_desc(const Obj* o) {
{
static ODesc d; static ODesc d;
d.Clear(); d.Clear();
@ -430,8 +365,7 @@ std::string obj_desc(const Obj* o)
return d.Description(); return d.Description();
} }
std::string obj_desc_short(const Obj* o) std::string obj_desc_short(const Obj* o) {
{
static ODesc d; static ODesc d;
d.SetShort(true); d.SetShort(true);

View file

@ -10,28 +10,24 @@
#include "zeek/ZeekString.h" // for byte_vec #include "zeek/ZeekString.h" // for byte_vec
#include "zeek/util.h" // for zeek_int_t #include "zeek/util.h" // for zeek_int_t
namespace zeek namespace zeek {
{
class IPAddr; class IPAddr;
class IPPrefix; class IPPrefix;
class File; class File;
class Type; class Type;
enum DescType enum DescType {
{
DESC_READABLE, DESC_READABLE,
DESC_BINARY, DESC_BINARY,
}; };
enum DescStyle enum DescStyle {
{
STANDARD_STYLE, STANDARD_STYLE,
RAW_STYLE, RAW_STYLE,
}; };
class ODesc class ODesc {
{
public: public:
explicit ODesc(DescType t = DESC_READABLE, File* f = nullptr); explicit ODesc(DescType t = DESC_READABLE, File* f = nullptr);
@ -69,10 +65,7 @@ public:
void AddEscapeSequence(const char* s, size_t n) { escape_sequences.insert(std::string(s, n)); } void AddEscapeSequence(const char* s, size_t n) { escape_sequences.insert(std::string(s, n)); }
void AddEscapeSequence(const std::string& s) { escape_sequences.insert(s); } void AddEscapeSequence(const std::string& s) { escape_sequences.insert(s); }
void RemoveEscapeSequence(const char* s) { escape_sequences.erase(s); } void RemoveEscapeSequence(const char* s) { escape_sequences.erase(s); }
void RemoveEscapeSequence(const char* s, size_t n) void RemoveEscapeSequence(const char* s, size_t n) { escape_sequences.erase(std::string(s, n)); }
{
escape_sequences.erase(std::string(s, n));
}
void RemoveEscapeSequence(const std::string& s) { escape_sequences.erase(s); } void RemoveEscapeSequence(const std::string& s) { escape_sequences.erase(s); }
void PushIndent(); void PushIndent();
@ -100,40 +93,33 @@ public:
void AddBytes(const String* s); void AddBytes(const String* s);
void Add(const char* s1, const char* s2) void Add(const char* s1, const char* s2) {
{
Add(s1); Add(s1);
Add(s2); Add(s2);
} }
void AddSP(const char* s1, const char* s2) void AddSP(const char* s1, const char* s2) {
{
Add(s1); Add(s1);
AddSP(s2); AddSP(s2);
} }
void AddSP(const char* s) void AddSP(const char* s) {
{
Add(s); Add(s);
SP(); SP();
} }
void AddCount(zeek_int_t n) void AddCount(zeek_int_t n) {
{ if ( ! IsReadable() ) {
if ( ! IsReadable() )
{
Add(n); Add(n);
SP(); SP();
} }
} }
void SP() void SP() {
{
if ( ! IsBinary() ) if ( ! IsBinary() )
Add(" ", 0); Add(" ", 0);
} }
void NL() void NL() {
{
if ( ! IsBinary() && ! is_short ) if ( ! IsBinary() && ! is_short )
Add("\n", 0); Add("\n", 0);
} }
@ -146,8 +132,7 @@ public:
const char* Description() const { return (const char*)base; } const char* Description() const { return (const char*)base; }
const u_char* Bytes() const { return (const u_char*)base; } const u_char* Bytes() const { return (const u_char*)base; }
byte_vec TakeBytes() byte_vec TakeBytes() {
{
const void* t = base; const void* t = base;
base = nullptr; base = nullptr;
size = 0; size = 0;

View file

@ -5,15 +5,13 @@
#include "zeek/3rdparty/doctest.h" #include "zeek/3rdparty/doctest.h"
#include "zeek/Hash.h" #include "zeek/Hash.h"
namespace zeek namespace zeek {
{
// namespace detail // namespace detail
TEST_SUITE_BEGIN("Dict"); TEST_SUITE_BEGIN("Dict");
TEST_CASE("dict construction") TEST_CASE("dict construction") {
{
PDict<int> dict; PDict<int> dict;
CHECK(! dict.IsOrdered()); CHECK(! dict.IsOrdered());
CHECK(dict.Length() == 0); CHECK(dict.Length() == 0);
@ -23,8 +21,7 @@ TEST_CASE("dict construction")
CHECK(dict2.Length() == 0); CHECK(dict2.Length() == 0);
} }
TEST_CASE("dict operation") TEST_CASE("dict operation") {
{
PDict<uint32_t> dict; PDict<uint32_t> dict;
uint32_t val = 10; uint32_t val = 10;
@ -69,8 +66,7 @@ TEST_CASE("dict operation")
delete key2; delete key2;
} }
TEST_CASE("dict nthentry") TEST_CASE("dict nthentry") {
{
PDict<uint32_t> unordered(UNORDERED); PDict<uint32_t> unordered(UNORDERED);
PDict<uint32_t> ordered(ORDERED); PDict<uint32_t> ordered(ORDERED);
@ -105,8 +101,7 @@ TEST_CASE("dict nthentry")
delete ukey2; delete ukey2;
} }
TEST_CASE("dict iteration") TEST_CASE("dict iteration") {
{
PDict<uint32_t> dict; PDict<uint32_t> dict;
uint32_t val = 15; uint32_t val = 15;
@ -122,13 +117,11 @@ TEST_CASE("dict iteration")
int count = 0; int count = 0;
for ( const auto& entry : dict ) for ( const auto& entry : dict ) {
{
auto* v = static_cast<uint32_t*>(entry.value); auto* v = static_cast<uint32_t*>(entry.value);
uint64_t k = *(uint32_t*)entry.GetKey(); uint64_t k = *(uint32_t*)entry.GetKey();
switch ( count ) switch ( count ) {
{
case 0: case 0:
CHECK(k == key_val2); CHECK(k == key_val2);
CHECK(*v == val2); CHECK(*v == val2);
@ -137,8 +130,7 @@ TEST_CASE("dict iteration")
CHECK(k == key_val); CHECK(k == key_val);
CHECK(*v == val); CHECK(*v == val);
break; break;
default: default: break;
break;
} }
count++; count++;
@ -155,8 +147,7 @@ TEST_CASE("dict iteration")
delete key2; delete key2;
} }
TEST_CASE("dict robust iteration") TEST_CASE("dict robust iteration") {
{
PDict<uint32_t> dict; PDict<uint32_t> dict;
uint32_t val = 15; uint32_t val = 15;
@ -178,13 +169,11 @@ TEST_CASE("dict robust iteration")
int count = 0; int count = 0;
auto it = dict.begin_robust(); auto it = dict.begin_robust();
for ( ; it != dict.end_robust(); ++it ) for ( ; it != dict.end_robust(); ++it ) {
{
auto* v = it->value; auto* v = it->value;
uint64_t k = *(uint32_t*)it->GetKey(); uint64_t k = *(uint32_t*)it->GetKey();
switch ( count ) switch ( count ) {
{
case 0: case 0:
CHECK(k == key_val2); CHECK(k == key_val2);
CHECK(*v == val2); CHECK(*v == val2);
@ -213,13 +202,11 @@ TEST_CASE("dict robust iteration")
int count = 0; int count = 0;
auto it = dict.begin_robust(); auto it = dict.begin_robust();
for ( ; it != dict.end_robust(); ++it ) for ( ; it != dict.end_robust(); ++it ) {
{
auto* v = it->value; auto* v = it->value;
uint64_t k = *(uint32_t*)it->GetKey(); uint64_t k = *(uint32_t*)it->GetKey();
switch ( count ) switch ( count ) {
{
case 0: case 0:
CHECK(k == key_val2); CHECK(k == key_val2);
CHECK(*v == val2); CHECK(*v == val2);
@ -246,8 +233,7 @@ TEST_CASE("dict robust iteration")
delete key3; delete key3;
} }
TEST_CASE("dict ordered iteration") TEST_CASE("dict ordered iteration") {
{
PDict<uint32_t> dict(DictOrder::ORDERED); PDict<uint32_t> dict(DictOrder::ORDERED);
// These key values are specifically contrived to be inserted // These key values are specifically contrived to be inserted
@ -276,8 +262,7 @@ TEST_CASE("dict ordered iteration")
int count = 0; int count = 0;
for ( const auto& entry : dict ) for ( const auto& entry : dict ) {
{
auto* v = static_cast<uint32_t*>(entry.value); auto* v = static_cast<uint32_t*>(entry.value);
uint32_t k = *(uint32_t*)entry.GetKey(); uint32_t k = *(uint32_t*)entry.GetKey();
@ -296,8 +281,7 @@ TEST_CASE("dict ordered iteration")
dict.Insert(key4.get(), &val4); dict.Insert(key4.get(), &val4);
count = 0; count = 0;
for ( const auto& entry : dict ) for ( const auto& entry : dict ) {
{
auto* v = static_cast<uint32_t*>(entry.value); auto* v = static_cast<uint32_t*>(entry.value);
uint32_t k = *(uint32_t*)entry.GetKey(); uint32_t k = *(uint32_t*)entry.GetKey();
@ -318,8 +302,7 @@ TEST_CASE("dict ordered iteration")
dict.Remove(key2.get()); dict.Remove(key2.get());
count = 0; count = 0;
for ( const auto& entry : dict ) for ( const auto& entry : dict ) {
{
auto* v = static_cast<uint32_t*>(entry.value); auto* v = static_cast<uint32_t*>(entry.value);
uint32_t k = *(uint32_t*)entry.GetKey(); uint32_t k = *(uint32_t*)entry.GetKey();
@ -336,16 +319,14 @@ TEST_CASE("dict ordered iteration")
} }
} }
class DictTestDummy class DictTestDummy {
{
public: public:
DictTestDummy(int v) : v(v) {} DictTestDummy(int v) : v(v) {}
~DictTestDummy() = default; ~DictTestDummy() = default;
int v = 0; int v = 0;
}; };
TEST_CASE("dict robust iteration replacement") TEST_CASE("dict robust iteration replacement") {
{
PDict<DictTestDummy> dict; PDict<DictTestDummy> dict;
DictTestDummy* val1 = new DictTestDummy(15); DictTestDummy* val1 = new DictTestDummy(15);
@ -369,7 +350,8 @@ TEST_CASE("dict robust iteration replacement")
// Iterate past the first couple of elements so we're not done, but the // Iterate past the first couple of elements so we're not done, but the
// iterator is still pointing at a valid element. // iterator is still pointing at a valid element.
for ( ; count != 2 && it != dict.end_robust(); ++count, ++it ) { } for ( ; count != 2 && it != dict.end_robust(); ++count, ++it ) {
}
// Store off the value at this iterator index // Store off the value at this iterator index
auto* v = it->value; auto* v = it->value;
@ -383,8 +365,7 @@ TEST_CASE("dict robust iteration replacement")
delete val2; delete val2;
// This shouldn't crash with AddressSanitizer // This shouldn't crash with AddressSanitizer
for ( ; it != dict.end_robust(); ++it ) for ( ; it != dict.end_robust(); ++it ) {
{
uint64_t k = *(uint32_t*)it->GetKey(); uint64_t k = *(uint32_t*)it->GetKey();
auto* v = it->value; auto* v = it->value;
CHECK(v->v == 50); CHECK(v->v == 50);
@ -399,8 +380,7 @@ TEST_CASE("dict robust iteration replacement")
delete val4; delete val4;
} }
TEST_CASE("dict iterator invalidation") TEST_CASE("dict iterator invalidation") {
{
PDict<uint32_t> dict; PDict<uint32_t> dict;
uint32_t val = 15; uint32_t val = 15;
@ -454,9 +434,6 @@ TEST_CASE("dict iterator invalidation")
} }
// private // private
void generic_delete_func(void* v) void generic_delete_func(void* v) { free(v); }
{
free(v);
}
} // namespace zeek } // namespace zeek

File diff suppressed because it is too large Load diff

View file

@ -14,11 +14,9 @@
#include "zeek/Var.h" #include "zeek/Var.h"
#include "zeek/ZeekString.h" #include "zeek/ZeekString.h"
namespace zeek::detail namespace zeek::detail {
{
Discarder::Discarder() Discarder::Discarder() {
{
check_ip = id::find_func("discarder_check_ip"); check_ip = id::find_func("discarder_check_ip");
check_tcp = id::find_func("discarder_check_tcp"); check_tcp = id::find_func("discarder_check_tcp");
check_udp = id::find_func("discarder_check_udp"); check_udp = id::find_func("discarder_check_udp");
@ -27,26 +25,19 @@ Discarder::Discarder()
discarder_maxlen = static_cast<int>(id::find_val("discarder_maxlen")->AsCount()); discarder_maxlen = static_cast<int>(id::find_val("discarder_maxlen")->AsCount());
} }
bool Discarder::IsActive() bool Discarder::IsActive() { return check_ip || check_tcp || check_udp || check_icmp; }
{
return check_ip || check_tcp || check_udp || check_icmp;
}
bool Discarder::NextPacket(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen) bool Discarder::NextPacket(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen) {
{
bool discard_packet = false; bool discard_packet = false;
if ( check_ip ) if ( check_ip ) {
{
zeek::Args args{ip->ToPktHdrVal()}; zeek::Args args{ip->ToPktHdrVal()};
try try {
{
discard_packet = check_ip->Invoke(&args)->AsBool(); discard_packet = check_ip->Invoke(&args)->AsBool();
} }
catch ( InterpreterException& e ) catch ( InterpreterException& e ) {
{
discard_packet = false; discard_packet = false;
} }
@ -70,8 +61,7 @@ bool Discarder::NextPacket(const std::shared_ptr<IP_Hdr>& ip, int len, int caple
bool is_tcp = (proto == IPPROTO_TCP); bool is_tcp = (proto == IPPROTO_TCP);
bool is_udp = (proto == IPPROTO_UDP); bool is_udp = (proto == IPPROTO_UDP);
int min_hdr_len = is_tcp ? sizeof(struct tcphdr) int min_hdr_len = is_tcp ? sizeof(struct tcphdr) : (is_udp ? sizeof(struct udphdr) : sizeof(struct icmp));
: (is_udp ? sizeof(struct udphdr) : sizeof(struct icmp));
if ( len < min_hdr_len || caplen < min_hdr_len ) if ( len < min_hdr_len || caplen < min_hdr_len )
// we don't have a complete protocol header // we don't have a complete protocol header
@ -81,10 +71,8 @@ bool Discarder::NextPacket(const std::shared_ptr<IP_Hdr>& ip, int len, int caple
// this gets advanced past the transport header. // this gets advanced past the transport header.
const u_char* data = ip->Payload(); const u_char* data = ip->Payload();
if ( is_tcp ) if ( is_tcp ) {
{ if ( check_tcp ) {
if ( check_tcp )
{
const struct tcphdr* tp = (const struct tcphdr*)data; const struct tcphdr* tp = (const struct tcphdr*)data;
int th_len = tp->th_off * 4; int th_len = tp->th_off * 4;
@ -93,22 +81,18 @@ bool Discarder::NextPacket(const std::shared_ptr<IP_Hdr>& ip, int len, int caple
{AdoptRef{}, BuildData(data, th_len, len, caplen)}, {AdoptRef{}, BuildData(data, th_len, len, caplen)},
}; };
try try {
{
discard_packet = check_tcp->Invoke(&args)->AsBool(); discard_packet = check_tcp->Invoke(&args)->AsBool();
} }
catch ( InterpreterException& e ) catch ( InterpreterException& e ) {
{
discard_packet = false; discard_packet = false;
} }
} }
} }
else if ( is_udp ) else if ( is_udp ) {
{ if ( check_udp ) {
if ( check_udp )
{
const struct udphdr* up = (const struct udphdr*)data; const struct udphdr* up = (const struct udphdr*)data;
int uh_len = sizeof(struct udphdr); int uh_len = sizeof(struct udphdr);
@ -117,33 +101,27 @@ bool Discarder::NextPacket(const std::shared_ptr<IP_Hdr>& ip, int len, int caple
{AdoptRef{}, BuildData(data, uh_len, len, caplen)}, {AdoptRef{}, BuildData(data, uh_len, len, caplen)},
}; };
try try {
{
discard_packet = check_udp->Invoke(&args)->AsBool(); discard_packet = check_udp->Invoke(&args)->AsBool();
} }
catch ( InterpreterException& e ) catch ( InterpreterException& e ) {
{
discard_packet = false; discard_packet = false;
} }
} }
} }
else else {
{ if ( check_icmp ) {
if ( check_icmp )
{
const struct icmp* ih = (const struct icmp*)data; const struct icmp* ih = (const struct icmp*)data;
zeek::Args args{ip->ToPktHdrVal()}; zeek::Args args{ip->ToPktHdrVal()};
try try {
{
discard_packet = check_icmp->Invoke(&args)->AsBool(); discard_packet = check_icmp->Invoke(&args)->AsBool();
} }
catch ( InterpreterException& e ) catch ( InterpreterException& e ) {
{
discard_packet = false; discard_packet = false;
} }
} }
@ -152,8 +130,7 @@ bool Discarder::NextPacket(const std::shared_ptr<IP_Hdr>& ip, int len, int caple
return discard_packet; return discard_packet;
} }
Val* Discarder::BuildData(const u_char* data, int hdrlen, int len, int caplen) Val* Discarder::BuildData(const u_char* data, int hdrlen, int len, int caplen) {
{
len -= hdrlen; len -= hdrlen;
caplen -= hdrlen; caplen -= hdrlen;
data += hdrlen; data += hdrlen;

View file

@ -7,19 +7,16 @@
#include "zeek/IntrusivePtr.h" #include "zeek/IntrusivePtr.h"
namespace zeek namespace zeek {
{
class IP_Hdr; class IP_Hdr;
class Val; class Val;
class Func; class Func;
using FuncPtr = IntrusivePtr<Func>; using FuncPtr = IntrusivePtr<Func>;
namespace detail namespace detail {
{
class Discarder final class Discarder final {
{
public: public:
Discarder(); Discarder();
~Discarder() = default; ~Discarder() = default;

View file

@ -7,11 +7,9 @@
#include "zeek/CCL.h" #include "zeek/CCL.h"
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek::detail namespace zeek::detail {
{
EquivClass::EquivClass(int arg_size) EquivClass::EquivClass(int arg_size) {
{
size = arg_size; size = arg_size;
fwd = new int[size]; fwd = new int[size];
bck = new int[size]; bck = new int[size];
@ -25,10 +23,8 @@ EquivClass::EquivClass(int arg_size)
bck[0] = ec_nil; bck[0] = ec_nil;
fwd[size - 1] = ec_nil; fwd[size - 1] = ec_nil;
for ( int i = 0; i < size; ++i ) for ( int i = 0; i < size; ++i ) {
{ if ( i > 0 ) {
if ( i > 0 )
{
fwd[i - 1] = i; fwd[i - 1] = i;
bck[i] = i - 1; bck[i] = i - 1;
} }
@ -38,8 +34,7 @@ EquivClass::EquivClass(int arg_size)
} }
} }
EquivClass::~EquivClass() EquivClass::~EquivClass() {
{
delete[] fwd; delete[] fwd;
delete[] bck; delete[] bck;
delete[] equiv_class; delete[] equiv_class;
@ -47,8 +42,7 @@ EquivClass::~EquivClass()
delete[] ccl_flags; delete[] ccl_flags;
} }
void EquivClass::ConvertCCL(CCL* ccl) void EquivClass::ConvertCCL(CCL* ccl) {
{
// For each character in the class, add the character's // For each character in the class, add the character's
// equivalence class to the new "character" class we are // equivalence class to the new "character" class we are
// creating. Thus when we are all done, the character class // creating. Thus when we are all done, the character class
@ -58,8 +52,7 @@ void EquivClass::ConvertCCL(CCL* ccl)
int_list* c_syms = ccl->Syms(); int_list* c_syms = ccl->Syms();
int_list* new_syms = new int_list; int_list* new_syms = new int_list;
for ( auto sym : *c_syms ) for ( auto sym : *c_syms ) {
{
if ( IsRep(sym) ) if ( IsRep(sym) )
new_syms->push_back(SymEquivClass(sym)); new_syms->push_back(SymEquivClass(sym));
} }
@ -67,18 +60,15 @@ void EquivClass::ConvertCCL(CCL* ccl)
ccl->ReplaceSyms(new_syms); ccl->ReplaceSyms(new_syms);
} }
int EquivClass::BuildECs() int EquivClass::BuildECs() {
{
// Create equivalence class numbers. If bck[x] is nil, // Create equivalence class numbers. If bck[x] is nil,
// then x is the representative of its equivalence class. // then x is the representative of its equivalence class.
for ( int i = 0; i < size; ++i ) for ( int i = 0; i < size; ++i )
if ( bck[i] == ec_nil ) if ( bck[i] == ec_nil ) {
{
equiv_class[i] = num_ecs++; equiv_class[i] = num_ecs++;
rep[i] = i; rep[i] = i;
for ( int j = fwd[i]; j != ec_nil; j = fwd[j] ) for ( int j = fwd[i]; j != ec_nil; j = fwd[j] ) {
{
equiv_class[j] = equiv_class[i]; equiv_class[j] = equiv_class[i];
rep[j] = i; rep[j] = i;
} }
@ -87,21 +77,18 @@ int EquivClass::BuildECs()
return num_ecs; return num_ecs;
} }
void EquivClass::CCL_Use(CCL* ccl) void EquivClass::CCL_Use(CCL* ccl) {
{
// Note that it doesn't matter whether or not the character class is // Note that it doesn't matter whether or not the character class is
// negated. The same results will be obtained in either case. // negated. The same results will be obtained in either case.
if ( ! ccl_flags ) if ( ! ccl_flags ) {
{
ccl_flags = new int[size]; ccl_flags = new int[size];
for ( int i = 0; i < size; ++i ) for ( int i = 0; i < size; ++i )
ccl_flags[i] = 0; ccl_flags[i] = 0;
} }
int_list* csyms = ccl->Syms(); int_list* csyms = ccl->Syms();
for ( size_t i = 0; i < csyms->size(); /* no increment */ ) for ( size_t i = 0; i < csyms->size(); /* no increment */ ) {
{
int sym = (*csyms)[i]; int sym = (*csyms)[i];
int old_ec = bck[sym]; int old_ec = bck[sym];
@ -109,17 +96,14 @@ void EquivClass::CCL_Use(CCL* ccl)
size_t j = i + 1; size_t j = i + 1;
for ( int k = fwd[sym]; k && k < size; k = fwd[k] ) for ( int k = fwd[sym]; k && k < size; k = fwd[k] ) { // look for the symbol in the character class
{ // look for the symbol in the character class for ( ; j < csyms->size(); ++j ) {
for ( ; j < csyms->size(); ++j )
{
if ( (*csyms)[j] > k ) if ( (*csyms)[j] > k )
// Since the character class is sorted, // Since the character class is sorted,
// we can stop. // we can stop.
break; break;
if ( (*csyms)[j] == k && ! ccl_flags[j] ) if ( (*csyms)[j] == k && ! ccl_flags[j] ) {
{
// We found an old companion of sym // We found an old companion of sym
// in the ccl. Link it into the new // in the ccl. Link it into the new
// equivalence class and flag it as // equivalence class and flag it as
@ -150,8 +134,7 @@ void EquivClass::CCL_Use(CCL* ccl)
old_ec = k; old_ec = k;
} }
if ( bck[sym] != ec_nil || old_ec != bck[sym] ) if ( bck[sym] != ec_nil || old_ec != bck[sym] ) {
{
bck[sym] = ec_nil; bck[sym] = ec_nil;
fwd[old_ec] = ec_nil; fwd[old_ec] = ec_nil;
} }
@ -165,8 +148,7 @@ void EquivClass::CCL_Use(CCL* ccl)
} }
} }
void EquivClass::UniqueChar(int sym) void EquivClass::UniqueChar(int sym) {
{
// If until now the character has been a proper subset of // If until now the character has been a proper subset of
// an equivalence class, break it away to create a new ec. // an equivalence class, break it away to create a new ec.
@ -180,17 +162,13 @@ void EquivClass::UniqueChar(int sym)
bck[sym] = ec_nil; bck[sym] = ec_nil;
} }
void EquivClass::Dump(FILE* f) void EquivClass::Dump(FILE* f) {
{
fprintf(f, "%d symbols in EC yielded %d ecs\n", size, num_ecs); fprintf(f, "%d symbols in EC yielded %d ecs\n", size, num_ecs);
for ( int i = 0; i < size; ++i ) for ( int i = 0; i < size; ++i )
if ( SymEquivClass(i) != 0 ) // skip usually huge default ec if ( SymEquivClass(i) != 0 ) // skip usually huge default ec
fprintf(f, "map %d ('%c') -> %d\n", i, i, SymEquivClass(i)); fprintf(f, "map %d ('%c') -> %d\n", i, i, SymEquivClass(i));
} }
int EquivClass::Size() const int EquivClass::Size() const { return padded_sizeof(*this) + util::pad_size(sizeof(int) * size * (ccl_flags ? 5 : 4)); }
{
return padded_sizeof(*this) + util::pad_size(sizeof(int) * size * (ccl_flags ? 5 : 4));
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -4,13 +4,11 @@
#include <stdio.h> #include <stdio.h>
namespace zeek::detail namespace zeek::detail {
{
class CCL; class CCL;
class EquivClass class EquivClass {
{
public: public:
explicit EquivClass(int size); explicit EquivClass(int size);
~EquivClass(); ~EquivClass();

View file

@ -15,20 +15,22 @@
zeek::EventMgr zeek::event_mgr; zeek::EventMgr zeek::event_mgr;
namespace zeek namespace zeek {
{
Event::Event(const EventHandlerPtr& arg_handler, zeek::Args arg_args, Event::Event(const EventHandlerPtr& arg_handler, zeek::Args arg_args, util::detail::SourceID arg_src,
util::detail::SourceID arg_src, analyzer::ID arg_aid, Obj* arg_obj, double arg_ts) analyzer::ID arg_aid, Obj* arg_obj, double arg_ts)
: handler(arg_handler), args(std::move(arg_args)), src(arg_src), aid(arg_aid), ts(arg_ts), : handler(arg_handler),
obj(arg_obj), next_event(nullptr) args(std::move(arg_args)),
{ src(arg_src),
aid(arg_aid),
ts(arg_ts),
obj(arg_obj),
next_event(nullptr) {
if ( obj ) if ( obj )
Ref(obj); Ref(obj);
} }
void Event::Describe(ODesc* d) const void Event::Describe(ODesc* d) const {
{
if ( d->IsReadable() ) if ( d->IsReadable() )
d->AddSP("event"); d->AddSP("event");
@ -42,21 +44,18 @@ void Event::Describe(ODesc* d) const
d->Add("("); d->Add("(");
} }
void Event::Dispatch(bool no_remote) void Event::Dispatch(bool no_remote) {
{
if ( src == util::detail::SOURCE_BROKER ) if ( src == util::detail::SOURCE_BROKER )
no_remote = true; no_remote = true;
if ( handler->ErrorHandler() ) if ( handler->ErrorHandler() )
reporter->BeginErrorHandler(); reporter->BeginErrorHandler();
try try {
{
handler->Call(&args, no_remote, ts); handler->Call(&args, no_remote, ts);
} }
catch ( InterpreterException& e ) catch ( InterpreterException& e ) {
{
// Already reported. // Already reported.
} }
@ -68,8 +67,7 @@ void Event::Dispatch(bool no_remote)
reporter->EndErrorHandler(); reporter->EndErrorHandler();
} }
EventMgr::EventMgr() EventMgr::EventMgr() {
{
head = tail = nullptr; head = tail = nullptr;
current_src = util::detail::SOURCE_LOCAL; current_src = util::detail::SOURCE_LOCAL;
current_aid = 0; current_aid = 0;
@ -78,10 +76,8 @@ EventMgr::EventMgr()
draining = false; draining = false;
} }
EventMgr::~EventMgr() EventMgr::~EventMgr() {
{ while ( head ) {
while ( head )
{
Event* n = head->NextEvent(); Event* n = head->NextEvent();
Unref(head); Unref(head);
head = n; head = n;
@ -90,26 +86,22 @@ EventMgr::~EventMgr()
Unref(src_val); Unref(src_val);
} }
void EventMgr::Enqueue(const EventHandlerPtr& h, Args vl, util::detail::SourceID src, void EventMgr::Enqueue(const EventHandlerPtr& h, Args vl, util::detail::SourceID src, analyzer::ID aid, Obj* obj,
analyzer::ID aid, Obj* obj, double ts) double ts) {
{
QueueEvent(new Event(h, std::move(vl), src, aid, obj, ts)); QueueEvent(new Event(h, std::move(vl), src, aid, obj, ts));
} }
void EventMgr::QueueEvent(Event* event) void EventMgr::QueueEvent(Event* event) {
{
bool done = PLUGIN_HOOK_WITH_RESULT(HOOK_QUEUE_EVENT, HookQueueEvent(event), false); bool done = PLUGIN_HOOK_WITH_RESULT(HOOK_QUEUE_EVENT, HookQueueEvent(event), false);
if ( done ) if ( done )
return; return;
if ( ! head ) if ( ! head ) {
{
head = tail = event; head = tail = event;
queue_flare.Fire(); queue_flare.Fire();
} }
else else {
{
tail->SetNext(event); tail->SetNext(event);
tail = event; tail = event;
} }
@ -117,8 +109,7 @@ void EventMgr::QueueEvent(Event* event)
++event_mgr.num_events_queued; ++event_mgr.num_events_queued;
} }
void EventMgr::Dispatch(Event* event, bool no_remote) void EventMgr::Dispatch(Event* event, bool no_remote) {
{
current_src = event->Source(); current_src = event->Source();
current_aid = event->Analyzer(); current_aid = event->Analyzer();
current_ts = event->Time(); current_ts = event->Time();
@ -126,8 +117,7 @@ void EventMgr::Dispatch(Event* event, bool no_remote)
Unref(event); Unref(event);
} }
void EventMgr::Drain() void EventMgr::Drain() {
{
if ( event_queue_flush_point ) if ( event_queue_flush_point )
Enqueue(event_queue_flush_point, Args{}); Enqueue(event_queue_flush_point, Args{});
@ -144,14 +134,12 @@ void EventMgr::Drain()
// just one round to make it less likely to break existing scripts // just one round to make it less likely to break existing scripts
// that expect the old behavior to trigger something quickly. // that expect the old behavior to trigger something quickly.
for ( int round = 0; head && round < 2; round++ ) for ( int round = 0; head && round < 2; round++ ) {
{
Event* current = head; Event* current = head;
head = nullptr; head = nullptr;
tail = nullptr; tail = nullptr;
while ( current ) while ( current ) {
{
Event* next = current->NextEvent(); Event* next = current->NextEvent();
current_src = current->Source(); current_src = current->Source();
@ -174,8 +162,7 @@ void EventMgr::Drain()
detail::trigger_mgr->Process(); detail::trigger_mgr->Process();
} }
void EventMgr::Describe(ODesc* d) const void EventMgr::Describe(ODesc* d) const {
{
int n = 0; int n = 0;
Event* e; Event* e;
for ( e = head; e; e = e->NextEvent() ) for ( e = head; e; e = e->NextEvent() )
@ -183,15 +170,13 @@ void EventMgr::Describe(ODesc* d) const
d->AddCount(n); d->AddCount(n);
for ( e = head; e; e = e->NextEvent() ) for ( e = head; e; e = e->NextEvent() ) {
{
e->Describe(d); e->Describe(d);
d->NL(); d->NL();
} }
} }
void EventMgr::Process() void EventMgr::Process() {
{
queue_flare.Extinguish(); queue_flare.Extinguish();
// While it semes like the most logical thing to do, we dont want // While it semes like the most logical thing to do, we dont want
@ -200,15 +185,13 @@ void EventMgr::Process()
// and had the opportunity to spawn new events. // and had the opportunity to spawn new events.
} }
void EventMgr::InitPostScript() void EventMgr::InitPostScript() {
{
iosource_mgr->Register(this, true, false); iosource_mgr->Register(this, true, false);
if ( ! iosource_mgr->RegisterFd(queue_flare.FD(), this) ) if ( ! iosource_mgr->RegisterFd(queue_flare.FD(), this) )
reporter->FatalError("Failed to register event manager FD with iosource_mgr"); reporter->FatalError("Failed to register event manager FD with iosource_mgr");
} }
void EventMgr::InitPostFork() void EventMgr::InitPostFork() {
{
// Re-initialize the flare, closing and re-opening the underlying // Re-initialize the flare, closing and re-opening the underlying
// pipe FDs. This is needed so that each Zeek process in a supervisor // pipe FDs. This is needed so that each Zeek process in a supervisor
// setup has its own pipe instead of them all sharing a single pipe. // setup has its own pipe instead of them all sharing a single pipe.

View file

@ -12,22 +12,18 @@
#include "zeek/analyzer/Analyzer.h" #include "zeek/analyzer/Analyzer.h"
#include "zeek/iosource/IOSource.h" #include "zeek/iosource/IOSource.h"
namespace zeek namespace zeek {
{
namespace run_state namespace run_state {
{
extern double network_time; extern double network_time;
} // namespace run_state } // namespace run_state
class EventMgr; class EventMgr;
class Event final : public Obj class Event final : public Obj {
{
public: public:
Event(const EventHandlerPtr& handler, zeek::Args args, Event(const EventHandlerPtr& handler, zeek::Args args, util::detail::SourceID src = util::detail::SOURCE_LOCAL,
util::detail::SourceID src = util::detail::SOURCE_LOCAL, analyzer::ID aid = 0, analyzer::ID aid = 0, Obj* obj = nullptr, double ts = run_state::network_time);
Obj* obj = nullptr, double ts = run_state::network_time);
void SetNext(Event* n) { next_event = n; } void SetNext(Event* n) { next_event = n; }
Event* NextEvent() const { return next_event; } Event* NextEvent() const { return next_event; }
@ -56,8 +52,7 @@ protected:
Event* next_event; Event* next_event;
}; };
class EventMgr final : public Obj, public iosource::IOSource class EventMgr final : public Obj, public iosource::IOSource {
{
public: public:
EventMgr(); EventMgr();
~EventMgr() override; ~EventMgr() override;
@ -76,17 +71,15 @@ public:
* @param ts timestamp at which the event is intended to be executed * @param ts timestamp at which the event is intended to be executed
* (defaults to current network time). * (defaults to current network time).
*/ */
void Enqueue(const EventHandlerPtr& h, zeek::Args vl, void Enqueue(const EventHandlerPtr& h, zeek::Args vl, util::detail::SourceID src = util::detail::SOURCE_LOCAL,
util::detail::SourceID src = util::detail::SOURCE_LOCAL, analyzer::ID aid = 0, analyzer::ID aid = 0, Obj* obj = nullptr, double ts = run_state::network_time);
Obj* obj = nullptr, double ts = run_state::network_time);
/** /**
* A version of Enqueue() taking a variable number of arguments. * A version of Enqueue() taking a variable number of arguments.
*/ */
template<class... Args> template<class... Args>
std::enable_if_t<std::is_convertible_v<std::tuple_element_t<0, std::tuple<Args...>>, ValPtr>> std::enable_if_t<std::is_convertible_v<std::tuple_element_t<0, std::tuple<Args...>>, ValPtr>> Enqueue(
Enqueue(const EventHandlerPtr& h, Args&&... args) const EventHandlerPtr& h, Args&&... args) {
{
return Enqueue(h, zeek::Args{std::forward<Args>(args)...}); return Enqueue(h, zeek::Args{std::forward<Args>(args)...});
} }

View file

@ -11,11 +11,9 @@
#include "zeek/broker/Manager.h" #include "zeek/broker/Manager.h"
#include "zeek/telemetry/Manager.h" #include "zeek/telemetry/Manager.h"
namespace zeek namespace zeek {
{
EventHandler::EventHandler(std::string arg_name) EventHandler::EventHandler(std::string arg_name) {
{
name = std::move(arg_name); name = std::move(arg_name);
used = false; used = false;
error_handler = false; error_handler = false;
@ -23,19 +21,15 @@ EventHandler::EventHandler(std::string arg_name)
generate_always = false; generate_always = false;
} }
EventHandler::operator bool() const EventHandler::operator bool() const {
{ return enabled && ((local && local->HasEnabledBodies()) || generate_always || ! auto_publish.empty());
return enabled &&
((local && local->HasEnabledBodies()) || generate_always || ! auto_publish.empty());
} }
const FuncTypePtr& EventHandler::GetType(bool check_export) const FuncTypePtr& EventHandler::GetType(bool check_export) {
{
if ( type ) if ( type )
return type; return type;
const auto& id = detail::lookup_ID(name.data(), detail::current_module.c_str(), false, false, const auto& id = detail::lookup_ID(name.data(), detail::current_module.c_str(), false, false, check_export);
check_export);
if ( ! id ) if ( ! id )
return FuncType::nil; return FuncType::nil;
@ -47,17 +41,12 @@ const FuncTypePtr& EventHandler::GetType(bool check_export)
return type; return type;
} }
void EventHandler::SetFunc(FuncPtr f) void EventHandler::SetFunc(FuncPtr f) { local = std::move(f); }
{
local = std::move(f);
}
void EventHandler::Call(Args* vl, bool no_remote, double ts) void EventHandler::Call(Args* vl, bool no_remote, double ts) {
{ if ( ! call_count ) {
if ( ! call_count ) static auto eh_invocations_family =
{ telemetry_mgr->CounterFamily("zeek", "event-handler-invocations", {"name"},
static auto eh_invocations_family = telemetry_mgr->CounterFamily(
"zeek", "event-handler-invocations", {"name"},
"Number of times the given event handler was called", "1", true); "Number of times the given event handler was called", "1", true);
call_count = eh_invocations_family.GetOrAdd({{"name", name}}); call_count = eh_invocations_family.GetOrAdd({{"name", name}});
@ -68,23 +57,19 @@ void EventHandler::Call(Args* vl, bool no_remote, double ts)
if ( new_event ) if ( new_event )
NewEvent(vl); NewEvent(vl);
if ( ! no_remote ) if ( ! no_remote ) {
{ if ( ! auto_publish.empty() ) {
if ( ! auto_publish.empty() )
{
// Send event in form [name, xs...] where xs represent the arguments. // Send event in form [name, xs...] where xs represent the arguments.
broker::vector xs; broker::vector xs;
xs.reserve(vl->size()); xs.reserve(vl->size());
bool valid_args = true; bool valid_args = true;
for ( auto i = 0u; i < vl->size(); ++i ) for ( auto i = 0u; i < vl->size(); ++i ) {
{
auto opt_data = Broker::detail::val_to_data((*vl)[i].get()); auto opt_data = Broker::detail::val_to_data((*vl)[i].get());
if ( opt_data ) if ( opt_data )
xs.emplace_back(std::move(*opt_data)); xs.emplace_back(std::move(*opt_data));
else else {
{
valid_args = false; valid_args = false;
auto_publish.clear(); auto_publish.clear();
reporter->Error("failed auto-remote event '%s', disabled", Name()); reporter->Error("failed auto-remote event '%s', disabled", Name());
@ -92,17 +77,14 @@ void EventHandler::Call(Args* vl, bool no_remote, double ts)
} }
} }
if ( valid_args ) if ( valid_args ) {
{ for ( auto it = auto_publish.begin();; ) {
for ( auto it = auto_publish.begin();; )
{
const auto& topic = *it; const auto& topic = *it;
++it; ++it;
if ( it != auto_publish.end() ) if ( it != auto_publish.end() )
broker_mgr->PublishEvent(topic, Name(), xs, ts); broker_mgr->PublishEvent(topic, Name(), xs, ts);
else else {
{
broker_mgr->PublishEvent(topic, Name(), std::move(xs), ts); broker_mgr->PublishEvent(topic, Name(), std::move(xs), ts);
break; break;
} }
@ -116,8 +98,7 @@ void EventHandler::Call(Args* vl, bool no_remote, double ts)
local->Invoke(vl); local->Invoke(vl);
} }
void EventHandler::NewEvent(Args* vl) void EventHandler::NewEvent(Args* vl) {
{
if ( ! new_event ) if ( ! new_event )
return; return;

View file

@ -11,19 +11,16 @@
#include "zeek/ZeekList.h" #include "zeek/ZeekList.h"
#include "zeek/telemetry/Counter.h" #include "zeek/telemetry/Counter.h"
namespace zeek namespace zeek {
{
namespace run_state namespace run_state {
{
extern double network_time; extern double network_time;
} // namespace run_state } // namespace run_state
class Func; class Func;
using FuncPtr = IntrusivePtr<Func>; using FuncPtr = IntrusivePtr<Func>;
class EventHandler class EventHandler {
{
public: public:
explicit EventHandler(std::string name); explicit EventHandler(std::string name);
@ -56,10 +53,7 @@ public:
// Flags the event as interesting even if there is no body defined. In // Flags the event as interesting even if there is no body defined. In
// particular, this will then still pass the event on to plugins. // particular, this will then still pass the event on to plugins.
void SetGenerateAlways(bool arg_generate_always = true) void SetGenerateAlways(bool arg_generate_always = true) { generate_always = arg_generate_always; }
{
generate_always = arg_generate_always;
}
bool GenerateAlways() const { return generate_always; } bool GenerateAlways() const { return generate_always; }
uint64_t CallCount() const { return call_count ? call_count->Value() : 0; } uint64_t CallCount() const { return call_count ? call_count->Value() : 0; }
@ -82,19 +76,16 @@ private:
}; };
// Encapsulates a ptr to an event handler to overload the boolean operator. // Encapsulates a ptr to an event handler to overload the boolean operator.
class EventHandlerPtr class EventHandlerPtr {
{
public: public:
EventHandlerPtr(EventHandler* p = nullptr) { handler = p; } EventHandlerPtr(EventHandler* p = nullptr) { handler = p; }
EventHandlerPtr(const EventHandlerPtr& h) { handler = h.handler; } EventHandlerPtr(const EventHandlerPtr& h) { handler = h.handler; }
const EventHandlerPtr& operator=(EventHandler* p) const EventHandlerPtr& operator=(EventHandler* p) {
{
handler = p; handler = p;
return *this; return *this;
} }
const EventHandlerPtr& operator=(const EventHandlerPtr& h) const EventHandlerPtr& operator=(const EventHandlerPtr& h) {
{
handler = h.handler; handler = h.handler;
return *this; return *this;
} }

View file

@ -7,20 +7,17 @@
#include "zeek/RE.h" #include "zeek/RE.h"
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
namespace zeek namespace zeek {
{
EventRegistry::EventRegistry() = default; EventRegistry::EventRegistry() = default;
EventRegistry::~EventRegistry() noexcept = default; EventRegistry::~EventRegistry() noexcept = default;
EventHandlerPtr EventRegistry::Register(std::string_view name, bool is_from_script) EventHandlerPtr EventRegistry::Register(std::string_view name, bool is_from_script) {
{
// If there already is an entry in the registry, we have a // If there already is an entry in the registry, we have a
// local handler on the script layer. // local handler on the script layer.
EventHandler* h = event_registry->Lookup(name); EventHandler* h = event_registry->Lookup(name);
if ( h ) if ( h ) {
{
if ( ! is_from_script ) if ( ! is_from_script )
not_only_from_script.insert(std::string(name)); not_only_from_script.insert(std::string(name));
@ -36,8 +33,7 @@ EventHandlerPtr EventRegistry::Register(std::string_view name, bool is_from_scri
return h; return h;
} }
void EventRegistry::Register(EventHandlerPtr handler, bool is_from_script) void EventRegistry::Register(EventHandlerPtr handler, bool is_from_script) {
{
std::string name = handler->Name(); std::string name = handler->Name();
handlers[name] = std::unique_ptr<EventHandler>(handler.Ptr()); handlers[name] = std::unique_ptr<EventHandler>(handler.Ptr());
@ -46,8 +42,7 @@ void EventRegistry::Register(EventHandlerPtr handler, bool is_from_script)
not_only_from_script.insert(name); not_only_from_script.insert(name);
} }
EventHandler* EventRegistry::Lookup(std::string_view name) EventHandler* EventRegistry::Lookup(std::string_view name) {
{
auto it = handlers.find(name); auto it = handlers.find(name);
if ( it != handlers.end() ) if ( it != handlers.end() )
return it->second.get(); return it->second.get();
@ -55,17 +50,14 @@ EventHandler* EventRegistry::Lookup(std::string_view name)
return nullptr; return nullptr;
} }
bool EventRegistry::NotOnlyRegisteredFromScript(std::string_view name) bool EventRegistry::NotOnlyRegisteredFromScript(std::string_view name) {
{
return not_only_from_script.count(std::string(name)) > 0; return not_only_from_script.count(std::string(name)) > 0;
} }
EventRegistry::string_list EventRegistry::Match(RE_Matcher* pattern) EventRegistry::string_list EventRegistry::Match(RE_Matcher* pattern) {
{
string_list names; string_list names;
for ( const auto& entry : handlers ) for ( const auto& entry : handlers ) {
{
EventHandler* v = entry.second.get(); EventHandler* v = entry.second.get();
if ( v->GetFunc() && pattern->MatchExactly(v->Name()) ) if ( v->GetFunc() && pattern->MatchExactly(v->Name()) )
names.push_back(entry.first); names.push_back(entry.first);
@ -74,12 +66,10 @@ EventRegistry::string_list EventRegistry::Match(RE_Matcher* pattern)
return names; return names;
} }
EventRegistry::string_list EventRegistry::UnusedHandlers() EventRegistry::string_list EventRegistry::UnusedHandlers() {
{
string_list names; string_list names;
for ( const auto& entry : handlers ) for ( const auto& entry : handlers ) {
{
EventHandler* v = entry.second.get(); EventHandler* v = entry.second.get();
if ( v->GetFunc() && ! v->Used() ) if ( v->GetFunc() && ! v->Used() )
names.push_back(entry.first); names.push_back(entry.first);
@ -88,12 +78,10 @@ EventRegistry::string_list EventRegistry::UnusedHandlers()
return names; return names;
} }
EventRegistry::string_list EventRegistry::UsedHandlers() EventRegistry::string_list EventRegistry::UsedHandlers() {
{
string_list names; string_list names;
for ( const auto& entry : handlers ) for ( const auto& entry : handlers ) {
{
EventHandler* v = entry.second.get(); EventHandler* v = entry.second.get();
if ( v->GetFunc() && v->Used() ) if ( v->GetFunc() && v->Used() )
names.push_back(entry.first); names.push_back(entry.first);
@ -102,54 +90,44 @@ EventRegistry::string_list EventRegistry::UsedHandlers()
return names; return names;
} }
EventRegistry::string_list EventRegistry::AllHandlers() EventRegistry::string_list EventRegistry::AllHandlers() {
{
string_list names; string_list names;
for ( const auto& entry : handlers ) for ( const auto& entry : handlers ) {
{
names.push_back(entry.first); names.push_back(entry.first);
} }
return names; return names;
} }
void EventRegistry::PrintDebug() void EventRegistry::PrintDebug() {
{ for ( const auto& entry : handlers ) {
for ( const auto& entry : handlers )
{
EventHandler* v = entry.second.get(); EventHandler* v = entry.second.get();
fprintf(stderr, "Registered event %s (%s handler / %s)\n", v->Name(), fprintf(stderr, "Registered event %s (%s handler / %s)\n", v->Name(), v->GetFunc() ? "local" : "no",
v->GetFunc() ? "local" : "no", *v ? "active" : "not active"); *v ? "active" : "not active");
} }
} }
void EventRegistry::SetErrorHandler(std::string_view name) void EventRegistry::SetErrorHandler(std::string_view name) {
{
EventHandler* eh = Lookup(name); EventHandler* eh = Lookup(name);
if ( eh ) if ( eh ) {
{
eh->SetErrorHandler(); eh->SetErrorHandler();
return; return;
} }
reporter->InternalWarning("unknown event handler '%s' in SetErrorHandler()", reporter->InternalWarning("unknown event handler '%s' in SetErrorHandler()", std::string(name).c_str());
std::string(name).c_str());
} }
void EventRegistry::ActivateAllHandlers() void EventRegistry::ActivateAllHandlers() {
{
auto event_names = AllHandlers(); auto event_names = AllHandlers();
for ( const auto& name : event_names ) for ( const auto& name : event_names ) {
{
if ( auto event = Lookup(name) ) if ( auto event = Lookup(name) )
event->SetGenerateAlways(); event->SetGenerateAlways();
} }
} }
EventGroupPtr EventRegistry::RegisterGroup(EventGroupKind kind, std::string_view name) EventGroupPtr EventRegistry::RegisterGroup(EventGroupKind kind, std::string_view name) {
{
auto key = std::pair{kind, std::string{name}}; auto key = std::pair{kind, std::string{name}};
if ( const auto& it = event_groups.find(key); it != event_groups.end() ) if ( const auto& it = event_groups.find(key); it != event_groups.end() )
return it->second; return it->second;
@ -158,8 +136,7 @@ EventGroupPtr EventRegistry::RegisterGroup(EventGroupKind kind, std::string_view
return event_groups.emplace(key, group).first->second; return event_groups.emplace(key, group).first->second;
} }
EventGroupPtr EventRegistry::LookupGroup(EventGroupKind kind, std::string_view name) EventGroupPtr EventRegistry::LookupGroup(EventGroupKind kind, std::string_view name) {
{
auto key = std::pair{kind, std::string{name}}; auto key = std::pair{kind, std::string{name}};
if ( const auto& it = event_groups.find(key); it != event_groups.end() ) if ( const auto& it = event_groups.find(key); it != event_groups.end() )
return it->second; return it->second;
@ -177,29 +154,19 @@ EventGroup::EventGroup(EventGroupKind kind, std::string_view name) : kind(kind),
// EventGroup is private friend with Func, so fiddling with the bodies // EventGroup is private friend with Func, so fiddling with the bodies
// and private members works and keeps the logic out of Func and away // and private members works and keeps the logic out of Func and away
// from the public zeek:: namespace. // from the public zeek:: namespace.
void EventGroup::UpdateFuncBodies() void EventGroup::UpdateFuncBodies() {
{ static auto is_group_disabled = [](const auto& g) { return g->IsDisabled(); };
static auto is_group_disabled = [](const auto& g)
{
return g->IsDisabled();
};
for ( auto& func : funcs ) for ( auto& func : funcs ) {
{
for ( auto& b : func->bodies ) for ( auto& b : func->bodies )
b.disabled = std::any_of(b.groups.cbegin(), b.groups.cend(), is_group_disabled); b.disabled = std::any_of(b.groups.cbegin(), b.groups.cend(), is_group_disabled);
static auto is_body_enabled = [](const auto& b) static auto is_body_enabled = [](const auto& b) { return ! b.disabled; };
{ func->has_enabled_bodies = std::any_of(func->bodies.cbegin(), func->bodies.cend(), is_body_enabled);
return ! b.disabled;
};
func->has_enabled_bodies = std::any_of(func->bodies.cbegin(), func->bodies.cend(),
is_body_enabled);
} }
} }
void EventGroup::Enable() void EventGroup::Enable() {
{
if ( enabled ) if ( enabled )
return; return;
@ -208,8 +175,7 @@ void EventGroup::Enable()
UpdateFuncBodies(); UpdateFuncBodies();
} }
void EventGroup::Disable() void EventGroup::Disable() {
{
if ( ! enabled ) if ( ! enabled )
return; return;
@ -218,9 +184,6 @@ void EventGroup::Disable()
UpdateFuncBodies(); UpdateFuncBodies();
} }
void EventGroup::AddFunc(detail::ScriptFuncPtr f) void EventGroup::AddFunc(detail::ScriptFuncPtr f) { funcs.insert(f); }
{
funcs.insert(f);
}
} // namespace zeek } // namespace zeek

View file

@ -13,12 +13,10 @@
#include "zeek/IntrusivePtr.h" #include "zeek/IntrusivePtr.h"
namespace zeek namespace zeek {
{
// The different kinds of event groups that exist. // The different kinds of event groups that exist.
enum class EventGroupKind enum class EventGroupKind {
{
Attribute, Attribute,
Module, Module,
}; };
@ -30,15 +28,13 @@ class RE_Matcher;
using EventGroupPtr = std::shared_ptr<EventGroup>; using EventGroupPtr = std::shared_ptr<EventGroup>;
namespace detail namespace detail {
{
class ScriptFunc; class ScriptFunc;
using ScriptFuncPtr = zeek::IntrusivePtr<ScriptFunc>; using ScriptFuncPtr = zeek::IntrusivePtr<ScriptFunc>;
} } // namespace detail
// The registry keeps track of all events that we provide or handle. // The registry keeps track of all events that we provide or handle.
class EventRegistry final class EventRegistry final {
{
public: public:
EventRegistry(); EventRegistry();
~EventRegistry() noexcept; ~EventRegistry() noexcept;
@ -110,8 +106,7 @@ private:
std::unordered_set<std::string> not_only_from_script; std::unordered_set<std::string> not_only_from_script;
// Map event groups identified by kind and name to their instances. // Map event groups identified by kind and name to their instances.
std::map<std::pair<EventGroupKind, std::string>, std::shared_ptr<EventGroup>, std::less<>> std::map<std::pair<EventGroupKind, std::string>, std::shared_ptr<EventGroup>, std::less<>> event_groups;
event_groups;
}; };
/** /**
@ -135,8 +130,7 @@ private:
* bodies of the tracked ScriptFuncs and updates them to reflect the current * bodies of the tracked ScriptFuncs and updates them to reflect the current
* group state. * group state.
*/ */
class EventGroup final class EventGroup final {
{
public: public:
EventGroup(EventGroupKind kind, std::string_view name); EventGroup(EventGroupKind kind, std::string_view name);
~EventGroup() noexcept = default; ~EventGroup() noexcept = default;

File diff suppressed because it is too large Load diff

View file

@ -5,16 +5,14 @@
#include "zeek/Val.h" #include "zeek/Val.h"
#include "zeek/ZeekArgs.h" #include "zeek/ZeekArgs.h"
namespace zeek::detail namespace zeek::detail {
{
class ValTrace; class ValTrace;
class ValTraceMgr; class ValTraceMgr;
// Abstract class for capturing a single difference between two script-level // Abstract class for capturing a single difference between two script-level
// values. Includes notions of inserting, changing, or deleting a value. // values. Includes notions of inserting, changing, or deleting a value.
class ValDelta class ValDelta {
{
public: public:
ValDelta(const ValTrace* _vt) : vt(_vt) {} ValDelta(const ValTrace* _vt) : vt(_vt) {}
virtual ~ValDelta() {} virtual ~ValDelta() {}
@ -43,8 +41,7 @@ using DeltaVector = std::vector<std::unique_ptr<ValDelta>>;
// For non-aggregates, this is simply the Val object, but for aggregates // For non-aggregates, this is simply the Val object, but for aggregates
// it is (recursively) each of the sub-elements, in a manner that can then // it is (recursively) each of the sub-elements, in a manner that can then
// be readily compared against future instances. // be readily compared against future instances.
class ValTrace class ValTrace {
{
public: public:
ValTrace(const ValPtr& v); ValTrace(const ValPtr& v);
~ValTrace() = default; ~ValTrace() = default;
@ -107,13 +104,9 @@ private:
}; };
// Captures the basic notion of a new, non-equivalent value being assigned. // Captures the basic notion of a new, non-equivalent value being assigned.
class DeltaReplaceValue : public ValDelta class DeltaReplaceValue : public ValDelta {
{
public: public:
DeltaReplaceValue(const ValTrace* _vt, ValPtr _new_val) DeltaReplaceValue(const ValTrace* _vt, ValPtr _new_val) : ValDelta(_vt), new_val(std::move(_new_val)) {}
: ValDelta(_vt), new_val(std::move(_new_val))
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
@ -122,13 +115,10 @@ private:
}; };
// Captures the notion of setting a record field. // Captures the notion of setting a record field.
class DeltaSetField : public ValDelta class DeltaSetField : public ValDelta {
{
public: public:
DeltaSetField(const ValTrace* _vt, int _field, ValPtr _new_val) DeltaSetField(const ValTrace* _vt, int _field, ValPtr _new_val)
: ValDelta(_vt), field(_field), new_val(std::move(_new_val)) : ValDelta(_vt), field(_field), new_val(std::move(_new_val)) {}
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
@ -138,8 +128,7 @@ private:
}; };
// Captures the notion of deleting a record field. // Captures the notion of deleting a record field.
class DeltaRemoveField : public ValDelta class DeltaRemoveField : public ValDelta {
{
public: public:
DeltaRemoveField(const ValTrace* _vt, int _field) : ValDelta(_vt), field(_field) {} DeltaRemoveField(const ValTrace* _vt, int _field) : ValDelta(_vt), field(_field) {}
@ -151,8 +140,7 @@ private:
}; };
// Captures the notion of creating a record from scratch. // Captures the notion of creating a record from scratch.
class DeltaRecordCreate : public ValDelta class DeltaRecordCreate : public ValDelta {
{
public: public:
DeltaRecordCreate(const ValTrace* _vt) : ValDelta(_vt) {} DeltaRecordCreate(const ValTrace* _vt) : ValDelta(_vt) {}
@ -161,8 +149,7 @@ public:
// Captures the notion of adding an element to a set. Use DeltaRemoveTableEntry to // Captures the notion of adding an element to a set. Use DeltaRemoveTableEntry to
// delete values. // delete values.
class DeltaSetSetEntry : public ValDelta class DeltaSetSetEntry : public ValDelta {
{
public: public:
DeltaSetSetEntry(const ValTrace* _vt, ValPtr _index) : ValDelta(_vt), index(_index) {} DeltaSetSetEntry(const ValTrace* _vt, ValPtr _index) : ValDelta(_vt), index(_index) {}
@ -176,13 +163,10 @@ private:
// Captures the notion of setting a table entry (which includes both changing // Captures the notion of setting a table entry (which includes both changing
// an existing one and adding a new one). Use DeltaRemoveTableEntry to // an existing one and adding a new one). Use DeltaRemoveTableEntry to
// delete values. // delete values.
class DeltaSetTableEntry : public ValDelta class DeltaSetTableEntry : public ValDelta {
{
public: public:
DeltaSetTableEntry(const ValTrace* _vt, ValPtr _index, ValPtr _new_val) DeltaSetTableEntry(const ValTrace* _vt, ValPtr _index, ValPtr _new_val)
: ValDelta(_vt), index(_index), new_val(std::move(_new_val)) : ValDelta(_vt), index(_index), new_val(std::move(_new_val)) {}
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
@ -192,13 +176,9 @@ private:
}; };
// Captures the notion of removing a table/set entry. // Captures the notion of removing a table/set entry.
class DeltaRemoveTableEntry : public ValDelta class DeltaRemoveTableEntry : public ValDelta {
{
public: public:
DeltaRemoveTableEntry(const ValTrace* _vt, ValPtr _index) DeltaRemoveTableEntry(const ValTrace* _vt, ValPtr _index) : ValDelta(_vt), index(std::move(_index)) {}
: ValDelta(_vt), index(std::move(_index))
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
bool NeedsLHS() const override { return false; } bool NeedsLHS() const override { return false; }
@ -208,8 +188,7 @@ private:
}; };
// Captures the notion of creating a set from scratch. // Captures the notion of creating a set from scratch.
class DeltaSetCreate : public ValDelta class DeltaSetCreate : public ValDelta {
{
public: public:
DeltaSetCreate(const ValTrace* _vt) : ValDelta(_vt) {} DeltaSetCreate(const ValTrace* _vt) : ValDelta(_vt) {}
@ -217,8 +196,7 @@ public:
}; };
// Captures the notion of creating a table from scratch. // Captures the notion of creating a table from scratch.
class DeltaTableCreate : public ValDelta class DeltaTableCreate : public ValDelta {
{
public: public:
DeltaTableCreate(const ValTrace* _vt) : ValDelta(_vt) {} DeltaTableCreate(const ValTrace* _vt) : ValDelta(_vt) {}
@ -226,13 +204,10 @@ public:
}; };
// Captures the notion of changing an element of a vector. // Captures the notion of changing an element of a vector.
class DeltaVectorSet : public ValDelta class DeltaVectorSet : public ValDelta {
{
public: public:
DeltaVectorSet(const ValTrace* _vt, int _index, ValPtr _elem) DeltaVectorSet(const ValTrace* _vt, int _index, ValPtr _elem)
: ValDelta(_vt), index(_index), elem(std::move(_elem)) : ValDelta(_vt), index(_index), elem(std::move(_elem)) {}
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
@ -242,13 +217,10 @@ private:
}; };
// Captures the notion of adding an entry to the end of a vector. // Captures the notion of adding an entry to the end of a vector.
class DeltaVectorAppend : public ValDelta class DeltaVectorAppend : public ValDelta {
{
public: public:
DeltaVectorAppend(const ValTrace* _vt, int _index, ValPtr _elem) DeltaVectorAppend(const ValTrace* _vt, int _index, ValPtr _elem)
: ValDelta(_vt), index(_index), elem(std::move(_elem)) : ValDelta(_vt), index(_index), elem(std::move(_elem)) {}
{
}
std::string Generate(ValTraceMgr* vtm) const override; std::string Generate(ValTraceMgr* vtm) const override;
@ -258,8 +230,7 @@ private:
}; };
// Captures the notion of replacing a vector wholesale. // Captures the notion of replacing a vector wholesale.
class DeltaVectorCreate : public ValDelta class DeltaVectorCreate : public ValDelta {
{
public: public:
DeltaVectorCreate(const ValTrace* _vt) : ValDelta(_vt) {} DeltaVectorCreate(const ValTrace* _vt) : ValDelta(_vt) {}
@ -268,8 +239,7 @@ public:
// Captures the notion of creating a value with an unsupported type // Captures the notion of creating a value with an unsupported type
// (like "opaque"). // (like "opaque").
class DeltaUnsupportedCreate : public ValDelta class DeltaUnsupportedCreate : public ValDelta {
{
public: public:
DeltaUnsupportedCreate(const ValTrace* _vt) : ValDelta(_vt) {} DeltaUnsupportedCreate(const ValTrace* _vt) : ValDelta(_vt) {}
@ -278,14 +248,10 @@ public:
// Manages the changes to (or creation of) a variable used to represent // Manages the changes to (or creation of) a variable used to represent
// a value. // a value.
class DeltaGen class DeltaGen {
{
public: public:
DeltaGen(ValPtr _val, std::string _rhs, bool _needs_lhs, bool _is_first_def) DeltaGen(ValPtr _val, std::string _rhs, bool _needs_lhs, bool _is_first_def)
: val(std::move(_val)), rhs(std::move(_rhs)), needs_lhs(_needs_lhs), : val(std::move(_val)), rhs(std::move(_rhs)), needs_lhs(_needs_lhs), is_first_def(_is_first_def) {}
is_first_def(_is_first_def)
{
}
const ValPtr& GetVal() const { return val; } const ValPtr& GetVal() const { return val; }
const std::string& RHS() const { return rhs; } const std::string& RHS() const { return rhs; }
@ -310,8 +276,7 @@ private:
using DeltaGenVec = std::vector<DeltaGen>; using DeltaGenVec = std::vector<DeltaGen>;
// Tracks a single event. // Tracks a single event.
class EventTrace class EventTrace {
{
public: public:
// Constructed in terms of the associated script function, "network // Constructed in terms of the associated script function, "network
// time" when the event occurred, and the position of this event // time" when the event occurred, and the position of this event
@ -323,8 +288,7 @@ public:
void SetArgs(std::string _args) { args = std::move(_args); } void SetArgs(std::string _args) { args = std::move(_args); }
// Adds to the trace an update for the given value. // Adds to the trace an update for the given value.
void AddDelta(ValPtr val, std::string rhs, bool needs_lhs, bool is_first_def) void AddDelta(ValPtr val, std::string rhs, bool needs_lhs, bool is_first_def) {
{
auto& d = is_post ? post_deltas : deltas; auto& d = is_post ? post_deltas : deltas;
d.emplace_back(DeltaGen(val, rhs, needs_lhs, is_first_def)); d.emplace_back(DeltaGen(val, rhs, needs_lhs, is_first_def));
} }
@ -346,14 +310,12 @@ public:
// "predecessor", if non-nil, gives the event that came just before // "predecessor", if non-nil, gives the event that came just before
// this one (used for "# from script" annotations"). "successor", // this one (used for "# from script" annotations"). "successor",
// if not empty, gives the name of the successor internal event. // if not empty, gives the name of the successor internal event.
void Generate(FILE* f, ValTraceMgr& vtm, const EventTrace* predecessor, void Generate(FILE* f, ValTraceMgr& vtm, const EventTrace* predecessor, std::string successor) const;
std::string successor) const;
private: private:
// "dvec" is either just our deltas, or the "post_deltas" of our // "dvec" is either just our deltas, or the "post_deltas" of our
// predecessor plus our deltas. // predecessor plus our deltas.
void Generate(FILE* f, ValTraceMgr& vtm, const DeltaGenVec& dvec, std::string successor, void Generate(FILE* f, ValTraceMgr& vtm, const DeltaGenVec& dvec, std::string successor, int num_pre = 0) const;
int num_pre = 0) const;
const ScriptFunc* ev; const ScriptFunc* ev;
double nt; double nt;
@ -373,8 +335,7 @@ private:
}; };
// Manages all of the events and associated values seen during the execution. // Manages all of the events and associated values seen during the execution.
class ValTraceMgr class ValTraceMgr {
{
public: public:
// Invoked to trace a new event with the associated arguments. // Invoked to trace a new event with the associated arguments.
void TraceEventValues(std::shared_ptr<EventTrace> et, const zeek::Args* args); void TraceEventValues(std::shared_ptr<EventTrace> et, const zeek::Args* args);
@ -474,8 +435,7 @@ private:
// Manages tracing of all of the events seen during execution, including // Manages tracing of all of the events seen during execution, including
// the final generation of the trace script. // the final generation of the trace script.
class EventTraceMgr class EventTraceMgr {
{
public: public:
EventTraceMgr(const std::string& trace_file); EventTraceMgr(const std::string& trace_file);
~EventTraceMgr(); ~EventTraceMgr();

File diff suppressed because it is too large Load diff

View file

@ -18,12 +18,11 @@
#include "zeek/ZeekArgs.h" #include "zeek/ZeekArgs.h"
#include "zeek/ZeekList.h" #include "zeek/ZeekList.h"
namespace zeek namespace zeek {
{ template<class T>
template <class T> class IntrusivePtr; class IntrusivePtr;
namespace detail namespace detail {
{
class Frame; class Frame;
class Scope; class Scope;
@ -34,8 +33,7 @@ using ScopePtr = IntrusivePtr<Scope>;
using ScriptFuncPtr = IntrusivePtr<ScriptFunc>; using ScriptFuncPtr = IntrusivePtr<ScriptFunc>;
using FunctionIngredientsPtr = std::shared_ptr<FunctionIngredients>; using FunctionIngredientsPtr = std::shared_ptr<FunctionIngredients>;
enum ExprTag : int enum ExprTag : int {
{
EXPR_ANY = -1, EXPR_ANY = -1,
EXPR_NAME, EXPR_NAME,
EXPR_CONST, EXPR_CONST,
@ -150,17 +148,18 @@ using StmtPtr = IntrusivePtr<Stmt>;
class ExprOptInfo; class ExprOptInfo;
class Expr : public Obj class Expr : public Obj {
{
public: public:
const TypePtr& GetType() const { return type; } const TypePtr& GetType() const { return type; }
template <class T> IntrusivePtr<T> GetType() const { return cast_intrusive<T>(type); } template<class T>
IntrusivePtr<T> GetType() const {
return cast_intrusive<T>(type);
}
ExprTag Tag() const { return tag; } ExprTag Tag() const { return tag; }
Expr* Ref() Expr* Ref() {
{
zeek::Ref(this); zeek::Ref(this);
return this; return this;
} }
@ -270,10 +269,7 @@ public:
// True if the expression can serve as an operand to a reduced // True if the expression can serve as an operand to a reduced
// expression. // expression.
bool IsSingleton(Reducer* r) const bool IsSingleton(Reducer* r) const { return (tag == EXPR_NAME && IsReduced(r)) || tag == EXPR_CONST; }
{
return (tag == EXPR_NAME && IsReduced(r)) || tag == EXPR_CONST;
}
// True if the expression has no side effects, false otherwise. // True if the expression has no side effects, false otherwise.
virtual bool HasNoSideEffects() const { return IsPure(); } virtual bool HasNoSideEffects() const { return IsPure(); }
@ -287,8 +283,7 @@ public:
// True if (a) the expression has at least one operand, and (b) all // True if (a) the expression has at least one operand, and (b) all
// of its operands are constant. // of its operands are constant.
bool HasConstantOps() const bool HasConstantOps() const {
{
return GetOp1() && GetOp1()->IsConst() && return GetOp1() && GetOp1()->IsConst() &&
(! GetOp2() || (GetOp2()->IsConst() && (! GetOp3() || GetOp3()->IsConst()))); (! GetOp2() || (GetOp2()->IsConst() && (! GetOp3() || GetOp3()->IsConst())));
} }
@ -346,10 +341,7 @@ public:
// that's been assigned to the given expression via red_stmt. // that's been assigned to the given expression via red_stmt.
ExprPtr AssignToTemporary(ExprPtr e, Reducer* c, StmtPtr& red_stmt); ExprPtr AssignToTemporary(ExprPtr e, Reducer* c, StmtPtr& red_stmt);
// Same but for this expression. // Same but for this expression.
ExprPtr AssignToTemporary(Reducer* c, StmtPtr& red_stmt) ExprPtr AssignToTemporary(Reducer* c, StmtPtr& red_stmt) { return AssignToTemporary(ThisPtr(), c, red_stmt); }
{
return AssignToTemporary(ThisPtr(), c, red_stmt);
}
// If the expression always evaluates to the same value, returns // If the expression always evaluates to the same value, returns
// that value. Otherwise, returns nullptr. // that value. Otherwise, returns nullptr.
@ -379,8 +371,7 @@ public:
const Expr* Original() const { return original ? original->Original() : this; } const Expr* Original() const { return original ? original->Original() : this; }
// Designate the given Expr node as the original for this one. // Designate the given Expr node as the original for this one.
void SetOriginal(ExprPtr _orig) void SetOriginal(ExprPtr _orig) {
{
if ( ! original ) if ( ! original )
original = std::move(_orig); original = std::move(_orig);
} }
@ -392,16 +383,14 @@ public:
// code, which is always passing in "new XyzExpr(...)". This // code, which is always passing in "new XyzExpr(...)". This
// call, as a convenient side effect, transforms that bare pointer // call, as a convenient side effect, transforms that bare pointer
// into an ExprPtr. // into an ExprPtr.
virtual ExprPtr SetSucc(Expr* succ) virtual ExprPtr SetSucc(Expr* succ) {
{
succ->SetOriginal(ThisPtr()); succ->SetOriginal(ThisPtr());
if ( IsParen() ) if ( IsParen() )
succ->MarkParen(); succ->MarkParen();
return {AdoptRef{}, succ}; return {AdoptRef{}, succ};
} }
const detail::Location* GetLocationInfo() const override const detail::Location* GetLocationInfo() const override {
{
if ( original ) if ( original )
return original->GetLocationInfo(); return original->GetLocationInfo();
else else
@ -458,8 +447,7 @@ protected:
static int num_exprs; static int num_exprs;
}; };
class NameExpr final : public Expr class NameExpr final : public Expr {
{
public: public:
explicit NameExpr(IDPtr id, bool const_init = false); explicit NameExpr(IDPtr id, bool const_init = false);
@ -492,8 +480,7 @@ protected:
bool in_const_init; bool in_const_init;
}; };
class ConstExpr final : public Expr class ConstExpr final : public Expr {
{
public: public:
explicit ConstExpr(ValPtr val); explicit ConstExpr(ValPtr val);
@ -513,8 +500,7 @@ protected:
ValPtr val; ValPtr val;
}; };
class UnaryExpr : public Expr class UnaryExpr : public Expr {
{
public: public:
Expr* Op() const { return op.get(); } Expr* Op() const { return op.get(); }
@ -549,8 +535,7 @@ protected:
ExprPtr op; ExprPtr op;
}; };
class BinaryExpr : public Expr class BinaryExpr : public Expr {
{
public: public:
Expr* Op1() const { return op1.get(); } Expr* Op1() const { return op1.get(); }
Expr* Op2() const { return op2.get(); } Expr* Op2() const { return op2.get(); }
@ -580,8 +565,7 @@ public:
protected: protected:
BinaryExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2) BinaryExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2)
: Expr(arg_tag), op1(std::move(arg_op1)), op2(std::move(arg_op2)) : Expr(arg_tag), op1(std::move(arg_op1)), op2(std::move(arg_op2)) {
{
if ( ! (op1 && op2) ) if ( ! (op1 && op2) )
return; return;
if ( op1->IsError() || op2->IsError() ) if ( op1->IsError() || op2->IsError() )
@ -636,8 +620,7 @@ protected:
ExprPtr op2; ExprPtr op2;
}; };
class CloneExpr final : public UnaryExpr class CloneExpr final : public UnaryExpr {
{
public: public:
explicit CloneExpr(ExprPtr op); explicit CloneExpr(ExprPtr op);
ValPtr Eval(Frame* f) const override; ValPtr Eval(Frame* f) const override;
@ -649,8 +632,7 @@ protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class IncrExpr final : public UnaryExpr class IncrExpr final : public UnaryExpr {
{
public: public:
IncrExpr(ExprTag tag, ExprPtr op); IncrExpr(ExprTag tag, ExprPtr op);
@ -668,8 +650,7 @@ public:
ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override;
}; };
class ComplementExpr final : public UnaryExpr class ComplementExpr final : public UnaryExpr {
{
public: public:
explicit ComplementExpr(ExprPtr op); explicit ComplementExpr(ExprPtr op);
@ -682,8 +663,7 @@ protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class NotExpr final : public UnaryExpr class NotExpr final : public UnaryExpr {
{
public: public:
explicit NotExpr(ExprPtr op); explicit NotExpr(ExprPtr op);
@ -696,8 +676,7 @@ protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class PosExpr final : public UnaryExpr class PosExpr final : public UnaryExpr {
{
public: public:
explicit PosExpr(ExprPtr op); explicit PosExpr(ExprPtr op);
@ -710,8 +689,7 @@ protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class NegExpr final : public UnaryExpr class NegExpr final : public UnaryExpr {
{
public: public:
explicit NegExpr(ExprPtr op); explicit NegExpr(ExprPtr op);
@ -724,8 +702,7 @@ protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class SizeExpr final : public UnaryExpr class SizeExpr final : public UnaryExpr {
{
public: public:
explicit SizeExpr(ExprPtr op); explicit SizeExpr(ExprPtr op);
ValPtr Eval(Frame* f) const override; ValPtr Eval(Frame* f) const override;
@ -737,8 +714,7 @@ protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class AddExpr final : public BinaryExpr class AddExpr final : public BinaryExpr {
{
public: public:
AddExpr(ExprPtr op1, ExprPtr op2); AddExpr(ExprPtr op1, ExprPtr op2);
void Canonicalize() override; void Canonicalize() override;
@ -752,8 +728,7 @@ protected:
ExprPtr BuildSub(const ExprPtr& op1, const ExprPtr& op2); ExprPtr BuildSub(const ExprPtr& op1, const ExprPtr& op2);
}; };
class AddToExpr final : public BinaryExpr class AddToExpr final : public BinaryExpr {
{
public: public:
AddToExpr(ExprPtr op1, ExprPtr op2); AddToExpr(ExprPtr op1, ExprPtr op2);
ValPtr Eval(Frame* f) const override; ValPtr Eval(Frame* f) const override;
@ -772,8 +747,7 @@ private:
bool is_vector_elem_append = false; bool is_vector_elem_append = false;
}; };
class RemoveFromExpr final : public BinaryExpr class RemoveFromExpr final : public BinaryExpr {
{
public: public:
bool IsPure() const override { return false; } bool IsPure() const override { return false; }
RemoveFromExpr(ExprPtr op1, ExprPtr op2); RemoveFromExpr(ExprPtr op1, ExprPtr op2);
@ -788,8 +762,7 @@ public:
ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override;
}; };
class SubExpr final : public BinaryExpr class SubExpr final : public BinaryExpr {
{
public: public:
SubExpr(ExprPtr op1, ExprPtr op2); SubExpr(ExprPtr op1, ExprPtr op2);
@ -799,8 +772,7 @@ public:
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
}; };
class TimesExpr final : public BinaryExpr class TimesExpr final : public BinaryExpr {
{
public: public:
TimesExpr(ExprPtr op1, ExprPtr op2); TimesExpr(ExprPtr op1, ExprPtr op2);
void Canonicalize() override; void Canonicalize() override;
@ -811,8 +783,7 @@ public:
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
}; };
class DivideExpr final : public BinaryExpr class DivideExpr final : public BinaryExpr {
{
public: public:
DivideExpr(ExprPtr op1, ExprPtr op2); DivideExpr(ExprPtr op1, ExprPtr op2);
@ -822,8 +793,7 @@ public:
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
}; };
class MaskExpr final : public BinaryExpr class MaskExpr final : public BinaryExpr {
{
public: public:
MaskExpr(ExprPtr op1, ExprPtr op2); MaskExpr(ExprPtr op1, ExprPtr op2);
@ -834,8 +804,7 @@ protected:
ValPtr AddrFold(Val* v1, Val* v2) const override; ValPtr AddrFold(Val* v1, Val* v2) const override;
}; };
class ModExpr final : public BinaryExpr class ModExpr final : public BinaryExpr {
{
public: public:
ModExpr(ExprPtr op1, ExprPtr op2); ModExpr(ExprPtr op1, ExprPtr op2);
@ -843,8 +812,7 @@ public:
ExprPtr Duplicate() override; ExprPtr Duplicate() override;
}; };
class BoolExpr final : public BinaryExpr class BoolExpr final : public BinaryExpr {
{
public: public:
BoolExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); BoolExpr(ExprTag tag, ExprPtr op1, ExprPtr op2);
@ -862,8 +830,7 @@ protected:
bool IsFalse(const ExprPtr& e) const; bool IsFalse(const ExprPtr& e) const;
}; };
class BitExpr final : public BinaryExpr class BitExpr final : public BinaryExpr {
{
public: public:
BitExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); BitExpr(ExprTag tag, ExprPtr op1, ExprPtr op2);
@ -873,8 +840,7 @@ public:
ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override;
}; };
class EqExpr final : public BinaryExpr class EqExpr final : public BinaryExpr {
{
public: public:
EqExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); EqExpr(ExprTag tag, ExprPtr op1, ExprPtr op2);
void Canonicalize() override; void Canonicalize() override;
@ -889,8 +855,7 @@ protected:
ValPtr Fold(Val* v1, Val* v2) const override; ValPtr Fold(Val* v1, Val* v2) const override;
}; };
class RelExpr final : public BinaryExpr class RelExpr final : public BinaryExpr {
{
public: public:
RelExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); RelExpr(ExprTag tag, ExprPtr op1, ExprPtr op2);
void Canonicalize() override; void Canonicalize() override;
@ -902,8 +867,7 @@ public:
bool InvertSense() override; bool InvertSense() override;
}; };
class CondExpr final : public Expr class CondExpr final : public Expr {
{
public: public:
CondExpr(ExprPtr op1, ExprPtr op2, ExprPtr op3); CondExpr(ExprPtr op1, ExprPtr op2, ExprPtr op3);
@ -942,8 +906,7 @@ protected:
ExprPtr op3; ExprPtr op3;
}; };
class RefExpr final : public UnaryExpr class RefExpr final : public UnaryExpr {
{
public: public:
explicit RefExpr(ExprPtr op); explicit RefExpr(ExprPtr op);
@ -962,13 +925,12 @@ public:
StmtPtr ReduceToLHS(Reducer* c); StmtPtr ReduceToLHS(Reducer* c);
}; };
class AssignExpr : public BinaryExpr class AssignExpr : public BinaryExpr {
{
public: public:
// If val is given, evaluating this expression will always yield the val // If val is given, evaluating this expression will always yield the val
// yet still perform the assignment. Used for triggers. // yet still perform the assignment. Used for triggers.
AssignExpr(ExprPtr op1, ExprPtr op2, bool is_init, ValPtr val = nullptr, AssignExpr(ExprPtr op1, ExprPtr op2, bool is_init, ValPtr val = nullptr, const AttributesPtr& attrs = nullptr,
const AttributesPtr& attrs = nullptr, bool type_check = true); bool type_check = true);
ValPtr Eval(Frame* f) const override; ValPtr Eval(Frame* f) const override;
TypePtr InitType() const override; TypePtr InitType() const override;
@ -1008,8 +970,7 @@ protected:
ValPtr val; // optional ValPtr val; // optional
}; };
class IndexSliceAssignExpr final : public AssignExpr class IndexSliceAssignExpr final : public AssignExpr {
{
public: public:
IndexSliceAssignExpr(ExprPtr op1, ExprPtr op2, bool is_init); IndexSliceAssignExpr(ExprPtr op1, ExprPtr op2, bool is_init);
ValPtr Eval(Frame* f) const override; ValPtr Eval(Frame* f) const override;
@ -1018,8 +979,7 @@ public:
ExprPtr Duplicate() override; ExprPtr Duplicate() override;
}; };
class IndexExpr : public BinaryExpr class IndexExpr : public BinaryExpr {
{
public: public:
IndexExpr(ExprPtr op1, ListExprPtr op2, bool is_slice = false, bool is_inside_when = false); IndexExpr(ExprPtr op1, ListExprPtr op2, bool is_slice = false, bool is_inside_when = false);
@ -1084,8 +1044,7 @@ extern VectorValPtr vector_int_select(VectorTypePtr vt, const VectorVal* v1, con
// //
// TODO: One Fine Day we should do the equivalent for accessing fields // TODO: One Fine Day we should do the equivalent for accessing fields
// in records, too. // in records, too.
class IndexExprWhen final : public IndexExpr class IndexExprWhen final : public IndexExpr {
{
public: public:
static inline std::vector<ValPtr> results = {}; static inline std::vector<ValPtr> results = {};
static inline int evaluating = 0; static inline int evaluating = 0;
@ -1094,20 +1053,16 @@ public:
static void EndEval() { --evaluating; } static void EndEval() { --evaluating; }
static std::vector<ValPtr> TakeAllResults() static std::vector<ValPtr> TakeAllResults() {
{
auto rval = std::move(results); auto rval = std::move(results);
results = {}; results = {};
return rval; return rval;
} }
IndexExprWhen(ExprPtr op1, ListExprPtr op2, bool is_slice = false) IndexExprWhen(ExprPtr op1, ListExprPtr op2, bool is_slice = false)
: IndexExpr(std::move(op1), std::move(op2), is_slice, true) : IndexExpr(std::move(op1), std::move(op2), is_slice, true) {}
{
}
ValPtr Eval(Frame* f) const override ValPtr Eval(Frame* f) const override {
{
auto v = IndexExpr::Eval(f); auto v = IndexExpr::Eval(f);
if ( v && evaluating > 0 ) if ( v && evaluating > 0 )
@ -1120,8 +1075,7 @@ public:
ExprPtr Duplicate() override; ExprPtr Duplicate() override;
}; };
class FieldExpr final : public UnaryExpr class FieldExpr final : public UnaryExpr {
{
public: public:
FieldExpr(ExprPtr op, const char* field_name); FieldExpr(ExprPtr op, const char* field_name);
~FieldExpr() override; ~FieldExpr() override;
@ -1151,8 +1105,7 @@ protected:
// "rec?$fieldname" is true if the value of $fieldname in rec is not nil. // "rec?$fieldname" is true if the value of $fieldname in rec is not nil.
// "rec?$$attrname" is true if the attribute attrname is not nil. // "rec?$$attrname" is true if the attribute attrname is not nil.
class HasFieldExpr final : public UnaryExpr class HasFieldExpr final : public UnaryExpr {
{
public: public:
HasFieldExpr(ExprPtr op, const char* field_name); HasFieldExpr(ExprPtr op, const char* field_name);
~HasFieldExpr() override; ~HasFieldExpr() override;
@ -1175,8 +1128,7 @@ protected:
int field; int field;
}; };
class RecordConstructorExpr final : public Expr class RecordConstructorExpr final : public Expr {
{
public: public:
explicit RecordConstructorExpr(ListExprPtr constructor_list); explicit RecordConstructorExpr(ListExprPtr constructor_list);
@ -1207,8 +1159,7 @@ protected:
std::optional<std::vector<int>> map; std::optional<std::vector<int>> map;
}; };
class TableConstructorExpr final : public UnaryExpr class TableConstructorExpr final : public UnaryExpr {
{
public: public:
TableConstructorExpr(ListExprPtr constructor_list, std::unique_ptr<std::vector<AttrPtr>> attrs, TableConstructorExpr(ListExprPtr constructor_list, std::unique_ptr<std::vector<AttrPtr>> attrs,
TypePtr arg_type = nullptr, AttributesPtr arg_attrs = nullptr); TypePtr arg_type = nullptr, AttributesPtr arg_attrs = nullptr);
@ -1233,8 +1184,7 @@ protected:
AttributesPtr attrs; AttributesPtr attrs;
}; };
class SetConstructorExpr final : public UnaryExpr class SetConstructorExpr final : public UnaryExpr {
{
public: public:
SetConstructorExpr(ListExprPtr constructor_list, std::unique_ptr<std::vector<AttrPtr>> attrs, SetConstructorExpr(ListExprPtr constructor_list, std::unique_ptr<std::vector<AttrPtr>> attrs,
TypePtr arg_type = nullptr, AttributesPtr arg_attrs = nullptr); TypePtr arg_type = nullptr, AttributesPtr arg_attrs = nullptr);
@ -1259,8 +1209,7 @@ protected:
AttributesPtr attrs; AttributesPtr attrs;
}; };
class VectorConstructorExpr final : public UnaryExpr class VectorConstructorExpr final : public UnaryExpr {
{
public: public:
explicit VectorConstructorExpr(ListExprPtr constructor_list, TypePtr arg_type = nullptr); explicit VectorConstructorExpr(ListExprPtr constructor_list, TypePtr arg_type = nullptr);
@ -1275,8 +1224,7 @@ protected:
void ExprDescribe(ODesc* d) const override; void ExprDescribe(ODesc* d) const override;
}; };
class FieldAssignExpr final : public UnaryExpr class FieldAssignExpr final : public UnaryExpr {
{
public: public:
FieldAssignExpr(const char* field_name, ExprPtr value); FieldAssignExpr(const char* field_name, ExprPtr value);
@ -1303,8 +1251,7 @@ protected:
std::string field_name; std::string field_name;
}; };
class ArithCoerceExpr final : public UnaryExpr class ArithCoerceExpr final : public UnaryExpr {
{
public: public:
ArithCoerceExpr(ExprPtr op, TypeTag t); ArithCoerceExpr(ExprPtr op, TypeTag t);
@ -1319,8 +1266,7 @@ protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class RecordCoerceExpr final : public UnaryExpr class RecordCoerceExpr final : public UnaryExpr {
{
public: public:
RecordCoerceExpr(ExprPtr op, RecordTypePtr r); RecordCoerceExpr(ExprPtr op, RecordTypePtr r);
@ -1339,8 +1285,7 @@ protected:
extern RecordValPtr coerce_to_record(RecordTypePtr rt, Val* v, const std::vector<int>& map); extern RecordValPtr coerce_to_record(RecordTypePtr rt, Val* v, const std::vector<int>& map);
class TableCoerceExpr final : public UnaryExpr class TableCoerceExpr final : public UnaryExpr {
{
public: public:
TableCoerceExpr(ExprPtr op, TableTypePtr r, bool type_check = true); TableCoerceExpr(ExprPtr op, TableTypePtr r, bool type_check = true);
~TableCoerceExpr() override = default; ~TableCoerceExpr() override = default;
@ -1352,8 +1297,7 @@ protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class VectorCoerceExpr final : public UnaryExpr class VectorCoerceExpr final : public UnaryExpr {
{
public: public:
VectorCoerceExpr(ExprPtr op, VectorTypePtr v); VectorCoerceExpr(ExprPtr op, VectorTypePtr v);
~VectorCoerceExpr() override = default; ~VectorCoerceExpr() override = default;
@ -1365,8 +1309,7 @@ protected:
ValPtr Fold(Val* v) const override; ValPtr Fold(Val* v) const override;
}; };
class ScheduleTimer final : public Timer class ScheduleTimer final : public Timer {
{
public: public:
ScheduleTimer(const EventHandlerPtr& event, zeek::Args args, double t); ScheduleTimer(const EventHandlerPtr& event, zeek::Args args, double t);
~ScheduleTimer() override = default; ~ScheduleTimer() override = default;
@ -1378,8 +1321,7 @@ protected:
zeek::Args args; zeek::Args args;
}; };
class ScheduleExpr final : public Expr class ScheduleExpr final : public Expr {
{
public: public:
ScheduleExpr(ExprPtr when, EventExprPtr event); ScheduleExpr(ExprPtr when, EventExprPtr event);
@ -1413,8 +1355,7 @@ protected:
EventExprPtr event; EventExprPtr event;
}; };
class InExpr final : public BinaryExpr class InExpr final : public BinaryExpr {
{
public: public:
InExpr(ExprPtr op1, ExprPtr op2); InExpr(ExprPtr op1, ExprPtr op2);
@ -1427,8 +1368,7 @@ protected:
ValPtr Fold(Val* v1, Val* v2) const override; ValPtr Fold(Val* v1, Val* v2) const override;
}; };
class CallExpr final : public Expr class CallExpr final : public Expr {
{
public: public:
CallExpr(ExprPtr func, ListExprPtr args, bool in_hook = false, bool in_when = false); CallExpr(ExprPtr func, ListExprPtr args, bool in_hook = false, bool in_when = false);
@ -1465,8 +1405,7 @@ protected:
* On evaluation, captures the frame that it is evaluated in. This becomes * On evaluation, captures the frame that it is evaluated in. This becomes
* the closure for the instance of the function that it creates. * the closure for the instance of the function that it creates.
*/ */
class LambdaExpr final : public Expr class LambdaExpr final : public Expr {
{
public: public:
LambdaExpr(FunctionIngredientsPtr ingredients, IDPList outer_ids, std::string name = "", LambdaExpr(FunctionIngredientsPtr ingredients, IDPList outer_ids, std::string name = "",
StmtPtr when_parent = nullptr); StmtPtr when_parent = nullptr);
@ -1532,8 +1471,7 @@ private:
// This comes before EventExpr so that EventExpr::GetOp1 can return its // This comes before EventExpr so that EventExpr::GetOp1 can return its
// arguments as convertible to ExprPtr. // arguments as convertible to ExprPtr.
class ListExpr : public Expr class ListExpr : public Expr {
{
public: public:
ListExpr(); ListExpr();
explicit ListExpr(ExprPtr e); explicit ListExpr(ExprPtr e);
@ -1570,8 +1508,7 @@ protected:
ExprPList exprs; ExprPList exprs;
}; };
class EventExpr final : public Expr class EventExpr final : public Expr {
{
public: public:
EventExpr(const char* name, ListExprPtr args); EventExpr(const char* name, ListExprPtr args);
@ -1602,14 +1539,12 @@ protected:
ListExprPtr args; ListExprPtr args;
}; };
class RecordAssignExpr final : public ListExpr class RecordAssignExpr final : public ListExpr {
{
public: public:
RecordAssignExpr(const ExprPtr& record, const ExprPtr& init_list, bool is_init); RecordAssignExpr(const ExprPtr& record, const ExprPtr& init_list, bool is_init);
}; };
class CastExpr final : public UnaryExpr class CastExpr final : public UnaryExpr {
{
public: public:
CastExpr(ExprPtr op, TypePtr t); CastExpr(ExprPtr op, TypePtr t);
@ -1625,8 +1560,7 @@ protected:
// and populates "error" with an error message. // and populates "error" with an error message.
extern ValPtr cast_value(ValPtr v, const TypePtr& t, std::string& error); extern ValPtr cast_value(ValPtr v, const TypePtr& t, std::string& error);
class IsExpr final : public UnaryExpr class IsExpr final : public UnaryExpr {
{
public: public:
IsExpr(ExprPtr op, TypePtr t); IsExpr(ExprPtr op, TypePtr t);
@ -1643,11 +1577,9 @@ private:
TypePtr t; TypePtr t;
}; };
class InlineExpr : public Expr class InlineExpr : public Expr {
{
public: public:
InlineExpr(ListExprPtr arg_args, std::vector<IDPtr> params, StmtPtr body, int frame_offset, InlineExpr(ListExprPtr arg_args, std::vector<IDPtr> params, StmtPtr body, int frame_offset, TypePtr ret_type);
TypePtr ret_type);
bool IsPure() const override; bool IsPure() const override;
@ -1676,8 +1608,7 @@ protected:
// A companion to AddToExpr that's for vector-append, instantiated during // A companion to AddToExpr that's for vector-append, instantiated during
// the reduction process. // the reduction process.
class AppendToExpr : public BinaryExpr class AppendToExpr : public BinaryExpr {
{
public: public:
AppendToExpr(ExprPtr op1, ExprPtr op2); AppendToExpr(ExprPtr op1, ExprPtr op2);
ValPtr Eval(Frame* f) const override; ValPtr Eval(Frame* f) const override;
@ -1691,8 +1622,7 @@ public:
}; };
// An internal class for reduced form. // An internal class for reduced form.
class IndexAssignExpr : public BinaryExpr class IndexAssignExpr : public BinaryExpr {
{
public: public:
// "op1[op2] = op3", all reduced. // "op1[op2] = op3", all reduced.
IndexAssignExpr(ExprPtr op1, ExprPtr op2, ExprPtr op3); IndexAssignExpr(ExprPtr op1, ExprPtr op2, ExprPtr op3);
@ -1719,8 +1649,7 @@ protected:
}; };
// An internal class for reduced form. // An internal class for reduced form.
class FieldLHSAssignExpr : public BinaryExpr class FieldLHSAssignExpr : public BinaryExpr {
{
public: public:
// "op1$field = RHS", where RHS is reduced with respect to // "op1$field = RHS", where RHS is reduced with respect to
// ReduceToFieldAssignment(). // ReduceToFieldAssignment().
@ -1748,8 +1677,7 @@ protected:
// Expression to explicitly capture conversion to an "any" type, rather // Expression to explicitly capture conversion to an "any" type, rather
// than it occurring implicitly during script interpretation. // than it occurring implicitly during script interpretation.
class CoerceToAnyExpr : public UnaryExpr class CoerceToAnyExpr : public UnaryExpr {
{
public: public:
CoerceToAnyExpr(ExprPtr op); CoerceToAnyExpr(ExprPtr op);
@ -1760,8 +1688,7 @@ protected:
}; };
// Same, but for conversion from an "any" type. // Same, but for conversion from an "any" type.
class CoerceFromAnyExpr : public UnaryExpr class CoerceFromAnyExpr : public UnaryExpr {
{
public: public:
CoerceFromAnyExpr(ExprPtr op, TypePtr to_type); CoerceFromAnyExpr(ExprPtr op, TypePtr to_type);
@ -1772,8 +1699,7 @@ protected:
}; };
// ... and for conversion from a "vector of any" type. // ... and for conversion from a "vector of any" type.
class CoerceFromAnyVecExpr : public UnaryExpr class CoerceFromAnyVecExpr : public UnaryExpr {
{
public: public:
// to_type is yield type, not VectorType. // to_type is yield type, not VectorType.
CoerceFromAnyVecExpr(ExprPtr op, TypePtr to_type); CoerceFromAnyVecExpr(ExprPtr op, TypePtr to_type);
@ -1787,8 +1713,7 @@ protected:
}; };
// Expression used to explicitly capture [a, b, c, ...] = x assignments. // Expression used to explicitly capture [a, b, c, ...] = x assignments.
class AnyIndexExpr : public UnaryExpr class AnyIndexExpr : public UnaryExpr {
{
public: public:
AnyIndexExpr(ExprPtr op, int index); AnyIndexExpr(ExprPtr op, int index);
@ -1806,8 +1731,7 @@ protected:
}; };
// Used internally for optimization, when a placeholder is needed. // Used internally for optimization, when a placeholder is needed.
class NopExpr : public Expr class NopExpr : public Expr {
{
public: public:
explicit NopExpr() : Expr(EXPR_NOP) {} explicit NopExpr() : Expr(EXPR_NOP) {}
@ -1825,8 +1749,7 @@ protected:
// Factored out so that compiled code can call it as well as the interpreter. // Factored out so that compiled code can call it as well as the interpreter.
extern const char* assign_to_index(ValPtr v1, ValPtr v2, ValPtr v3, bool& iterators_invalidated); extern const char* assign_to_index(ValPtr v1, ValPtr v2, ValPtr v3, bool& iterators_invalidated);
inline Val* Expr::ExprVal() const inline Val* Expr::ExprVal() const {
{
if ( ! IsConst() ) if ( ! IsConst() )
BadTag("ExprVal::Val", expr_name(tag), expr_name(EXPR_CONST)); BadTag("ExprVal::Val", expr_name(tag), expr_name(EXPR_CONST));
return ((ConstExpr*)this)->Value(); return ((ConstExpr*)this)->Value();
@ -1869,25 +1792,13 @@ extern std::optional<std::vector<ValPtr>> eval_list(Frame* f, const ListExpr* l)
extern bool expr_greater(const Expr* e1, const Expr* e2); extern bool expr_greater(const Expr* e1, const Expr* e2);
// True if the given Expr* has a vector type // True if the given Expr* has a vector type
inline bool is_vector(Expr* e) inline bool is_vector(Expr* e) { return e->GetType()->Tag() == TYPE_VECTOR; }
{ inline bool is_vector(const ExprPtr& e) { return is_vector(e.get()); }
return e->GetType()->Tag() == TYPE_VECTOR;
}
inline bool is_vector(const ExprPtr& e)
{
return is_vector(e.get());
}
// True if the given Expr* has a list type // True if the given Expr* has a list type
inline bool is_list(Expr* e) inline bool is_list(Expr* e) { return e->GetType()->Tag() == TYPE_LIST; }
{
return e->GetType()->Tag() == TYPE_LIST;
}
inline bool is_list(const ExprPtr& e) inline bool is_list(const ExprPtr& e) { return is_list(e.get()); }
{
return is_list(e.get());
}
} // namespace detail } // namespace detail
} // namespace zeek } // namespace zeek

View file

@ -31,20 +31,17 @@
#include "zeek/Type.h" #include "zeek/Type.h"
#include "zeek/Var.h" #include "zeek/Var.h"
namespace zeek namespace zeek {
{
std::list<std::pair<std::string, File*>> File::open_files; std::list<std::pair<std::string, File*>> File::open_files;
// Maximizes the number of open file descriptors. // Maximizes the number of open file descriptors.
static void maximize_num_fds() static void maximize_num_fds() {
{
struct rlimit rl; struct rlimit rl;
if ( getrlimit(RLIMIT_NOFILE, &rl) < 0 ) if ( getrlimit(RLIMIT_NOFILE, &rl) < 0 )
reporter->FatalError("maximize_num_fds(): getrlimit failed"); reporter->FatalError("maximize_num_fds(): getrlimit failed");
if ( rl.rlim_max == RLIM_INFINITY ) if ( rl.rlim_max == RLIM_INFINITY ) {
{
// Don't try raising the current limit. // Don't try raising the current limit.
return; return;
} }
@ -56,8 +53,7 @@ static void maximize_num_fds()
reporter->FatalError("maximize_num_fds(): setrlimit failed"); reporter->FatalError("maximize_num_fds(): setrlimit failed");
} }
File::File(FILE* arg_f) File::File(FILE* arg_f) {
{
Init(); Init();
f = arg_f; f = arg_f;
name = access = nullptr; name = access = nullptr;
@ -65,8 +61,7 @@ File::File(FILE* arg_f)
is_open = (f != nullptr); is_open = (f != nullptr);
} }
File::File(FILE* arg_f, const char* arg_name, const char* arg_access) File::File(FILE* arg_f, const char* arg_name, const char* arg_access) {
{
Init(); Init();
f = arg_f; f = arg_f;
name = util::copy_string(arg_name); name = util::copy_string(arg_name);
@ -75,8 +70,7 @@ File::File(FILE* arg_f, const char* arg_name, const char* arg_access)
is_open = (f != nullptr); is_open = (f != nullptr);
} }
File::File(const char* arg_name, const char* arg_access) File::File(const char* arg_name, const char* arg_access) {
{
Init(); Init();
f = nullptr; f = nullptr;
name = util::copy_string(arg_name); name = util::copy_string(arg_name);
@ -93,15 +87,13 @@ File::File(const char* arg_name, const char* arg_access)
if ( f ) if ( f )
is_open = true; is_open = true;
else if ( ! Open() ) else if ( ! Open() ) {
{
reporter->Error("cannot open %s: %s", name, strerror(errno)); reporter->Error("cannot open %s: %s", name, strerror(errno));
is_open = false; is_open = false;
} }
} }
const char* File::Name() const const char* File::Name() const {
{
if ( name ) if ( name )
return name; return name;
@ -117,13 +109,11 @@ const char* File::Name() const
return nullptr; return nullptr;
} }
bool File::Open(FILE* file, const char* mode) bool File::Open(FILE* file, const char* mode) {
{
static bool fds_maximized = false; static bool fds_maximized = false;
open_time = run_state::network_time ? run_state::network_time : util::current_time(); open_time = run_state::network_time ? run_state::network_time : util::current_time();
if ( ! fds_maximized ) if ( ! fds_maximized ) {
{
// Haven't initialized yet. // Haven't initialized yet.
maximize_num_fds(); maximize_num_fds();
fds_maximized = true; fds_maximized = true;
@ -131,8 +121,7 @@ bool File::Open(FILE* file, const char* mode)
f = file; f = file;
if ( ! f ) if ( ! f ) {
{
if ( ! mode ) if ( ! mode )
f = fopen(name, access); f = fopen(name, access);
else else
@ -141,8 +130,7 @@ bool File::Open(FILE* file, const char* mode)
SetBuf(buffered); SetBuf(buffered);
if ( ! f ) if ( ! f ) {
{
is_open = false; is_open = false;
return false; return false;
} }
@ -155,8 +143,7 @@ bool File::Open(FILE* file, const char* mode)
return true; return true;
} }
File::~File() File::~File() {
{
Close(); Close();
Unref(attrs); Unref(attrs);
@ -168,8 +155,7 @@ File::~File()
#endif #endif
} }
void File::Init() void File::Init() {
{
open_time = 0; open_time = 0;
is_open = false; is_open = false;
attrs = nullptr; attrs = nullptr;
@ -181,13 +167,9 @@ void File::Init()
#endif #endif
} }
FILE* File::FileHandle() FILE* File::FileHandle() { return f; }
{
return f;
}
FILE* File::Seek(long new_position) FILE* File::Seek(long new_position) {
{
if ( ! FileHandle() ) if ( ! FileHandle() )
return nullptr; return nullptr;
@ -197,8 +179,7 @@ FILE* File::Seek(long new_position)
return f; return f;
} }
void File::SetBuf(bool arg_buffered) void File::SetBuf(bool arg_buffered) {
{
if ( ! f ) if ( ! f )
return; return;
@ -208,8 +189,7 @@ void File::SetBuf(bool arg_buffered)
buffered = arg_buffered; buffered = arg_buffered;
} }
bool File::Close() bool File::Close() {
{
if ( ! is_open ) if ( ! is_open )
return true; return true;
@ -230,24 +210,19 @@ bool File::Close()
return true; return true;
} }
void File::Unlink() void File::Unlink() {
{ for ( auto it = open_files.begin(); it != open_files.end(); ++it ) {
for ( auto it = open_files.begin(); it != open_files.end(); ++it ) if ( (*it).second == this ) {
{
if ( (*it).second == this )
{
open_files.erase(it); open_files.erase(it);
return; return;
} }
} }
} }
void File::Describe(ODesc* d) const void File::Describe(ODesc* d) const {
{
d->AddSP("file"); d->AddSP("file");
if ( name ) if ( name ) {
{
d->Add("\""); d->Add("\"");
d->Add(name); d->Add(name);
d->AddSP("\""); d->AddSP("\"");
@ -260,8 +235,7 @@ void File::Describe(ODesc* d) const
d->Add("(no type)"); d->Add("(no type)");
} }
void File::SetAttrs(detail::Attributes* arg_attrs) void File::SetAttrs(detail::Attributes* arg_attrs) {
{
if ( ! arg_attrs ) if ( ! arg_attrs )
return; return;
@ -272,8 +246,7 @@ void File::SetAttrs(detail::Attributes* arg_attrs)
EnableRawOutput(); EnableRawOutput();
} }
RecordVal* File::Rotate() RecordVal* File::Rotate() {
{
if ( ! is_open ) if ( ! is_open )
return nullptr; return nullptr;
@ -285,8 +258,7 @@ RecordVal* File::Rotate()
auto* info = new RecordVal(rotate_info); auto* info = new RecordVal(rotate_info);
FILE* newf = util::detail::rotate_file(name, info); FILE* newf = util::detail::rotate_file(name, info);
if ( ! newf ) if ( ! newf ) {
{
Unref(info); Unref(info);
return nullptr; return nullptr;
} }
@ -302,18 +274,15 @@ RecordVal* File::Rotate()
return info; return info;
} }
void File::CloseOpenFiles() void File::CloseOpenFiles() {
{
auto it = open_files.begin(); auto it = open_files.begin();
while ( it != open_files.end() ) while ( it != open_files.end() ) {
{
auto el = it++; auto el = it++;
(*el).second->Close(); (*el).second->Close();
} }
} }
bool File::Write(const char* data, int len) bool File::Write(const char* data, int len) {
{
if ( ! is_open ) if ( ! is_open )
return false; return false;
@ -326,8 +295,7 @@ bool File::Write(const char* data, int len)
return true; return true;
} }
void File::RaiseOpenEvent() void File::RaiseOpenEvent() {
{
if ( ! ::file_opened ) if ( ! ::file_opened )
return; return;
@ -336,12 +304,10 @@ void File::RaiseOpenEvent()
event_mgr.Dispatch(event, true); event_mgr.Dispatch(event, true);
} }
double File::Size() double File::Size() {
{
fflush(f); fflush(f);
struct stat s; struct stat s;
if ( fstat(fileno(f), &s) < 0 ) if ( fstat(fileno(f), &s) < 0 ) {
{
reporter->Error("can't stat fd for %s: %s", name, strerror(errno)); reporter->Error("can't stat fd for %s: %s", name, strerror(errno));
return 0; return 0;
} }
@ -349,8 +315,7 @@ double File::Size()
return s.st_size; return s.st_size;
} }
FilePtr File::Get(const char* name) FilePtr File::Get(const char* name) {
{
for ( const auto& el : open_files ) for ( const auto& el : open_files )
if ( el.first == name ) if ( el.first == name )
return {NewRef{}, el.second}; return {NewRef{}, el.second};

View file

@ -10,18 +10,16 @@
#include "zeek/Val.h" #include "zeek/Val.h"
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek namespace zeek {
{
namespace detail namespace detail {
{
class PrintStmt; class PrintStmt;
class Attributes; class Attributes;
extern void do_print_stmt(const std::vector<ValPtr>& vals); extern void do_print_stmt(const std::vector<ValPtr>& vals);
} // namespace detail; } // namespace detail
class RecordVal; class RecordVal;
class Type; class Type;
@ -30,8 +28,7 @@ using TypePtr = IntrusivePtr<Type>;
class File; class File;
using FilePtr = IntrusivePtr<File>; using FilePtr = IntrusivePtr<File>;
class File final : public Obj class File final : public Obj {
{
public: public:
explicit File(FILE* arg_f); explicit File(FILE* arg_f);
File(FILE* arg_f, const char* filename, const char* access); File(FILE* arg_f, const char* filename, const char* access);

View file

@ -13,12 +13,10 @@
#include <winsock2.h> #include <winsock2.h>
#define fatalError(...) \ #define fatalError(...) \
do \ do { \
{ \
if ( reporter ) \ if ( reporter ) \
reporter->FatalError(__VA_ARGS__); \ reporter->FatalError(__VA_ARGS__); \
else \ else { \
{ \
fprintf(stderr, __VA_ARGS__); \ fprintf(stderr, __VA_ARGS__); \
fprintf(stderr, "\n"); \ fprintf(stderr, "\n"); \
_exit(1); \ _exit(1); \
@ -27,13 +25,11 @@
#endif #endif
namespace zeek::detail namespace zeek::detail {
{
Flare::Flare() Flare::Flare()
#ifndef _MSC_VER #ifndef _MSC_VER
: pipe(FD_CLOEXEC, FD_CLOEXEC, O_NONBLOCK, O_NONBLOCK) : pipe(FD_CLOEXEC, FD_CLOEXEC, O_NONBLOCK, O_NONBLOCK) {
{
} }
#else #else
{ {
@ -41,12 +37,10 @@ Flare::Flare()
if ( WSAStartup(MAKEWORD(2, 2), &wsaData) != 0 ) if ( WSAStartup(MAKEWORD(2, 2), &wsaData) != 0 )
fatalError("WSAStartup failure: %d", WSAGetLastError()); fatalError("WSAStartup failure: %d", WSAGetLastError());
recvfd = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, nullptr, 0, recvfd = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, nullptr, 0, WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT);
WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT);
if ( recvfd == (int)INVALID_SOCKET ) if ( recvfd == (int)INVALID_SOCKET )
fatalError("WSASocket failure: %d", WSAGetLastError()); fatalError("WSASocket failure: %d", WSAGetLastError());
sendfd = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, nullptr, 0, sendfd = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, nullptr, 0, WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT);
WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT);
if ( sendfd == (int)INVALID_SOCKET ) if ( sendfd == (int)INVALID_SOCKET )
fatalError("WSASocket failure: %d", WSAGetLastError()); fatalError("WSASocket failure: %d", WSAGetLastError());
@ -64,8 +58,7 @@ Flare::Flare()
} }
#endif #endif
[[noreturn]] static void bad_pipe_op(const char* which, bool signal_safe) [[noreturn]] static void bad_pipe_op(const char* which, bool signal_safe) {
{
if ( signal_safe ) if ( signal_safe )
abort(); abort();
@ -74,19 +67,16 @@ Flare::Flare()
if ( reporter ) if ( reporter )
reporter->FatalErrorWithCore("unexpected pipe %s failure: %s", which, buf); reporter->FatalErrorWithCore("unexpected pipe %s failure: %s", which, buf);
else else {
{
fprintf(stderr, "unexpected pipe %s failure: %s", which, buf); fprintf(stderr, "unexpected pipe %s failure: %s", which, buf);
abort(); abort();
} }
} }
void Flare::Fire(bool signal_safe) void Flare::Fire(bool signal_safe) {
{
char tmp = 0; char tmp = 0;
for ( ;; ) for ( ;; ) {
{
#ifndef _MSC_VER #ifndef _MSC_VER
int n = write(pipe.WriteFD(), &tmp, 1); int n = write(pipe.WriteFD(), &tmp, 1);
@ -97,8 +87,7 @@ void Flare::Fire(bool signal_safe)
// Success -- wrote a byte to pipe. // Success -- wrote a byte to pipe.
break; break;
if ( n < 0 ) if ( n < 0 ) {
{
#ifdef _MSC_VER #ifdef _MSC_VER
errno = WSAGetLastError(); errno = WSAGetLastError();
bad_pipe_op("send", signal_safe); bad_pipe_op("send", signal_safe);
@ -118,20 +107,17 @@ void Flare::Fire(bool signal_safe)
} }
} }
int Flare::Extinguish(bool signal_safe) int Flare::Extinguish(bool signal_safe) {
{
int rval = 0; int rval = 0;
char tmp[256]; char tmp[256];
for ( ;; ) for ( ;; ) {
{
#ifndef _MSC_VER #ifndef _MSC_VER
int n = read(pipe.ReadFD(), &tmp, sizeof(tmp)); int n = read(pipe.ReadFD(), &tmp, sizeof(tmp));
#else #else
int n = recv(recvfd, tmp, sizeof(tmp), 0); int n = recv(recvfd, tmp, sizeof(tmp), 0);
#endif #endif
if ( n >= 0 ) if ( n >= 0 ) {
{
rval += n; rval += n;
// Pipe may not be empty yet: try again. // Pipe may not be empty yet: try again.
continue; continue;

View file

@ -6,11 +6,9 @@
#include "Pipe.h" #include "Pipe.h"
#endif #endif
namespace zeek::detail namespace zeek::detail {
{
class Flare class Flare {
{
public: public:
/** /**
* Create a flare object that can be used to signal a "ready" status via * Create a flare object that can be used to signal a "ready" status via

View file

@ -14,40 +14,34 @@
constexpr uint32_t MIN_ACCEPTABLE_FRAG_SIZE = 64; constexpr uint32_t MIN_ACCEPTABLE_FRAG_SIZE = 64;
constexpr uint32_t MAX_ACCEPTABLE_FRAG_SIZE = 64000; constexpr uint32_t MAX_ACCEPTABLE_FRAG_SIZE = 64000;
namespace zeek::detail namespace zeek::detail {
{
FragTimer::~FragTimer() FragTimer::~FragTimer() {
{
if ( f ) if ( f )
f->ClearTimer(); f->ClearTimer();
} }
void FragTimer::Dispatch(double t, bool /* is_expire */) void FragTimer::Dispatch(double t, bool /* is_expire */) {
{
if ( f ) if ( f )
f->Expire(t); f->Expire(t);
else else
reporter->InternalWarning("fragment timer dispatched w/o reassembler"); reporter->InternalWarning("fragment timer dispatched w/o reassembler");
} }
FragReassembler::FragReassembler(session::Manager* arg_s, const std::shared_ptr<IP_Hdr>& ip, FragReassembler::FragReassembler(session::Manager* arg_s, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt,
const u_char* pkt, const FragReassemblerKey& k, double t) const FragReassemblerKey& k, double t)
: Reassembler(0, REASSEM_FRAG) : Reassembler(0, REASSEM_FRAG) {
{
s = arg_s; s = arg_s;
key = k; key = k;
const struct ip* ip4 = ip->IP4_Hdr(); const struct ip* ip4 = ip->IP4_Hdr();
if ( ip4 ) if ( ip4 ) {
{
proto_hdr_len = ip->HdrLen(); proto_hdr_len = ip->HdrLen();
proto_hdr = new u_char[64]; // max IP header + slop proto_hdr = new u_char[64]; // max IP header + slop
// Don't do a structure copy - need to pick up options, too. // Don't do a structure copy - need to pick up options, too.
memcpy((void*)proto_hdr, (const void*)ip4, proto_hdr_len); memcpy((void*)proto_hdr, (const void*)ip4, proto_hdr_len);
} }
else else {
{
proto_hdr_len = ip->HdrLen() - 8; // minus length of fragment header proto_hdr_len = ip->HdrLen() - 8; // minus length of fragment header
proto_hdr = new u_char[proto_hdr_len]; proto_hdr = new u_char[proto_hdr_len];
memcpy(proto_hdr, ip->IP6_Hdr(), proto_hdr_len); memcpy(proto_hdr, ip->IP6_Hdr(), proto_hdr_len);
@ -57,8 +51,7 @@ FragReassembler::FragReassembler(session::Manager* arg_s, const std::shared_ptr<
frag_size = 0; // flag meaning "not known" frag_size = 0; // flag meaning "not known"
next_proto = ip->NextProto(); next_proto = ip->NextProto();
if ( frag_timeout != 0.0 ) if ( frag_timeout != 0.0 ) {
{
expire_timer = new FragTimer(this, t + frag_timeout); expire_timer = new FragTimer(this, t + frag_timeout);
timer_mgr->Add(expire_timer); timer_mgr->Add(expire_timer);
} }
@ -68,28 +61,23 @@ FragReassembler::FragReassembler(session::Manager* arg_s, const std::shared_ptr<
AddFragment(t, ip, pkt); AddFragment(t, ip, pkt);
} }
FragReassembler::~FragReassembler() FragReassembler::~FragReassembler() {
{
DeleteTimer(); DeleteTimer();
delete[] proto_hdr; delete[] proto_hdr;
} }
void FragReassembler::AddFragment(double t, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt) void FragReassembler::AddFragment(double t, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt) {
{
const struct ip* ip4 = ip->IP4_Hdr(); const struct ip* ip4 = ip->IP4_Hdr();
if ( ip4 ) if ( ip4 ) {
{ if ( ip4->ip_p != ((const struct ip*)proto_hdr)->ip_p || ip4->ip_hl != ((const struct ip*)proto_hdr)->ip_hl )
if ( ip4->ip_p != ((const struct ip*)proto_hdr)->ip_p ||
ip4->ip_hl != ((const struct ip*)proto_hdr)->ip_hl )
// || ip4->ip_tos != proto_hdr->ip_tos // || ip4->ip_tos != proto_hdr->ip_tos
// don't check TOS, there's at least one stack that actually // don't check TOS, there's at least one stack that actually
// uses different values, and it's hard to see an associated // uses different values, and it's hard to see an associated
// attack. // attack.
s->Weird("fragment_protocol_inconsistency", ip.get()); s->Weird("fragment_protocol_inconsistency", ip.get());
} }
else else {
{
if ( ip->NextProto() != next_proto || ip->HdrLen() - 8 != proto_hdr_len ) if ( ip->NextProto() != next_proto || ip->HdrLen() - 8 != proto_hdr_len )
s->Weird("fragment_protocol_inconsistency", ip.get()); s->Weird("fragment_protocol_inconsistency", ip.get());
// TODO: more detailed unfrag header consistency checks? // TODO: more detailed unfrag header consistency checks?
@ -103,8 +91,7 @@ void FragReassembler::AddFragment(double t, const std::shared_ptr<IP_Hdr>& ip, c
uint32_t len = ip->TotalLen(); uint32_t len = ip->TotalLen();
uint16_t hdr_len = ip->HdrLen(); uint16_t hdr_len = ip->HdrLen();
if ( len < hdr_len ) if ( len < hdr_len ) {
{
s->Weird("fragment_protocol_inconsistency", ip.get()); s->Weird("fragment_protocol_inconsistency", ip.get());
return; return;
} }
@ -115,14 +102,12 @@ void FragReassembler::AddFragment(double t, const std::shared_ptr<IP_Hdr>& ip, c
// Make sure to use the first fragment header's next field. // Make sure to use the first fragment header's next field.
next_proto = ip->NextProto(); next_proto = ip->NextProto();
if ( ! ip->MF() ) if ( ! ip->MF() ) {
{
// Last fragment. // Last fragment.
if ( frag_size == 0 ) if ( frag_size == 0 )
frag_size = upper_seq; frag_size = upper_seq;
else if ( upper_seq != frag_size ) else if ( upper_seq != frag_size ) {
{
s->Weird("fragment_size_inconsistency", ip.get()); s->Weird("fragment_size_inconsistency", ip.get());
if ( upper_seq > frag_size ) if ( upper_seq > frag_size )
@ -136,8 +121,7 @@ void FragReassembler::AddFragment(double t, const std::shared_ptr<IP_Hdr>& ip, c
if ( upper_seq > MAX_ACCEPTABLE_FRAG_SIZE ) if ( upper_seq > MAX_ACCEPTABLE_FRAG_SIZE )
s->Weird("excessively_large_fragment", ip.get()); s->Weird("excessively_large_fragment", ip.get());
if ( frag_size && upper_seq > frag_size ) if ( frag_size && upper_seq > frag_size ) {
{
// This can happen if we receive a fragment that's *not* // This can happen if we receive a fragment that's *not*
// the last fragment, but still imputes a size that's // the last fragment, but still imputes a size that's
// larger than the size we derived from a previously-seen // larger than the size we derived from a previously-seen
@ -157,39 +141,33 @@ void FragReassembler::AddFragment(double t, const std::shared_ptr<IP_Hdr>& ip, c
NewBlock(run_state::network_time, offset, len, pkt); NewBlock(run_state::network_time, offset, len, pkt);
} }
void FragReassembler::Weird(const char* name) const void FragReassembler::Weird(const char* name) const {
{
unsigned int version = ((const ip*)proto_hdr)->ip_v; unsigned int version = ((const ip*)proto_hdr)->ip_v;
if ( version == 4 ) if ( version == 4 ) {
{
IP_Hdr hdr((const ip*)proto_hdr, false); IP_Hdr hdr((const ip*)proto_hdr, false);
s->Weird(name, &hdr); s->Weird(name, &hdr);
} }
else if ( version == 6 ) else if ( version == 6 ) {
{
IP_Hdr hdr((const ip6_hdr*)proto_hdr, false, proto_hdr_len); IP_Hdr hdr((const ip6_hdr*)proto_hdr, false, proto_hdr_len);
s->Weird(name, &hdr); s->Weird(name, &hdr);
} }
else else {
{
reporter->InternalWarning("Unexpected IP version in FragReassembler"); reporter->InternalWarning("Unexpected IP version in FragReassembler");
reporter->Weird(name); reporter->Weird(name);
} }
} }
void FragReassembler::Overlap(const u_char* b1, const u_char* b2, uint64_t n) void FragReassembler::Overlap(const u_char* b1, const u_char* b2, uint64_t n) {
{
if ( memcmp((const void*)b1, (const void*)b2, n) ) if ( memcmp((const void*)b1, (const void*)b2, n) )
Weird("fragment_inconsistency"); Weird("fragment_inconsistency");
else else
Weird("fragment_overlap"); Weird("fragment_overlap");
} }
void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */) void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */) {
{
auto it = block_list.Begin(); auto it = block_list.Begin();
if ( it->second.seq > 0 || ! frag_size ) if ( it->second.seq > 0 || ! frag_size )
@ -199,8 +177,7 @@ void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */)
auto next = std::next(it); auto next = std::next(it);
// We might have it all - look for contiguous all the way. // We might have it all - look for contiguous all the way.
while ( next != block_list.End() ) while ( next != block_list.End() ) {
{
if ( it->second.upper != next->second.seq ) if ( it->second.upper != next->second.seq )
break; break;
@ -210,11 +187,9 @@ void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */)
const auto& last = block_list.LastBlock(); const auto& last = block_list.LastBlock();
if ( next != block_list.End() ) if ( next != block_list.End() ) {
{
// We have a hole. // We have a hole.
if ( it->second.upper >= frag_size ) if ( it->second.upper >= frag_size ) {
{
// We're stuck. The point where we stopped is // We're stuck. The point where we stopped is
// contiguous up through the expected end of // contiguous up through the expected end of
// the fragment, but there's more stuff still // the fragment, but there's more stuff still
@ -232,8 +207,7 @@ void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */)
return; return;
} }
else if ( last.upper > frag_size ) else if ( last.upper > frag_size ) {
{
Weird("fragment_size_inconsistency"); Weird("fragment_size_inconsistency");
frag_size = last.upper; frag_size = last.upper;
} }
@ -259,12 +233,10 @@ void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */)
pkt += proto_hdr_len; pkt += proto_hdr_len;
for ( it = block_list.Begin(); it != block_list.End(); ++it ) for ( it = block_list.Begin(); it != block_list.End(); ++it ) {
{
const auto& b = it->second; const auto& b = it->second;
if ( it != block_list.Begin() ) if ( it != block_list.Begin() ) {
{
const auto& prev = std::prev(it)->second; const auto& prev = std::prev(it)->second;
// If we're above a hole, stop. This can happen because // If we're above a hole, stop. This can happen because
@ -274,8 +246,7 @@ void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */)
break; break;
} }
if ( b.upper > n ) if ( b.upper > n ) {
{
reporter->InternalWarning("bad fragment reassembly"); reporter->InternalWarning("bad fragment reassembly");
DeleteTimer(); DeleteTimer();
Expire(run_state::network_time); Expire(run_state::network_time);
@ -290,16 +261,14 @@ void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */)
unsigned int version = ((const struct ip*)pkt_start)->ip_v; unsigned int version = ((const struct ip*)pkt_start)->ip_v;
if ( version == 4 ) if ( version == 4 ) {
{
struct ip* reassem4 = (struct ip*)pkt_start; struct ip* reassem4 = (struct ip*)pkt_start;
reassem4->ip_len = htons(frag_size + proto_hdr_len); reassem4->ip_len = htons(frag_size + proto_hdr_len);
reassembled_pkt = std::make_shared<IP_Hdr>(reassem4, true, true); reassembled_pkt = std::make_shared<IP_Hdr>(reassem4, true, true);
DeleteTimer(); DeleteTimer();
} }
else if ( version == 6 ) else if ( version == 6 ) {
{
struct ip6_hdr* reassem6 = (struct ip6_hdr*)pkt_start; struct ip6_hdr* reassem6 = (struct ip6_hdr*)pkt_start;
reassem6->ip6_plen = htons(frag_size + proto_hdr_len - 40); reassem6->ip6_plen = htons(frag_size + proto_hdr_len - 40);
const IPv6_Hdr_Chain* chain = new IPv6_Hdr_Chain(reassem6, next_proto, n); const IPv6_Hdr_Chain* chain = new IPv6_Hdr_Chain(reassem6, next_proto, n);
@ -307,15 +276,13 @@ void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */)
DeleteTimer(); DeleteTimer();
} }
else else {
{
reporter->InternalWarning("bad IP version in fragment reassembly: %d", version); reporter->InternalWarning("bad IP version in fragment reassembly: %d", version);
delete[] pkt_start; delete[] pkt_start;
} }
} }
void FragReassembler::Expire(double t) void FragReassembler::Expire(double t) {
{
block_list.Clear(); block_list.Clear();
expire_timer->ClearReassembler(); expire_timer->ClearReassembler();
expire_timer = nullptr; // timer manager will delete it expire_timer = nullptr; // timer manager will delete it
@ -323,24 +290,17 @@ void FragReassembler::Expire(double t)
fragment_mgr->Remove(this); fragment_mgr->Remove(this);
} }
void FragReassembler::DeleteTimer() void FragReassembler::DeleteTimer() {
{ if ( expire_timer ) {
if ( expire_timer )
{
expire_timer->ClearReassembler(); expire_timer->ClearReassembler();
timer_mgr->Cancel(expire_timer); timer_mgr->Cancel(expire_timer);
expire_timer = nullptr; // timer manager will delete it expire_timer = nullptr; // timer manager will delete it
} }
} }
FragmentManager::~FragmentManager() FragmentManager::~FragmentManager() { Clear(); }
{
Clear();
}
FragReassembler* FragmentManager::NextFragment(double t, const std::shared_ptr<IP_Hdr>& ip, FragReassembler* FragmentManager::NextFragment(double t, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt) {
const u_char* pkt)
{
uint32_t frag_id = ip->ID(); uint32_t frag_id = ip->ID();
FragReassemblerKey key = std::make_tuple(ip->SrcAddr(), ip->DstAddr(), frag_id); FragReassemblerKey key = std::make_tuple(ip->SrcAddr(), ip->DstAddr(), frag_id);
@ -349,8 +309,7 @@ FragReassembler* FragmentManager::NextFragment(double t, const std::shared_ptr<I
if ( it != fragments.end() ) if ( it != fragments.end() )
f = it->second; f = it->second;
if ( ! f ) if ( ! f ) {
{
f = new FragReassembler(session_mgr, ip, pkt, key, t); f = new FragReassembler(session_mgr, ip, pkt, key, t);
fragments[key] = f; fragments[key] = f;
if ( fragments.size() > max_fragments ) if ( fragments.size() > max_fragments )
@ -362,16 +321,14 @@ FragReassembler* FragmentManager::NextFragment(double t, const std::shared_ptr<I
return f; return f;
} }
void FragmentManager::Clear() void FragmentManager::Clear() {
{
for ( const auto& entry : fragments ) for ( const auto& entry : fragments )
Unref(entry.second); Unref(entry.second);
fragments.clear(); fragments.clear();
} }
void FragmentManager::Remove(detail::FragReassembler* f) void FragmentManager::Remove(detail::FragReassembler* f) {
{
if ( ! f ) if ( ! f )
return; return;

View file

@ -10,26 +10,22 @@
#include "zeek/Timer.h" #include "zeek/Timer.h"
#include "zeek/util.h" // for zeek_uint_t #include "zeek/util.h" // for zeek_uint_t
namespace zeek namespace zeek {
{
class IP_Hdr; class IP_Hdr;
namespace session namespace session {
{
class Manager; class Manager;
} }
namespace detail namespace detail {
{
class FragReassembler; class FragReassembler;
class FragTimer; class FragTimer;
using FragReassemblerKey = std::tuple<IPAddr, IPAddr, zeek_uint_t>; using FragReassemblerKey = std::tuple<IPAddr, IPAddr, zeek_uint_t>;
class FragReassembler : public Reassembler class FragReassembler : public Reassembler {
{
public: public:
FragReassembler(session::Manager* s, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt, FragReassembler(session::Manager* s, const std::shared_ptr<IP_Hdr>& ip, const u_char* pkt,
const FragReassemblerKey& k, double t); const FragReassemblerKey& k, double t);
@ -60,8 +56,7 @@ protected:
FragTimer* expire_timer; FragTimer* expire_timer;
}; };
class FragTimer final : public Timer class FragTimer final : public Timer {
{
public: public:
FragTimer(FragReassembler* arg_f, double arg_t) : Timer(arg_t, TIMER_FRAG) { f = arg_f; } FragTimer(FragReassembler* arg_f, double arg_t) : Timer(arg_t, TIMER_FRAG) { f = arg_f; }
~FragTimer() override; ~FragTimer() override;
@ -75,8 +70,7 @@ protected:
FragReassembler* f; FragReassembler* f;
}; };
class FragmentManager class FragmentManager {
{
public: public:
FragmentManager() = default; FragmentManager() = default;
~FragmentManager(); ~FragmentManager();
@ -96,8 +90,7 @@ private:
extern FragmentManager* fragment_mgr; extern FragmentManager* fragment_mgr;
class FragReassemblerTracker class FragReassemblerTracker {
{
public: public:
FragReassemblerTracker(FragReassembler* f) : frag_reassembler(f) {} FragReassemblerTracker(FragReassembler* f) : frag_reassembler(f) {}

View file

@ -13,11 +13,9 @@
std::vector<zeek::detail::Frame*> g_frame_stack; std::vector<zeek::detail::Frame*> g_frame_stack;
namespace zeek::detail namespace zeek::detail {
{
Frame::Frame(int arg_size, const ScriptFunc* func, const zeek::Args* fn_args) Frame::Frame(int arg_size, const ScriptFunc* func, const zeek::Args* fn_args) {
{
size = arg_size; size = arg_size;
frame = std::make_unique<Element[]>(size); frame = std::make_unique<Element[]>(size);
function = func; function = func;
@ -40,20 +38,16 @@ Frame::Frame(int arg_size, const ScriptFunc* func, const zeek::Args* fn_args)
current_offset = 0; current_offset = 0;
} }
void Frame::SetElement(int n, ValPtr v) void Frame::SetElement(int n, ValPtr v) {
{
n += current_offset; n += current_offset;
ASSERT(n >= 0 && n < size); ASSERT(n >= 0 && n < size);
frame[n] = std::move(v); frame[n] = std::move(v);
} }
void Frame::SetElement(const ID* id, ValPtr v) void Frame::SetElement(const ID* id, ValPtr v) {
{ if ( captures ) {
if ( captures )
{
auto cap_off = captures_offset_map->find(id->Name()); auto cap_off = captures_offset_map->find(id->Name());
if ( cap_off != captures_offset_map->end() ) if ( cap_off != captures_offset_map->end() ) {
{
captures->SetElement(cap_off->second, std::move(v)); captures->SetElement(cap_off->second, std::move(v));
return; return;
} }
@ -62,10 +56,8 @@ void Frame::SetElement(const ID* id, ValPtr v)
SetElement(id->Offset(), std::move(v)); SetElement(id->Offset(), std::move(v));
} }
const ValPtr& Frame::GetElementByID(const ID* id) const const ValPtr& Frame::GetElementByID(const ID* id) const {
{ if ( captures ) {
if ( captures )
{
auto cap_off = captures_offset_map->find(id->Name()); auto cap_off = captures_offset_map->find(id->Name());
if ( cap_off != captures_offset_map->end() ) if ( cap_off != captures_offset_map->end() )
return captures->GetElement(cap_off->second); return captures->GetElement(cap_off->second);
@ -74,23 +66,19 @@ const ValPtr& Frame::GetElementByID(const ID* id) const
return frame[id->Offset() + current_offset]; return frame[id->Offset() + current_offset];
} }
void Frame::Reset(int startIdx) void Frame::Reset(int startIdx) {
{
for ( int i = startIdx + current_offset; i < size; ++i ) for ( int i = startIdx + current_offset; i < size; ++i )
frame[i] = nullptr; frame[i] = nullptr;
} }
void Frame::Describe(ODesc* d) const void Frame::Describe(ODesc* d) const {
{
if ( ! d->IsBinary() ) if ( ! d->IsBinary() )
d->AddSP("frame"); d->AddSP("frame");
if ( ! d->IsReadable() ) if ( ! d->IsReadable() ) {
{
d->Add(size); d->Add(size);
for ( int i = 0; i < size; ++i ) for ( int i = 0; i < size; ++i ) {
{
d->Add(frame[i] != nullptr); d->Add(frame[i] != nullptr);
d->SP(); d->SP();
} }
@ -103,8 +91,7 @@ void Frame::Describe(ODesc* d) const
d->Add("<nil>"); d->Add("<nil>");
} }
Frame* Frame::Clone() const Frame* Frame::Clone() const {
{
Frame* other = new Frame(size, function, func_args); Frame* other = new Frame(size, function, func_args);
other->call = call; other->call = call;
@ -121,8 +108,7 @@ Frame* Frame::Clone() const
return other; return other;
} }
Frame* Frame::CloneForTrigger() const Frame* Frame::CloneForTrigger() const {
{
Frame* other = new Frame(0, function, func_args); Frame* other = new Frame(0, function, func_args);
other->call = call; other->call = call;
@ -132,20 +118,17 @@ Frame* Frame::CloneForTrigger() const
return other; return other;
} }
static bool val_is_func(const ValPtr& v, ScriptFunc* func) static bool val_is_func(const ValPtr& v, ScriptFunc* func) {
{
if ( v->GetType()->Tag() != TYPE_FUNC ) if ( v->GetType()->Tag() != TYPE_FUNC )
return false; return false;
return v->AsFunc() == func; return v->AsFunc() == func;
} }
broker::expected<broker::data> Frame::Serialize() broker::expected<broker::data> Frame::Serialize() {
{
broker::vector body; broker::vector body;
for ( int i = 0; i < size; ++i ) for ( int i = 0; i < size; ++i ) {
{
const auto& val = frame[i]; const auto& val = frame[i];
auto expected = Broker::detail::val_to_data(val.get()); auto expected = Broker::detail::val_to_data(val.get());
if ( ! expected ) if ( ! expected )
@ -162,8 +145,7 @@ broker::expected<broker::data> Frame::Serialize()
return {std::move(rval)}; return {std::move(rval)};
} }
std::pair<bool, FramePtr> Frame::Unserialize(const broker::vector& data) std::pair<bool, FramePtr> Frame::Unserialize(const broker::vector& data) {
{
if ( data.size() == 0 ) if ( data.size() == 0 )
return std::make_pair(true, nullptr); return std::make_pair(true, nullptr);
@ -176,8 +158,7 @@ std::pair<bool, FramePtr> Frame::Unserialize(const broker::vector& data)
int frame_size = body.size(); int frame_size = body.size();
auto rf = make_intrusive<Frame>(frame_size, nullptr, nullptr); auto rf = make_intrusive<Frame>(frame_size, nullptr, nullptr);
for ( int i = 0; i < frame_size; ++i ) for ( int i = 0; i < frame_size; ++i ) {
{
auto has_vec = broker::get_if<broker::vector>(body[i]); auto has_vec = broker::get_if<broker::vector>(body[i]);
if ( ! has_vec ) if ( ! has_vec )
continue; continue;
@ -203,21 +184,14 @@ std::pair<bool, FramePtr> Frame::Unserialize(const broker::vector& data)
return std::make_pair(true, std::move(rf)); return std::make_pair(true, std::move(rf));
} }
const detail::Location* Frame::GetCallLocation() const const detail::Location* Frame::GetCallLocation() const {
{
// This is currently trivial, but we keep it as an explicit // This is currently trivial, but we keep it as an explicit
// method because it can provide flexibility for compiled code. // method because it can provide flexibility for compiled code.
return call->GetLocationInfo(); return call->GetLocationInfo();
} }
void Frame::SetTrigger(trigger::TriggerPtr arg_trigger) void Frame::SetTrigger(trigger::TriggerPtr arg_trigger) { trigger = std::move(arg_trigger); }
{
trigger = std::move(arg_trigger);
}
void Frame::ClearTrigger() void Frame::ClearTrigger() { trigger = nullptr; }
{
trigger = nullptr;
}
} } // namespace zeek::detail

View file

@ -17,31 +17,27 @@
#include "zeek/ZeekArgs.h" #include "zeek/ZeekArgs.h"
#include "zeek/ZeekList.h" // for typedef val_list #include "zeek/ZeekList.h" // for typedef val_list
namespace zeek namespace zeek {
{
using ValPtr = IntrusivePtr<Val>; using ValPtr = IntrusivePtr<Val>;
namespace detail namespace detail {
{
class CallExpr; class CallExpr;
class ScriptFunc; class ScriptFunc;
using IDPtr = IntrusivePtr<ID>; using IDPtr = IntrusivePtr<ID>;
namespace trigger namespace trigger {
{
class Trigger; class Trigger;
using TriggerPtr = IntrusivePtr<Trigger>; using TriggerPtr = IntrusivePtr<Trigger>;
} } // namespace trigger
class Frame; class Frame;
using FramePtr = IntrusivePtr<Frame>; using FramePtr = IntrusivePtr<Frame>;
class Frame : public Obj class Frame : public Obj {
{
public: public:
/** /**
* Constructs a new frame belonging to *func* with *fn_args* * Constructs a new frame belonging to *func* with *fn_args*
@ -64,8 +60,7 @@ public:
* @param n the index to get. * @param n the index to get.
* @return the value at index *n* of the underlying array. * @return the value at index *n* of the underlying array.
*/ */
const ValPtr& GetElement(int n) const const ValPtr& GetElement(int n) const {
{
// Note: technically this may want to adjust by current_offset, but // Note: technically this may want to adjust by current_offset, but
// in practice, this method is never called from anywhere other than // in practice, this method is never called from anywhere other than
// function call invocation, where current_offset should be zero. // function call invocation, where current_offset should be zero.
@ -193,8 +188,7 @@ public:
void ClearTrigger(); void ClearTrigger();
trigger::Trigger* GetTrigger() const { return trigger.get(); } trigger::Trigger* GetTrigger() const { return trigger.get(); }
void SetCall(const CallExpr* arg_call) void SetCall(const CallExpr* arg_call) {
{
call = arg_call; call = arg_call;
SetTriggerAssoc((void*)call); SetTriggerAssoc((void*)call);
} }

File diff suppressed because it is too large Load diff

View file

@ -18,21 +18,19 @@
#include "zeek/ZeekArgs.h" #include "zeek/ZeekArgs.h"
#include "zeek/ZeekList.h" #include "zeek/ZeekList.h"
namespace broker namespace broker {
{
class data; class data;
using vector = std::vector<data>; using vector = std::vector<data>;
template <class> class expected; template<class>
} class expected;
} // namespace broker
namespace zeek namespace zeek {
{
class Val; class Val;
class FuncType; class FuncType;
namespace detail namespace detail {
{
class Scope; class Scope;
class Stmt; class Stmt;
@ -54,24 +52,18 @@ using EventGroupPtr = std::shared_ptr<EventGroup>;
class Func; class Func;
using FuncPtr = IntrusivePtr<Func>; using FuncPtr = IntrusivePtr<Func>;
class Func : public Obj class Func : public Obj {
{
public: public:
static inline const FuncPtr nil; static inline const FuncPtr nil;
enum Kind enum Kind { SCRIPT_FUNC, BUILTIN_FUNC };
{
SCRIPT_FUNC,
BUILTIN_FUNC
};
explicit Func(Kind arg_kind) : kind(arg_kind) {} explicit Func(Kind arg_kind) : kind(arg_kind) {}
virtual bool IsPure() const = 0; virtual bool IsPure() const = 0;
FunctionFlavor Flavor() const { return GetType()->Flavor(); } FunctionFlavor Flavor() const { return GetType()->Flavor(); }
struct Body struct Body {
{
detail::StmtPtr stmts; detail::StmtPtr stmts;
int priority; int priority;
std::set<EventGroupPtr> groups; std::set<EventGroupPtr> groups;
@ -79,10 +71,7 @@ public:
// The disabled field is updated from EventGroup instances. // The disabled field is updated from EventGroup instances.
bool disabled = false; bool disabled = false;
bool operator<(const Body& other) const bool operator<(const Body& other) const { return priority > other.priority; } // reverse sort
{
return priority > other.priority;
} // reverse sort
}; };
const std::vector<Body>& GetBodies() const { return bodies; } const std::vector<Body>& GetBodies() const { return bodies; }
@ -107,10 +96,8 @@ public:
* A version of Invoke() taking a variable number of individual arguments. * A version of Invoke() taking a variable number of individual arguments.
*/ */
template<class... Args> template<class... Args>
std::enable_if_t<std::is_convertible_v<std::tuple_element_t<0, std::tuple<Args...>>, ValPtr>, std::enable_if_t<std::is_convertible_v<std::tuple_element_t<0, std::tuple<Args...>>, ValPtr>, ValPtr> Invoke(
ValPtr> Args&&... args) const {
Invoke(Args&&... args) const
{
auto zargs = zeek::Args{std::forward<Args>(args)...}; auto zargs = zeek::Args{std::forward<Args>(args)...};
return Invoke(&zargs); return Invoke(&zargs);
} }
@ -121,11 +108,10 @@ public:
// as is a non-default second parameter to the first method, which // as is a non-default second parameter to the first method, which
// overrides the function body in "ingr". // overrides the function body in "ingr".
void AddBody(const detail::FunctionIngredients& ingr, detail::StmtPtr new_body = nullptr); void AddBody(const detail::FunctionIngredients& ingr, detail::StmtPtr new_body = nullptr);
virtual void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, virtual void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, size_t new_frame_size,
size_t new_frame_size, int priority, int priority, const std::set<EventGroupPtr>& groups);
const std::set<EventGroupPtr>& groups); void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, size_t new_frame_size,
void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, int priority = 0);
size_t new_frame_size, int priority = 0);
void AddBody(detail::StmtPtr new_body, size_t new_frame_size); void AddBody(detail::StmtPtr new_body, size_t new_frame_size);
virtual void SetScope(detail::ScopePtr newscope); virtual void SetScope(detail::ScopePtr newscope);
@ -168,17 +154,14 @@ private:
bool has_enabled_bodies = true; bool has_enabled_bodies = true;
}; };
namespace detail namespace detail {
{
class ScriptFunc : public Func class ScriptFunc : public Func {
{
public: public:
ScriptFunc(const IDPtr& id); ScriptFunc(const IDPtr& id);
// For compiled scripts. // For compiled scripts.
ScriptFunc(std::string name, FuncTypePtr ft, std::vector<StmtPtr> bodies, ScriptFunc(std::string name, FuncTypePtr ft, std::vector<StmtPtr> bodies, std::vector<int> priorities);
std::vector<int> priorities);
~ScriptFunc() override; ~ScriptFunc() override;
@ -220,8 +203,7 @@ public:
* *
* @return internal vector of ZVal's kept for persisting captures * @return internal vector of ZVal's kept for persisting captures
*/ */
auto& GetCapturesVec() const auto& GetCapturesVec() const {
{
ASSERT(captures_vec); ASSERT(captures_vec);
return *captures_vec; return *captures_vec;
} }
@ -252,9 +234,8 @@ public:
using Func::AddBody; using Func::AddBody;
void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, void AddBody(detail::StmtPtr new_body, const std::vector<detail::IDPtr>& new_inits, size_t new_frame_size,
size_t new_frame_size, int priority, int priority, const std::set<EventGroupPtr>& groups) override;
const std::set<EventGroupPtr>& groups) override;
/** /**
* Replaces the given current instance of a function body with * Replaces the given current instance of a function body with
@ -332,8 +313,7 @@ private:
using built_in_func = BifReturnVal (*)(Frame* frame, const Args* args); using built_in_func = BifReturnVal (*)(Frame* frame, const Args* args);
class BuiltinFunc final : public Func class BuiltinFunc final : public Func {
{
public: public:
BuiltinFunc(built_in_func func, const char* name, bool is_pure); BuiltinFunc(built_in_func func, const char* name, bool is_pure);
~BuiltinFunc() override = default; ~BuiltinFunc() override = default;
@ -345,8 +325,7 @@ public:
void Describe(ODesc* d) const override; void Describe(ODesc* d) const override;
protected: protected:
BuiltinFunc() BuiltinFunc() {
{
func = nullptr; func = nullptr;
is_pure = 0; is_pure = 0;
} }
@ -357,16 +336,14 @@ protected:
extern bool check_built_in_call(BuiltinFunc* f, CallExpr* call); extern bool check_built_in_call(BuiltinFunc* f, CallExpr* call);
struct CallInfo struct CallInfo {
{
const CallExpr* call; const CallExpr* call;
const Func* func; const Func* func;
const zeek::Args& args; const zeek::Args& args;
}; };
// Class that collects all the specifics defining a Func. // Class that collects all the specifics defining a Func.
class FunctionIngredients class FunctionIngredients {
{
public: public:
// Gathers all of the information from a scope and a function body needed // Gathers all of the information from a scope and a function body needed
// to build a function. // to build a function.
@ -427,8 +404,7 @@ extern bool did_builtin_init;
extern std::vector<void (*)()> bif_initializers; extern std::vector<void (*)()> bif_initializers;
extern void init_primary_bifs(); extern void init_primary_bifs();
inline void run_bif_initializers() inline void run_bif_initializers() {
{
for ( const auto& bi : bif_initializers ) for ( const auto& bi : bif_initializers )
bi(); bi();

View file

@ -18,8 +18,7 @@
#include "const.bif.netvar_h" #include "const.bif.netvar_h"
namespace zeek::detail namespace zeek::detail {
{
alignas(32) uint64_t KeyedHash::shared_highwayhash_key[4]; alignas(32) uint64_t KeyedHash::shared_highwayhash_key[4];
alignas(32) uint64_t KeyedHash::cluster_highwayhash_key[4]; alignas(32) uint64_t KeyedHash::cluster_highwayhash_key[4];
@ -27,17 +26,12 @@ alignas(16) unsigned long long KeyedHash::shared_siphash_key[2];
// we use the following lines to not pull in the highwayhash headers in Hash.h - but to check the // we use the following lines to not pull in the highwayhash headers in Hash.h - but to check the
// types did not change underneath us. // types did not change underneath us.
static_assert(std::is_same_v<hash64_t, highwayhash::HHResult64>, static_assert(std::is_same_v<hash64_t, highwayhash::HHResult64>, "Highwayhash return values must match hash_x_t");
"Highwayhash return values must match hash_x_t"); static_assert(std::is_same_v<hash128_t, highwayhash::HHResult128>, "Highwayhash return values must match hash_x_t");
static_assert(std::is_same_v<hash128_t, highwayhash::HHResult128>, static_assert(std::is_same_v<hash256_t, highwayhash::HHResult256>, "Highwayhash return values must match hash_x_t");
"Highwayhash return values must match hash_x_t");
static_assert(std::is_same_v<hash256_t, highwayhash::HHResult256>,
"Highwayhash return values must match hash_x_t");
void KeyedHash::InitializeSeeds(const std::array<uint32_t, SEED_INIT_SIZE>& seed_data) void KeyedHash::InitializeSeeds(const std::array<uint32_t, SEED_INIT_SIZE>& seed_data) {
{ static_assert(std::is_same_v<decltype(KeyedHash::shared_siphash_key), highwayhash::SipHashState::Key>,
static_assert(
std::is_same_v<decltype(KeyedHash::shared_siphash_key), highwayhash::SipHashState::Key>,
"Highwayhash Key is not unsigned long long[2]"); "Highwayhash Key is not unsigned long long[2]");
static_assert(std::is_same_v<decltype(KeyedHash::shared_highwayhash_key), highwayhash::HHKey>, static_assert(std::is_same_v<decltype(KeyedHash::shared_highwayhash_key), highwayhash::HHKey>,
"Highwayhash HHKey is not uint64_t[4]"); "Highwayhash HHKey is not uint64_t[4]");
@ -59,126 +53,91 @@ void KeyedHash::InitializeSeeds(const std::array<uint32_t, SEED_INIT_SIZE>& seed
seeds_initialized = true; seeds_initialized = true;
} }
void KeyedHash::InitOptions() void KeyedHash::InitOptions() {
{
calculate_digest(Hash_SHA256, BifConst::digest_salt->Bytes(), BifConst::digest_salt->Len(), calculate_digest(Hash_SHA256, BifConst::digest_salt->Bytes(), BifConst::digest_salt->Len(),
reinterpret_cast<unsigned char*>(cluster_highwayhash_key)); reinterpret_cast<unsigned char*>(cluster_highwayhash_key));
} }
hash64_t KeyedHash::Hash64(const void* bytes, uint64_t size) hash64_t KeyedHash::Hash64(const void* bytes, uint64_t size) {
{
return highwayhash::SipHash(shared_siphash_key, static_cast<const char*>(bytes), size); return highwayhash::SipHash(shared_siphash_key, static_cast<const char*>(bytes), size);
} }
void KeyedHash::Hash128(const void* bytes, uint64_t size, hash128_t* result) void KeyedHash::Hash128(const void* bytes, uint64_t size, hash128_t* result) {
{ highwayhash::InstructionSets::Run<highwayhash::HighwayHash>(shared_highwayhash_key, static_cast<const char*>(bytes),
highwayhash::InstructionSets::Run<highwayhash::HighwayHash>( size, result);
shared_highwayhash_key, static_cast<const char*>(bytes), size, result);
} }
void KeyedHash::Hash256(const void* bytes, uint64_t size, hash256_t* result) void KeyedHash::Hash256(const void* bytes, uint64_t size, hash256_t* result) {
{ highwayhash::InstructionSets::Run<highwayhash::HighwayHash>(shared_highwayhash_key, static_cast<const char*>(bytes),
highwayhash::InstructionSets::Run<highwayhash::HighwayHash>( size, result);
shared_highwayhash_key, static_cast<const char*>(bytes), size, result);
} }
hash64_t KeyedHash::StaticHash64(const void* bytes, uint64_t size) hash64_t KeyedHash::StaticHash64(const void* bytes, uint64_t size) {
{
hash64_t result = 0; hash64_t result = 0;
highwayhash::InstructionSets::Run<highwayhash::HighwayHash>( highwayhash::InstructionSets::Run<highwayhash::HighwayHash>(cluster_highwayhash_key,
cluster_highwayhash_key, static_cast<const char*>(bytes), size, &result); static_cast<const char*>(bytes), size, &result);
return result; return result;
} }
void KeyedHash::StaticHash128(const void* bytes, uint64_t size, hash128_t* result) void KeyedHash::StaticHash128(const void* bytes, uint64_t size, hash128_t* result) {
{ highwayhash::InstructionSets::Run<highwayhash::HighwayHash>(cluster_highwayhash_key,
highwayhash::InstructionSets::Run<highwayhash::HighwayHash>( static_cast<const char*>(bytes), size, result);
cluster_highwayhash_key, static_cast<const char*>(bytes), size, result);
} }
void KeyedHash::StaticHash256(const void* bytes, uint64_t size, hash256_t* result) void KeyedHash::StaticHash256(const void* bytes, uint64_t size, hash256_t* result) {
{ highwayhash::InstructionSets::Run<highwayhash::HighwayHash>(cluster_highwayhash_key,
highwayhash::InstructionSets::Run<highwayhash::HighwayHash>( static_cast<const char*>(bytes), size, result);
cluster_highwayhash_key, static_cast<const char*>(bytes), size, result);
} }
void init_hash_function() void init_hash_function() {
{
// Make sure we have already called init_random_seed(). // Make sure we have already called init_random_seed().
if ( ! KeyedHash::IsInitialized() ) if ( ! KeyedHash::IsInitialized() )
reporter->InternalError("Zeek's hash functions aren't fully initialized"); reporter->InternalError("Zeek's hash functions aren't fully initialized");
} }
HashKey::HashKey(bool b) HashKey::HashKey(bool b) { Set(b); }
{
Set(b);
}
HashKey::HashKey(int i) HashKey::HashKey(int i) { Set(i); }
{
Set(i);
}
HashKey::HashKey(zeek_int_t bi) HashKey::HashKey(zeek_int_t bi) { Set(bi); }
{
Set(bi);
}
HashKey::HashKey(zeek_uint_t bu) HashKey::HashKey(zeek_uint_t bu) { Set(bu); }
{
Set(bu);
}
HashKey::HashKey(uint32_t u) HashKey::HashKey(uint32_t u) { Set(u); }
{
Set(u);
}
HashKey::HashKey(const uint32_t u[], size_t n) HashKey::HashKey(const uint32_t u[], size_t n) {
{
size = write_size = n * sizeof(u[0]); size = write_size = n * sizeof(u[0]);
key = (char*)u; key = (char*)u;
} }
HashKey::HashKey(double d) HashKey::HashKey(double d) { Set(d); }
{
Set(d);
}
HashKey::HashKey(const void* p) HashKey::HashKey(const void* p) { Set(p); }
{
Set(p);
}
HashKey::HashKey(const char* s) HashKey::HashKey(const char* s) {
{
size = write_size = strlen(s); // note - skip final \0 size = write_size = strlen(s); // note - skip final \0
key = (char*)s; key = (char*)s;
} }
HashKey::HashKey(const String* s) HashKey::HashKey(const String* s) {
{
size = write_size = s->Len(); size = write_size = s->Len();
key = (char*)s->Bytes(); key = (char*)s->Bytes();
} }
HashKey::HashKey(const void* bytes, size_t arg_size) HashKey::HashKey(const void* bytes, size_t arg_size) {
{
size = write_size = arg_size; size = write_size = arg_size;
key = CopyKey((char*)bytes, size); key = CopyKey((char*)bytes, size);
is_our_dynamic = true; is_our_dynamic = true;
} }
HashKey::HashKey(const void* arg_key, size_t arg_size, hash_t arg_hash) HashKey::HashKey(const void* arg_key, size_t arg_size, hash_t arg_hash) {
{
size = write_size = arg_size; size = write_size = arg_size;
hash = arg_hash; hash = arg_hash;
key = CopyKey((char*)arg_key, size); key = CopyKey((char*)arg_key, size);
is_our_dynamic = true; is_our_dynamic = true;
} }
HashKey::HashKey(const void* arg_key, size_t arg_size, hash_t arg_hash, bool /* dont_copy */) HashKey::HashKey(const void* arg_key, size_t arg_size, hash_t arg_hash, bool /* dont_copy */) {
{
size = write_size = arg_size; size = write_size = arg_size;
hash = arg_hash; hash = arg_hash;
key = (char*)arg_key; key = (char*)arg_key;
@ -186,8 +145,7 @@ HashKey::HashKey(const void* arg_key, size_t arg_size, hash_t arg_hash, bool /*
HashKey::HashKey(const HashKey& other) : HashKey(other.key, other.size, other.hash) {} HashKey::HashKey(const HashKey& other) : HashKey(other.key, other.size, other.hash) {}
HashKey::HashKey(HashKey&& other) noexcept HashKey::HashKey(HashKey&& other) noexcept {
{
hash = other.hash; hash = other.hash;
size = other.size; size = other.size;
write_size = other.write_size; write_size = other.write_size;
@ -201,19 +159,16 @@ HashKey::HashKey(HashKey&& other) noexcept
other.key = nullptr; other.key = nullptr;
} }
HashKey::~HashKey() HashKey::~HashKey() {
{
if ( is_our_dynamic ) if ( is_our_dynamic )
delete[] reinterpret_cast<char*>(key); delete[] reinterpret_cast<char*>(key);
} }
hash_t HashKey::Hash() const hash_t HashKey::Hash() const {
{
if ( hash == 0 ) if ( hash == 0 )
hash = HashBytes(key, size); hash = HashBytes(key, size);
#ifdef DEBUG #ifdef DEBUG
if ( zeek::detail::debug_logger.IsEnabled(DBG_HASHKEY) ) if ( zeek::detail::debug_logger.IsEnabled(DBG_HASHKEY) ) {
{
ODesc d; ODesc d;
Describe(&d); Describe(&d);
DBG_LOG(DBG_HASHKEY, "HashKey %p %s", this, d.Description()); DBG_LOG(DBG_HASHKEY, "HashKey %p %s", this, d.Description());
@ -222,10 +177,8 @@ hash_t HashKey::Hash() const
return hash; return hash;
} }
void* HashKey::TakeKey() void* HashKey::TakeKey() {
{ if ( is_our_dynamic ) {
if ( is_our_dynamic )
{
is_our_dynamic = false; is_our_dynamic = false;
return key; return key;
} }
@ -233,21 +186,17 @@ void* HashKey::TakeKey()
return CopyKey(key, size); return CopyKey(key, size);
} }
void HashKey::Describe(ODesc* d) const void HashKey::Describe(ODesc* d) const {
{
char buf[64]; char buf[64];
snprintf(buf, 16, "%0" PRIx64, hash); snprintf(buf, 16, "%0" PRIx64, hash);
d->Add(buf); d->Add(buf);
d->SP(); d->SP();
if ( size > 0 ) if ( size > 0 ) {
{
d->Add(IsAllocated() ? "(" : "["); d->Add(IsAllocated() ? "(" : "[");
for ( size_t i = 0; i < size; i++ ) for ( size_t i = 0; i < size; i++ ) {
{ if ( i > 0 ) {
if ( i > 0 )
{
d->SP(); d->SP();
// Extra spacing every 8 bytes, for readability. // Extra spacing every 8 bytes, for readability.
if ( i % 8 == 0 ) if ( i % 8 == 0 )
@ -255,8 +204,7 @@ void HashKey::Describe(ODesc* d) const
} }
// Don't display unwritten content, only say how much there is. // Don't display unwritten content, only say how much there is.
if ( i > write_size ) if ( i > write_size ) {
{
d->Add("<+"); d->Add("<+");
d->Add(static_cast<uint64_t>(size - write_size - 1)); d->Add(static_cast<uint64_t>(size - write_size - 1));
d->Add(" of "); d->Add(" of ");
@ -273,82 +221,68 @@ void HashKey::Describe(ODesc* d) const
} }
} }
char* HashKey::CopyKey(const char* k, size_t s) const char* HashKey::CopyKey(const char* k, size_t s) const {
{
char* k_copy = new char[s]; // s == 0 is okay, returns non-nil char* k_copy = new char[s]; // s == 0 is okay, returns non-nil
memcpy(k_copy, k, s); memcpy(k_copy, k, s);
return k_copy; return k_copy;
} }
hash_t HashKey::HashBytes(const void* bytes, size_t size) hash_t HashKey::HashBytes(const void* bytes, size_t size) { return KeyedHash::Hash64(bytes, size); }
{
return KeyedHash::Hash64(bytes, size);
}
void HashKey::Set(bool b) void HashKey::Set(bool b) {
{
key_u.b = b; key_u.b = b;
key = reinterpret_cast<char*>(&key_u); key = reinterpret_cast<char*>(&key_u);
size = write_size = sizeof(b); size = write_size = sizeof(b);
} }
void HashKey::Set(int i) void HashKey::Set(int i) {
{
key_u.i = i; key_u.i = i;
key = reinterpret_cast<char*>(&key_u); key = reinterpret_cast<char*>(&key_u);
size = write_size = sizeof(i); size = write_size = sizeof(i);
} }
void HashKey::Set(zeek_int_t bi) void HashKey::Set(zeek_int_t bi) {
{
key_u.bi = bi; key_u.bi = bi;
key = reinterpret_cast<char*>(&key_u); key = reinterpret_cast<char*>(&key_u);
size = write_size = sizeof(bi); size = write_size = sizeof(bi);
} }
void HashKey::Set(zeek_uint_t bu) void HashKey::Set(zeek_uint_t bu) {
{
key_u.bi = zeek_int_t(bu); key_u.bi = zeek_int_t(bu);
key = reinterpret_cast<char*>(&key_u); key = reinterpret_cast<char*>(&key_u);
size = write_size = sizeof(bu); size = write_size = sizeof(bu);
} }
void HashKey::Set(uint32_t u) void HashKey::Set(uint32_t u) {
{
key_u.u32 = u; key_u.u32 = u;
key = reinterpret_cast<char*>(&key_u); key = reinterpret_cast<char*>(&key_u);
size = write_size = sizeof(u); size = write_size = sizeof(u);
} }
void HashKey::Set(double d) void HashKey::Set(double d) {
{
key_u.d = d; key_u.d = d;
key = reinterpret_cast<char*>(&key_u); key = reinterpret_cast<char*>(&key_u);
size = write_size = sizeof(d); size = write_size = sizeof(d);
} }
void HashKey::Set(const void* p) void HashKey::Set(const void* p) {
{
key_u.p = p; key_u.p = p;
key = reinterpret_cast<char*>(&key_u); key = reinterpret_cast<char*>(&key_u);
size = write_size = sizeof(p); size = write_size = sizeof(p);
} }
void HashKey::Reserve(const char* tag, size_t addl_size, size_t alignment) void HashKey::Reserve(const char* tag, size_t addl_size, size_t alignment) {
{
ASSERT(! IsAllocated()); ASSERT(! IsAllocated());
size_t s0 = size; size_t s0 = size;
size_t s1 = util::memory_size_align(size, alignment); size_t s1 = util::memory_size_align(size, alignment);
size = s1 + addl_size; size = s1 + addl_size;
DBG_LOG(DBG_HASHKEY, "HashKey %p reserving %lu/%lu: %lu -> %lu -> %lu [%s]", this, addl_size, DBG_LOG(DBG_HASHKEY, "HashKey %p reserving %lu/%lu: %lu -> %lu -> %lu [%s]", this, addl_size, alignment, s0, s1,
alignment, s0, s1, size, tag); size, tag);
} }
void HashKey::Allocate() void HashKey::Allocate() {
{ if ( key != nullptr && key != reinterpret_cast<char*>(&key_u) ) {
if ( key != nullptr && key != reinterpret_cast<char*>(&key_u) )
{
reporter->InternalWarning("usage error in HashKey::Allocate(): already allocated"); reporter->InternalWarning("usage error in HashKey::Allocate(): already allocated");
return; return;
} }
@ -360,15 +294,10 @@ void HashKey::Allocate()
write_size = 0; write_size = 0;
} }
void HashKey::Write(const char* tag, bool b) void HashKey::Write(const char* tag, bool b) { Write(tag, &b, sizeof(b), 0); }
{
Write(tag, &b, sizeof(b), 0);
}
void HashKey::Write(const char* tag, int i, bool align) void HashKey::Write(const char* tag, int i, bool align) {
{ if ( ! IsAllocated() ) {
if ( ! IsAllocated() )
{
Set(i); Set(i);
return; return;
} }
@ -376,10 +305,8 @@ void HashKey::Write(const char* tag, int i, bool align)
Write(tag, &i, sizeof(i), align ? sizeof(i) : 0); Write(tag, &i, sizeof(i), align ? sizeof(i) : 0);
} }
void HashKey::Write(const char* tag, zeek_int_t bi, bool align) void HashKey::Write(const char* tag, zeek_int_t bi, bool align) {
{ if ( ! IsAllocated() ) {
if ( ! IsAllocated() )
{
Set(bi); Set(bi);
return; return;
} }
@ -387,10 +314,8 @@ void HashKey::Write(const char* tag, zeek_int_t bi, bool align)
Write(tag, &bi, sizeof(bi), align ? sizeof(bi) : 0); Write(tag, &bi, sizeof(bi), align ? sizeof(bi) : 0);
} }
void HashKey::Write(const char* tag, zeek_uint_t bu, bool align) void HashKey::Write(const char* tag, zeek_uint_t bu, bool align) {
{ if ( ! IsAllocated() ) {
if ( ! IsAllocated() )
{
Set(bu); Set(bu);
return; return;
} }
@ -398,10 +323,8 @@ void HashKey::Write(const char* tag, zeek_uint_t bu, bool align)
Write(tag, &bu, sizeof(bu), align ? sizeof(bu) : 0); Write(tag, &bu, sizeof(bu), align ? sizeof(bu) : 0);
} }
void HashKey::Write(const char* tag, uint32_t u, bool align) void HashKey::Write(const char* tag, uint32_t u, bool align) {
{ if ( ! IsAllocated() ) {
if ( ! IsAllocated() )
{
Set(u); Set(u);
return; return;
} }
@ -409,10 +332,8 @@ void HashKey::Write(const char* tag, uint32_t u, bool align)
Write(tag, &u, sizeof(u), align ? sizeof(u) : 0); Write(tag, &u, sizeof(u), align ? sizeof(u) : 0);
} }
void HashKey::Write(const char* tag, double d, bool align) void HashKey::Write(const char* tag, double d, bool align) {
{ if ( ! IsAllocated() ) {
if ( ! IsAllocated() )
{
Set(d); Set(d);
return; return;
} }
@ -420,8 +341,7 @@ void HashKey::Write(const char* tag, double d, bool align)
Write(tag, &d, sizeof(d), align ? sizeof(d) : 0); Write(tag, &d, sizeof(d), align ? sizeof(d) : 0);
} }
void HashKey::Write(const char* tag, const void* bytes, size_t n, size_t alignment) void HashKey::Write(const char* tag, const void* bytes, size_t n, size_t alignment) {
{
size_t s0 = write_size; size_t s0 = write_size;
AlignWrite(alignment); AlignWrite(alignment);
size_t s1 = write_size; size_t s1 = write_size;
@ -430,21 +350,18 @@ void HashKey::Write(const char* tag, const void* bytes, size_t n, size_t alignme
memcpy(key + write_size, bytes, n); memcpy(key + write_size, bytes, n);
write_size += n; write_size += n;
DBG_LOG(DBG_HASHKEY, "HashKey %p writing %lu/%lu: %lu -> %lu -> %lu [%s]", this, n, alignment, DBG_LOG(DBG_HASHKEY, "HashKey %p writing %lu/%lu: %lu -> %lu -> %lu [%s]", this, n, alignment, s0, s1, write_size,
s0, s1, write_size, tag); tag);
} }
void HashKey::SkipWrite(const char* tag, size_t n) void HashKey::SkipWrite(const char* tag, size_t n) {
{ DBG_LOG(DBG_HASHKEY, "HashKey %p skip-writing %lu: %lu -> %lu [%s]", this, n, write_size, write_size + n, tag);
DBG_LOG(DBG_HASHKEY, "HashKey %p skip-writing %lu: %lu -> %lu [%s]", this, n, write_size,
write_size + n, tag);
EnsureWriteSpace(n); EnsureWriteSpace(n);
write_size += n; write_size += n;
} }
void HashKey::AlignWrite(size_t alignment) void HashKey::AlignWrite(size_t alignment) {
{
ASSERT(IsAllocated()); ASSERT(IsAllocated());
if ( alignment == 0 ) if ( alignment == 0 )
@ -455,7 +372,8 @@ void HashKey::AlignWrite(size_t alignment)
write_size = util::memory_size_align(write_size, alignment); write_size = util::memory_size_align(write_size, alignment);
if ( write_size > size ) if ( write_size > size )
reporter->InternalError("buffer overflow in HashKey::AlignWrite(): " reporter->InternalError(
"buffer overflow in HashKey::AlignWrite(): "
"after alignment, %lu bytes used of %lu allocated", "after alignment, %lu bytes used of %lu allocated",
write_size, size); write_size, size);
@ -463,8 +381,7 @@ void HashKey::AlignWrite(size_t alignment)
key[old_size++] = '\0'; key[old_size++] = '\0';
} }
void HashKey::AlignRead(size_t alignment) const void HashKey::AlignRead(size_t alignment) const {
{
ASSERT(IsAllocated()); ASSERT(IsAllocated());
if ( alignment == 0 ) if ( alignment == 0 )
@ -475,43 +392,29 @@ void HashKey::AlignRead(size_t alignment) const
read_size = util::memory_size_align(read_size, alignment); read_size = util::memory_size_align(read_size, alignment);
if ( read_size > size ) if ( read_size > size )
reporter->InternalError("buffer overflow in HashKey::AlignRead(): " reporter->InternalError(
"buffer overflow in HashKey::AlignRead(): "
"after alignment, %lu bytes used of %lu allocated", "after alignment, %lu bytes used of %lu allocated",
read_size, size); read_size, size);
} }
void HashKey::Read(const char* tag, bool& b) const void HashKey::Read(const char* tag, bool& b) const { Read(tag, &b, sizeof(b), 0); }
{
Read(tag, &b, sizeof(b), 0);
}
void HashKey::Read(const char* tag, int& i, bool align) const void HashKey::Read(const char* tag, int& i, bool align) const { Read(tag, &i, sizeof(i), align ? sizeof(i) : 0); }
{
void HashKey::Read(const char* tag, zeek_int_t& i, bool align) const {
Read(tag, &i, sizeof(i), align ? sizeof(i) : 0); Read(tag, &i, sizeof(i), align ? sizeof(i) : 0);
} }
void HashKey::Read(const char* tag, zeek_int_t& i, bool align) const void HashKey::Read(const char* tag, zeek_uint_t& u, bool align) const {
{
Read(tag, &i, sizeof(i), align ? sizeof(i) : 0);
}
void HashKey::Read(const char* tag, zeek_uint_t& u, bool align) const
{
Read(tag, &u, sizeof(u), align ? sizeof(u) : 0); Read(tag, &u, sizeof(u), align ? sizeof(u) : 0);
} }
void HashKey::Read(const char* tag, uint32_t& u, bool align) const void HashKey::Read(const char* tag, uint32_t& u, bool align) const { Read(tag, &u, sizeof(u), align ? sizeof(u) : 0); }
{
Read(tag, &u, sizeof(u), align ? sizeof(u) : 0);
}
void HashKey::Read(const char* tag, double& d, bool align) const void HashKey::Read(const char* tag, double& d, bool align) const { Read(tag, &d, sizeof(d), align ? sizeof(d) : 0); }
{
Read(tag, &d, sizeof(d), align ? sizeof(d) : 0);
}
void HashKey::Read(const char* tag, void* out, size_t n, size_t alignment) const void HashKey::Read(const char* tag, void* out, size_t n, size_t alignment) const {
{
size_t s0 = read_size; size_t s0 = read_size;
AlignRead(alignment); AlignRead(alignment);
size_t s1 = read_size; size_t s1 = read_size;
@ -522,55 +425,53 @@ void HashKey::Read(const char* tag, void* out, size_t n, size_t alignment) const
// in memcpy even if the size is 0. // in memcpy even if the size is 0.
ASSERT(out != nullptr || (out == nullptr && n == 0)); ASSERT(out != nullptr || (out == nullptr && n == 0));
if ( n > 0 ) if ( n > 0 ) {
{
memcpy(out, key + read_size, n); memcpy(out, key + read_size, n);
read_size += n; read_size += n;
} }
DBG_LOG(DBG_HASHKEY, "HashKey %p reading %lu/%lu: %lu -> %lu -> %lu [%s]", this, n, alignment, DBG_LOG(DBG_HASHKEY, "HashKey %p reading %lu/%lu: %lu -> %lu -> %lu [%s]", this, n, alignment, s0, s1, read_size,
s0, s1, read_size, tag); tag);
} }
void HashKey::SkipRead(const char* tag, size_t n) const void HashKey::SkipRead(const char* tag, size_t n) const {
{ DBG_LOG(DBG_HASHKEY, "HashKey %p skip-reading %lu: %lu -> %lu [%s]", this, n, read_size, read_size + n, tag);
DBG_LOG(DBG_HASHKEY, "HashKey %p skip-reading %lu: %lu -> %lu [%s]", this, n, read_size,
read_size + n, tag);
EnsureReadSpace(n); EnsureReadSpace(n);
read_size += n; read_size += n;
} }
void HashKey::EnsureWriteSpace(size_t n) const void HashKey::EnsureWriteSpace(size_t n) const {
{
if ( n == 0 ) if ( n == 0 )
return; return;
if ( ! IsAllocated() ) if ( ! IsAllocated() )
reporter->InternalError("usage error in HashKey::EnsureWriteSpace(): " reporter->InternalError(
"usage error in HashKey::EnsureWriteSpace(): "
"size-checking unreserved buffer"); "size-checking unreserved buffer");
if ( write_size + n > size ) if ( write_size + n > size )
reporter->InternalError("buffer overflow in HashKey::Write(): writing %lu " reporter->InternalError(
"buffer overflow in HashKey::Write(): writing %lu "
"bytes with %lu remaining", "bytes with %lu remaining",
n, size - write_size); n, size - write_size);
} }
void HashKey::EnsureReadSpace(size_t n) const void HashKey::EnsureReadSpace(size_t n) const {
{
if ( n == 0 ) if ( n == 0 )
return; return;
if ( ! IsAllocated() ) if ( ! IsAllocated() )
reporter->InternalError("usage error in HashKey::EnsureReadSpace(): " reporter->InternalError(
"usage error in HashKey::EnsureReadSpace(): "
"size-checking unreserved buffer"); "size-checking unreserved buffer");
if ( read_size + n > size ) if ( read_size + n > size )
reporter->InternalError("buffer overflow in HashKey::EnsureReadSpace(): reading %lu " reporter->InternalError(
"buffer overflow in HashKey::EnsureReadSpace(): reading %lu "
"bytes with %lu remaining", "bytes with %lu remaining",
n, size - read_size); n, size - read_size);
} }
bool HashKey::operator==(const HashKey& other) const bool HashKey::operator==(const HashKey& other) const {
{
// Quick exit for the same object. // Quick exit for the same object.
if ( this == &other ) if ( this == &other )
return true; return true;
@ -578,8 +479,7 @@ bool HashKey::operator==(const HashKey& other) const
return Equal(other.key, other.size, other.hash); return Equal(other.key, other.size, other.hash);
} }
bool HashKey::operator!=(const HashKey& other) const bool HashKey::operator!=(const HashKey& other) const {
{
// Quick exit for different objects. // Quick exit for different objects.
if ( this != &other ) if ( this != &other )
return true; return true;
@ -587,8 +487,7 @@ bool HashKey::operator!=(const HashKey& other) const
return ! Equal(other.key, other.size, other.hash); return ! Equal(other.key, other.size, other.hash);
} }
bool HashKey::Equal(const void* other_key, size_t other_size, hash_t other_hash) const bool HashKey::Equal(const void* other_key, size_t other_size, hash_t other_hash) const {
{
// If the key memory is the same just return true. // If the key memory is the same just return true.
if ( key == other_key && size == other_size ) if ( key == other_key && size == other_size )
return true; return true;
@ -601,8 +500,7 @@ bool HashKey::Equal(const void* other_key, size_t other_size, hash_t other_hash)
return (hash == other_hash) && (size == other_size) && (memcmp(key, other_key, size) == 0); return (hash == other_hash) && (size == other_size) && (memcmp(key, other_key, size) == 0);
} }
HashKey& HashKey::operator=(const HashKey& other) HashKey& HashKey::operator=(const HashKey& other) {
{
if ( this == &other ) if ( this == &other )
return *this; return *this;
@ -620,8 +518,7 @@ HashKey& HashKey::operator=(const HashKey& other)
return *this; return *this;
} }
HashKey& HashKey::operator=(HashKey&& other) noexcept HashKey& HashKey::operator=(HashKey&& other) noexcept {
{
if ( this == &other ) if ( this == &other )
return *this; return *this;
@ -645,8 +542,7 @@ HashKey& HashKey::operator=(HashKey&& other) noexcept
TEST_SUITE_BEGIN("Hash"); TEST_SUITE_BEGIN("Hash");
TEST_CASE("equality") TEST_CASE("equality") {
{
HashKey h1(12345); HashKey h1(12345);
HashKey h2(12345); HashKey h2(12345);
HashKey h3(67890); HashKey h3(67890);
@ -655,8 +551,7 @@ TEST_CASE("equality")
CHECK(h1 != h3); CHECK(h1 != h3);
} }
TEST_CASE("copy assignment") TEST_CASE("copy assignment") {
{
HashKey h1(12345); HashKey h1(12345);
HashKey h2 = h1; HashKey h2 = h1;
HashKey h3{h1}; HashKey h3{h1};
@ -665,8 +560,7 @@ TEST_CASE("copy assignment")
CHECK(h1 == h3); CHECK(h1 == h3);
} }
TEST_CASE("move assignment") TEST_CASE("move assignment") {
{
HashKey h1(12345); HashKey h1(12345);
HashKey h2(12345); HashKey h2(12345);
HashKey h3(12345); HashKey h3(12345);

View file

@ -27,37 +27,32 @@
// to allow md5_hmac_bif access to the hmac seed // to allow md5_hmac_bif access to the hmac seed
#include "zeek/ZeekArgs.h" #include "zeek/ZeekArgs.h"
namespace zeek namespace zeek {
{
class String; class String;
class ODesc; class ODesc;
} } // namespace zeek
namespace zeek::detail namespace zeek::detail {
{
class Frame; class Frame;
class BifReturnVal; class BifReturnVal;
} } // namespace zeek::detail
namespace zeek::BifFunc namespace zeek::BifFunc {
{
extern zeek::detail::BifReturnVal md5_hmac_bif(zeek::detail::Frame* frame, const zeek::Args*); extern zeek::detail::BifReturnVal md5_hmac_bif(zeek::detail::Frame* frame, const zeek::Args*);
} }
namespace zeek::detail namespace zeek::detail {
{
using hash_t = uint64_t; using hash_t = uint64_t;
using hash64_t = uint64_t; using hash64_t = uint64_t;
using hash128_t = uint64_t[2]; using hash128_t = uint64_t[2];
using hash256_t = uint64_t[4]; using hash256_t = uint64_t[4];
class KeyedHash class KeyedHash {
{
public: public:
/** /**
* Generate a 64 bit digest hash. * Generate a 64 bit digest hash.
@ -215,22 +210,15 @@ private:
inline static uint8_t shared_hmac_md5_key[16]; inline static uint8_t shared_hmac_md5_key[16];
inline static bool seeds_initialized = false; inline static bool seeds_initialized = false;
friend void util::detail::hmac_md5(size_t size, const unsigned char* bytes, friend void util::detail::hmac_md5(size_t size, const unsigned char* bytes, unsigned char digest[16]);
unsigned char digest[16]);
friend BifReturnVal BifFunc::md5_hmac_bif(zeek::detail::Frame* frame, const Args*); friend BifReturnVal BifFunc::md5_hmac_bif(zeek::detail::Frame* frame, const Args*);
}; };
enum HashKeyTag enum HashKeyTag { HASH_KEY_INT, HASH_KEY_DOUBLE, HASH_KEY_STRING };
{
HASH_KEY_INT,
HASH_KEY_DOUBLE,
HASH_KEY_STRING
};
constexpr int NUM_HASH_KEYS = HASH_KEY_STRING + 1; constexpr int NUM_HASH_KEYS = HASH_KEY_STRING + 1;
class HashKey class HashKey {
{
public: public:
explicit HashKey() { key_u.u32 = 0; } explicit HashKey() { key_u.u32 = 0; }
explicit HashKey(bool b); explicit HashKey(bool b);
@ -280,15 +268,15 @@ public:
// A HashKey is "allocated" when the underlying key points somewhere // A HashKey is "allocated" when the underlying key points somewhere
// other than our internal key_u union. This is almost like // other than our internal key_u union. This is almost like
// is_our_dynamic, but remains true also after TakeKey(). // is_our_dynamic, but remains true also after TakeKey().
bool IsAllocated() const bool IsAllocated() const { return (key != nullptr && key != reinterpret_cast<const char*>(&key_u)); }
{
return (key != nullptr && key != reinterpret_cast<const char*>(&key_u));
}
// Buffer size reservation. Repeated calls to these methods // Buffer size reservation. Repeated calls to these methods
// incrementally build up the eventual buffer size to be allocated via // incrementally build up the eventual buffer size to be allocated via
// Allocate(). // Allocate().
template <typename T> void ReserveType(const char* tag) { Reserve(tag, sizeof(T), sizeof(T)); } template<typename T>
void ReserveType(const char* tag) {
Reserve(tag, sizeof(T), sizeof(T));
}
void Reserve(const char* tag, size_t addl_size, size_t alignment = 0); void Reserve(const char* tag, size_t addl_size, size_t alignment = 0);
// Allocates the reserved amount of memory // Allocates the reserved amount of memory

281
src/ID.cc
View file

@ -22,8 +22,7 @@
#include "zeek/zeekygen/ScriptInfo.h" #include "zeek/zeekygen/ScriptInfo.h"
#include "zeek/zeekygen/utils.h" #include "zeek/zeekygen/utils.h"
namespace zeek namespace zeek {
{
RecordTypePtr id::conn_id; RecordTypePtr id::conn_id;
RecordTypePtr id::endpoint; RecordTypePtr id::endpoint;
@ -37,13 +36,9 @@ TableTypePtr id::count_set;
VectorTypePtr id::string_vec; VectorTypePtr id::string_vec;
VectorTypePtr id::index_vec; VectorTypePtr id::index_vec;
const detail::IDPtr& id::find(std::string_view name) const detail::IDPtr& id::find(std::string_view name) { return zeek::detail::global_scope()->Find(name); }
{
return zeek::detail::global_scope()->Find(name);
}
const TypePtr& id::find_type(std::string_view name) const TypePtr& id::find_type(std::string_view name) {
{
auto id = zeek::detail::global_scope()->Find(name); auto id = zeek::detail::global_scope()->Find(name);
if ( ! id ) if ( ! id )
@ -52,8 +47,7 @@ const TypePtr& id::find_type(std::string_view name)
return id->GetType(); return id->GetType();
} }
const ValPtr& id::find_val(std::string_view name) const ValPtr& id::find_val(std::string_view name) {
{
auto id = zeek::detail::global_scope()->Find(name); auto id = zeek::detail::global_scope()->Find(name);
if ( ! id ) if ( ! id )
@ -62,36 +56,31 @@ const ValPtr& id::find_val(std::string_view name)
return id->GetVal(); return id->GetVal();
} }
const ValPtr& id::find_const(std::string_view name) const ValPtr& id::find_const(std::string_view name) {
{
auto id = zeek::detail::global_scope()->Find(name); auto id = zeek::detail::global_scope()->Find(name);
if ( ! id ) if ( ! id )
reporter->InternalError("Failed to find variable named: %s", std::string(name).data()); reporter->InternalError("Failed to find variable named: %s", std::string(name).data());
if ( ! id->IsConst() ) if ( ! id->IsConst() )
reporter->InternalError("Variable is not 'const', but expected to be: %s", reporter->InternalError("Variable is not 'const', but expected to be: %s", std::string(name).data());
std::string(name).data());
return id->GetVal(); return id->GetVal();
} }
FuncPtr id::find_func(std::string_view name) FuncPtr id::find_func(std::string_view name) {
{
const auto& v = id::find_val(name); const auto& v = id::find_val(name);
if ( ! v ) if ( ! v )
return nullptr; return nullptr;
if ( ! IsFunc(v->GetType()->Tag()) ) if ( ! IsFunc(v->GetType()->Tag()) )
reporter->InternalError("Expected variable '%s' to be a function", reporter->InternalError("Expected variable '%s' to be a function", std::string(name).data());
std::string(name).data());
return v.get()->As<FuncVal*>()->AsFuncPtr(); return v.get()->As<FuncVal*>()->AsFuncPtr();
} }
void id::detail::init_types() void id::detail::init_types() {
{
conn_id = id::find_type<RecordType>("conn_id"); conn_id = id::find_type<RecordType>("conn_id");
endpoint = id::find_type<RecordType>("endpoint"); endpoint = id::find_type<RecordType>("endpoint");
connection = id::find_type<RecordType>("connection"); connection = id::find_type<RecordType>("connection");
@ -105,11 +94,9 @@ void id::detail::init_types()
index_vec = id::find_type<VectorType>("index_vec"); index_vec = id::find_type<VectorType>("index_vec");
} }
namespace detail namespace detail {
{
ID::ID(const char* arg_name, IDScope arg_scope, bool arg_is_export) ID::ID(const char* arg_name, IDScope arg_scope, bool arg_is_export) {
{
name = util::copy_string(arg_name); name = util::copy_string(arg_name);
scope = arg_scope; scope = arg_scope;
is_export = arg_is_export; is_export = arg_is_export;
@ -130,29 +117,18 @@ ID::ID(const char* arg_name, IDScope arg_scope, bool arg_is_export)
SetLocationInfo(&start_location, &end_location); SetLocationInfo(&start_location, &end_location);
} }
ID::~ID() ID::~ID() {
{
ClearOptInfo(); ClearOptInfo();
delete[] name; delete[] name;
} }
std::string ID::ModuleName() const std::string ID::ModuleName() const { return extract_module_name(name); }
{
return extract_module_name(name);
}
void ID::SetType(TypePtr t) void ID::SetType(TypePtr t) { type = std::move(t); }
{
type = std::move(t);
}
void ID::ClearVal() void ID::ClearVal() { val = nullptr; }
{
val = nullptr;
}
void ID::SetVal(ValPtr v) void ID::SetVal(ValPtr v) {
{
val = std::move(v); val = std::move(v);
Modified(); Modified();
@ -160,13 +136,10 @@ void ID::SetVal(ValPtr v)
UpdateValID(); UpdateValID();
#endif #endif
if ( type && val && type->Tag() == TYPE_FUNC && if ( type && val && type->Tag() == TYPE_FUNC && type->AsFuncType()->Flavor() == FUNC_FLAVOR_EVENT ) {
type->AsFuncType()->Flavor() == FUNC_FLAVOR_EVENT )
{
EventHandler* handler = event_registry->Lookup(name); EventHandler* handler = event_registry->Lookup(name);
auto func = val.get()->As<FuncVal*>()->AsFuncPtr(); auto func = val.get()->As<FuncVal*>()->AsFuncPtr();
if ( ! handler ) if ( ! handler ) {
{
handler = new EventHandler(name); handler = new EventHandler(name);
handler->SetFunc(func); handler->SetFunc(func);
event_registry->Register(handler, true); event_registry->Register(handler, true);
@ -174,8 +147,7 @@ void ID::SetVal(ValPtr v)
if ( ! IsExport() ) if ( ! IsExport() )
register_new_event({NewRef{}, this}); register_new_event({NewRef{}, this});
} }
else else {
{
// Otherwise, internally defined events cannot // Otherwise, internally defined events cannot
// have local handler. // have local handler.
handler->SetFunc(func); handler->SetFunc(func);
@ -183,97 +155,77 @@ void ID::SetVal(ValPtr v)
} }
} }
void ID::SetVal(ValPtr v, InitClass c) void ID::SetVal(ValPtr v, InitClass c) {
{ if ( c == INIT_NONE || c == INIT_FULL ) {
if ( c == INIT_NONE || c == INIT_FULL )
{
SetVal(std::move(v)); SetVal(std::move(v));
return; return;
} }
if ( type->Tag() != TYPE_TABLE && (type->Tag() != TYPE_PATTERN || c == INIT_REMOVE) && if ( type->Tag() != TYPE_TABLE && (type->Tag() != TYPE_PATTERN || c == INIT_REMOVE) &&
(type->Tag() != TYPE_VECTOR || c == INIT_REMOVE) ) (type->Tag() != TYPE_VECTOR || c == INIT_REMOVE) ) {
{
if ( c == INIT_EXTRA ) if ( c == INIT_EXTRA )
Error("+= initializer only applies to tables, sets, vectors and patterns", v.get()); Error("+= initializer only applies to tables, sets, vectors and patterns", v.get());
else else
Error("-= initializer only applies to tables and sets", v.get()); Error("-= initializer only applies to tables and sets", v.get());
} }
else else {
{ if ( c == INIT_EXTRA ) {
if ( c == INIT_EXTRA ) if ( ! val ) {
{
if ( ! val )
{
SetVal(std::move(v)); SetVal(std::move(v));
return; return;
} }
else else
v->AddTo(val.get(), false); v->AddTo(val.get(), false);
} }
else else {
{
if ( val ) if ( val )
v->RemoveFrom(val.get()); v->RemoveFrom(val.get());
} }
} }
} }
void ID::SetVal(ExprPtr ev, InitClass c) void ID::SetVal(ExprPtr ev, InitClass c) {
{
const auto& a = attrs->Find(c == INIT_EXTRA ? ATTR_ADD_FUNC : ATTR_DEL_FUNC); const auto& a = attrs->Find(c == INIT_EXTRA ? ATTR_ADD_FUNC : ATTR_DEL_FUNC);
if ( ! a ) if ( ! a )
Internal("no add/delete function in ID::SetVal"); Internal("no add/delete function in ID::SetVal");
if ( ! val ) if ( ! val ) {
{ Error(zeek::util::fmt("%s initializer applied to ID without value", c == INIT_EXTRA ? "+=" : "-="), this);
Error(zeek::util::fmt("%s initializer applied to ID without value",
c == INIT_EXTRA ? "+=" : "-="),
this);
return; return;
} }
EvalFunc(a->GetExpr(), std::move(ev)); EvalFunc(a->GetExpr(), std::move(ev));
} }
bool ID::IsRedefinable() const bool ID::IsRedefinable() const { return GetAttr(ATTR_REDEF) != nullptr; }
{
return GetAttr(ATTR_REDEF) != nullptr;
}
void ID::SetAttrs(AttributesPtr a) void ID::SetAttrs(AttributesPtr a) {
{
attrs = nullptr; attrs = nullptr;
AddAttrs(std::move(a)); AddAttrs(std::move(a));
} }
void ID::UpdateValAttrs() void ID::UpdateValAttrs() {
{
if ( ! attrs ) if ( ! attrs )
return; return;
auto tag = GetType()->Tag(); auto tag = GetType()->Tag();
if ( tag == TYPE_FUNC ) if ( tag == TYPE_FUNC ) {
{
const auto& attr = attrs->Find(ATTR_ERROR_HANDLER); const auto& attr = attrs->Find(ATTR_ERROR_HANDLER);
if ( attr ) if ( attr )
event_registry->SetErrorHandler(Name()); event_registry->SetErrorHandler(Name());
} }
if ( tag == TYPE_RECORD ) if ( tag == TYPE_RECORD ) {
{
const auto& attr = attrs->Find(ATTR_LOG); const auto& attr = attrs->Find(ATTR_LOG);
if ( attr ) if ( attr ) {
{
// Apply &log to all record fields. // Apply &log to all record fields.
RecordType* rt = GetType()->AsRecordType(); RecordType* rt = GetType()->AsRecordType();
for ( int i = 0; i < rt->NumFields(); ++i ) for ( int i = 0; i < rt->NumFields(); ++i ) {
{
TypeDecl* fd = rt->FieldDecl(i); TypeDecl* fd = rt->FieldDecl(i);
if ( ! fd->attrs ) if ( ! fd->attrs )
@ -296,26 +248,18 @@ void ID::UpdateValAttrs()
val->AsFile()->SetAttrs(attrs.get()); val->AsFile()->SetAttrs(attrs.get());
} }
const AttrPtr& ID::GetAttr(AttrTag t) const const AttrPtr& ID::GetAttr(AttrTag t) const { return attrs ? attrs->Find(t) : Attr::nil; }
{
return attrs ? attrs->Find(t) : Attr::nil;
}
bool ID::IsDeprecated() const bool ID::IsDeprecated() const { return GetAttr(ATTR_DEPRECATED) != nullptr; }
{
return GetAttr(ATTR_DEPRECATED) != nullptr;
}
void ID::MakeDeprecated(ExprPtr deprecation) void ID::MakeDeprecated(ExprPtr deprecation) {
{
if ( IsDeprecated() ) if ( IsDeprecated() )
return; return;
AddAttr(make_intrusive<Attr>(ATTR_DEPRECATED, std::move(deprecation))); AddAttr(make_intrusive<Attr>(ATTR_DEPRECATED, std::move(deprecation)));
} }
std::string ID::GetDeprecationWarning() const std::string ID::GetDeprecationWarning() const {
{
std::string result; std::string result;
const auto& depr_attr = GetAttr(ATTR_DEPRECATED); const auto& depr_attr = GetAttr(ATTR_DEPRECATED);
@ -328,15 +272,13 @@ std::string ID::GetDeprecationWarning() const
return util::fmt("deprecated (%s): %s", Name(), result.c_str()); return util::fmt("deprecated (%s): %s", Name(), result.c_str());
} }
void ID::AddAttr(AttrPtr a, bool is_redef) void ID::AddAttr(AttrPtr a, bool is_redef) {
{
std::vector<AttrPtr> attrv{std::move(a)}; std::vector<AttrPtr> attrv{std::move(a)};
auto attrs = make_intrusive<Attributes>(std::move(attrv), GetType(), false, IsGlobal()); auto attrs = make_intrusive<Attributes>(std::move(attrv), GetType(), false, IsGlobal());
AddAttrs(std::move(attrs), is_redef); AddAttrs(std::move(attrs), is_redef);
} }
void ID::AddAttrs(AttributesPtr a, bool is_redef) void ID::AddAttrs(AttributesPtr a, bool is_redef) {
{
if ( attrs ) if ( attrs )
attrs->AddAttrs(a, is_redef); attrs->AddAttrs(a, is_redef);
else else
@ -345,14 +287,12 @@ void ID::AddAttrs(AttributesPtr a, bool is_redef)
UpdateValAttrs(); UpdateValAttrs();
} }
void ID::RemoveAttr(AttrTag a) void ID::RemoveAttr(AttrTag a) {
{
if ( attrs ) if ( attrs )
attrs->RemoveAttr(a); attrs->RemoveAttr(a);
} }
void ID::SetOption() void ID::SetOption() {
{
if ( is_option ) if ( is_option )
return; return;
@ -363,8 +303,7 @@ void ID::SetOption()
AddAttr(make_intrusive<Attr>(ATTR_REDEF)); AddAttr(make_intrusive<Attr>(ATTR_REDEF));
} }
void ID::EvalFunc(ExprPtr ef, ExprPtr ev) void ID::EvalFunc(ExprPtr ef, ExprPtr ev) {
{
auto arg1 = make_intrusive<detail::ConstExpr>(val); auto arg1 = make_intrusive<detail::ConstExpr>(val);
auto args = make_intrusive<detail::ListExpr>(); auto args = make_intrusive<detail::ListExpr>();
args->Append(std::move(arg1)); args->Append(std::move(arg1));
@ -373,13 +312,11 @@ void ID::EvalFunc(ExprPtr ef, ExprPtr ev)
SetVal(ce->Eval(nullptr)); SetVal(ce->Eval(nullptr));
} }
TraversalCode ID::Traverse(TraversalCallback* cb) const TraversalCode ID::Traverse(TraversalCallback* cb) const {
{
TraversalCode tc = cb->PreID(this); TraversalCode tc = cb->PreID(this);
HANDLE_TC_STMT_PRE(tc); HANDLE_TC_STMT_PRE(tc);
if ( is_type ) if ( is_type ) {
{
tc = cb->PreTypedef(this); tc = cb->PreTypedef(this);
HANDLE_TC_STMT_PRE(tc); HANDLE_TC_STMT_PRE(tc);
@ -388,14 +325,12 @@ TraversalCode ID::Traverse(TraversalCallback* cb) const
} }
// FIXME: Perhaps we should be checking at other than global scope. // FIXME: Perhaps we should be checking at other than global scope.
else if ( val && IsFunc(val->GetType()->Tag()) && cb->current_scope == detail::global_scope() ) else if ( val && IsFunc(val->GetType()->Tag()) && cb->current_scope == detail::global_scope() ) {
{
tc = val->AsFunc()->Traverse(cb); tc = val->AsFunc()->Traverse(cb);
HANDLE_TC_STMT_PRE(tc); HANDLE_TC_STMT_PRE(tc);
} }
else if ( ! is_enum_const ) else if ( ! is_enum_const ) {
{
tc = cb->PreDecl(this); tc = cb->PreDecl(this);
HANDLE_TC_STMT_PRE(tc); HANDLE_TC_STMT_PRE(tc);
@ -407,42 +342,33 @@ TraversalCode ID::Traverse(TraversalCallback* cb) const
HANDLE_TC_EXPR_POST(tc); HANDLE_TC_EXPR_POST(tc);
} }
void ID::Error(const char* msg, const Obj* o2) void ID::Error(const char* msg, const Obj* o2) {
{
Obj::Error(msg, o2, true); Obj::Error(msg, o2, true);
SetType(error_type()); SetType(error_type());
} }
void ID::Describe(ODesc* d) const void ID::Describe(ODesc* d) const { d->Add(name); }
{
d->Add(name);
}
void ID::DescribeExtended(ODesc* d) const void ID::DescribeExtended(ODesc* d) const {
{
d->Add(name); d->Add(name);
if ( type ) if ( type ) {
{
d->Add(" : "); d->Add(" : ");
type->Describe(d); type->Describe(d);
} }
if ( val ) if ( val ) {
{
d->Add(" = "); d->Add(" = ");
val->Describe(d); val->Describe(d);
} }
if ( attrs ) if ( attrs ) {
{
d->Add(" "); d->Add(" ");
attrs->Describe(d); attrs->Describe(d);
} }
} }
void ID::DescribeReSTShort(ODesc* d) const void ID::DescribeReSTShort(ODesc* d) const {
{
if ( is_type ) if ( is_type )
d->Add(":zeek:type:`"); d->Add(":zeek:type:`");
else else
@ -451,26 +377,19 @@ void ID::DescribeReSTShort(ODesc* d) const
d->Add(name); d->Add(name);
d->Add("`"); d->Add("`");
if ( type ) if ( type ) {
{
d->Add(": "); d->Add(": ");
d->Add(":zeek:type:`"); d->Add(":zeek:type:`");
if ( ! is_type && ! type->GetName().empty() ) if ( ! is_type && ! type->GetName().empty() )
d->Add(type->GetName().c_str()); d->Add(type->GetName().c_str());
else else {
{
TypeTag t = type->Tag(); TypeTag t = type->Tag();
switch ( t ) switch ( t ) {
{ case TYPE_TABLE: d->Add(type->IsSet() ? "set" : type_name(t)); break;
case TYPE_TABLE:
d->Add(type->IsSet() ? "set" : type_name(t));
break;
case TYPE_FUNC: case TYPE_FUNC: d->Add(type->AsFuncType()->FlavorString().c_str()); break;
d->Add(type->AsFuncType()->FlavorString().c_str());
break;
case TYPE_ENUM: case TYPE_ENUM:
if ( is_type ) if ( is_type )
@ -479,26 +398,21 @@ void ID::DescribeReSTShort(ODesc* d) const
d->Add(zeekygen_mgr->GetEnumTypeName(Name()).c_str()); d->Add(zeekygen_mgr->GetEnumTypeName(Name()).c_str());
break; break;
default: default: d->Add(type_name(t)); break;
d->Add(type_name(t));
break;
} }
} }
d->Add("`"); d->Add("`");
} }
if ( attrs ) if ( attrs ) {
{
d->SP(); d->SP();
attrs->DescribeReST(d, true); attrs->DescribeReST(d, true);
} }
} }
void ID::DescribeReST(ODesc* d, bool roles_only) const void ID::DescribeReST(ODesc* d, bool roles_only) const {
{ if ( roles_only ) {
if ( roles_only )
{
if ( is_type ) if ( is_type )
d->Add(":zeek:type:`"); d->Add(":zeek:type:`");
else else
@ -506,8 +420,7 @@ void ID::DescribeReST(ODesc* d, bool roles_only) const
d->Add(name); d->Add(name);
d->Add("`"); d->Add("`");
} }
else else {
{
if ( is_type ) if ( is_type )
d->Add(".. zeek:type:: "); d->Add(".. zeek:type:: ");
else else
@ -515,8 +428,7 @@ void ID::DescribeReST(ODesc* d, bool roles_only) const
d->Add(name); d->Add(name);
if ( auto sc = zeek::zeekygen::detail::source_code_range(this) ) if ( auto sc = zeek::zeekygen::detail::source_code_range(this) ) {
{
d->PushIndent(); d->PushIndent();
d->Add(util::fmt(":source-code: %s", sc->data())); d->Add(util::fmt(":source-code: %s", sc->data()));
d->PopIndentNoNL(); d->PopIndentNoNL();
@ -526,36 +438,28 @@ void ID::DescribeReST(ODesc* d, bool roles_only) const
d->PushIndent(); d->PushIndent();
d->NL(); d->NL();
if ( type ) if ( type ) {
{
d->Add(":Type: "); d->Add(":Type: ");
if ( ! is_type && ! type->GetName().empty() ) if ( ! is_type && ! type->GetName().empty() ) {
{
d->Add(":zeek:type:`"); d->Add(":zeek:type:`");
d->Add(type->GetName()); d->Add(type->GetName());
d->Add("`"); d->Add("`");
} }
else else {
{
type->DescribeReST(d, roles_only); type->DescribeReST(d, roles_only);
if ( IsFunc(type->Tag()) ) if ( IsFunc(type->Tag()) ) {
{
auto ft = type->AsFuncType(); auto ft = type->AsFuncType();
if ( ft->Flavor() == FUNC_FLAVOR_EVENT || ft->Flavor() == FUNC_FLAVOR_HOOK ) if ( ft->Flavor() == FUNC_FLAVOR_EVENT || ft->Flavor() == FUNC_FLAVOR_HOOK ) {
{
const auto& protos = ft->Prototypes(); const auto& protos = ft->Prototypes();
if ( protos.size() > 1 ) if ( protos.size() > 1 ) {
{
auto first = true; auto first = true;
for ( const auto& proto : protos ) for ( const auto& proto : protos ) {
{ if ( first ) {
if ( first )
{
first = false; first = false;
continue; continue;
} }
@ -575,8 +479,7 @@ void ID::DescribeReST(ODesc* d, bool roles_only) const
d->NL(); d->NL();
} }
if ( attrs ) if ( attrs ) {
{
d->Add(":Attributes: "); d->Add(":Attributes: ");
attrs->DescribeReST(d); attrs->DescribeReST(d);
d->NL(); d->NL();
@ -586,20 +489,16 @@ void ID::DescribeReST(ODesc* d, bool roles_only) const
// Values within Version module are likely to include a // Values within Version module are likely to include a
// constantly-changing version number and be a frequent // constantly-changing version number and be a frequent
// source of error/desynchronization, so don't include them. // source of error/desynchronization, so don't include them.
ModuleName() != "Version" ) ModuleName() != "Version" ) {
{
d->Add(":Default:"); d->Add(":Default:");
auto ii = zeekygen_mgr->GetIdentifierInfo(Name()); auto ii = zeekygen_mgr->GetIdentifierInfo(Name());
auto redefs = ii->GetRedefs(); auto redefs = ii->GetRedefs();
const auto& iv = ! redefs.empty() && ii->InitialVal() ? ii->InitialVal() : val; const auto& iv = ! redefs.empty() && ii->InitialVal() ? ii->InitialVal() : val;
if ( type->InternalType() == TYPE_INTERNAL_OTHER ) if ( type->InternalType() == TYPE_INTERNAL_OTHER ) {
{ switch ( type->Tag() ) {
switch ( type->Tag() )
{
case TYPE_TABLE: case TYPE_TABLE:
if ( iv->AsTable()->Length() == 0 ) if ( iv->AsTable()->Length() == 0 ) {
{
d->Add(" ``{}``"); d->Add(" ``{}``");
d->NL(); d->NL();
break; break;
@ -618,15 +517,13 @@ void ID::DescribeReST(ODesc* d, bool roles_only) const
} }
} }
else else {
{
d->SP(); d->SP();
iv->DescribeReST(d); iv->DescribeReST(d);
d->NL(); d->NL();
} }
for ( auto& ir : redefs ) for ( auto& ir : redefs ) {
{
if ( ! ir->init_expr ) if ( ! ir->init_expr )
continue; continue;
@ -664,20 +561,15 @@ void ID::DescribeReST(ODesc* d, bool roles_only) const
} }
#ifdef DEBUG #ifdef DEBUG
void ID::UpdateValID() void ID::UpdateValID() {
{
if ( IsGlobal() && val && name && name[0] != '#' ) if ( IsGlobal() && val && name && name[0] != '#' )
val->SetID(this); val->SetID(this);
} }
#endif #endif
void ID::AddOptionHandler(FuncPtr callback, int priority) void ID::AddOptionHandler(FuncPtr callback, int priority) { option_handlers.emplace(priority, std::move(callback)); }
{
option_handlers.emplace(priority, std::move(callback));
}
std::vector<Func*> ID::GetOptionHandlers() const std::vector<Func*> ID::GetOptionHandlers() const {
{
// multimap is sorted // multimap is sorted
// It might be worth caching this if we expect it to be called // It might be worth caching this if we expect it to be called
// a lot... // a lot...
@ -687,8 +579,7 @@ std::vector<Func*> ID::GetOptionHandlers() const
return v; return v;
} }
void ID::ClearOptInfo() void ID::ClearOptInfo() {
{
delete opt_info; delete opt_info;
opt_info = nullptr; opt_info = nullptr;
} }

View file

@ -13,8 +13,7 @@
#include "zeek/Obj.h" #include "zeek/Obj.h"
#include "zeek/TraverseTypes.h" #include "zeek/TraverseTypes.h"
namespace zeek namespace zeek {
{
class Func; class Func;
class Val; class Val;
@ -31,29 +30,22 @@ using EnumTypePtr = IntrusivePtr<EnumType>;
using ValPtr = IntrusivePtr<Val>; using ValPtr = IntrusivePtr<Val>;
using FuncPtr = IntrusivePtr<Func>; using FuncPtr = IntrusivePtr<Func>;
} } // namespace zeek
namespace zeek::detail namespace zeek::detail {
{
class Attributes; class Attributes;
class Expr; class Expr;
using ExprPtr = IntrusivePtr<Expr>; using ExprPtr = IntrusivePtr<Expr>;
enum InitClass enum InitClass {
{
INIT_NONE, INIT_NONE,
INIT_FULL, INIT_FULL,
INIT_EXTRA, INIT_EXTRA,
INIT_REMOVE, INIT_REMOVE,
INIT_SKIP, INIT_SKIP,
}; };
enum IDScope enum IDScope { SCOPE_FUNCTION, SCOPE_MODULE, SCOPE_GLOBAL };
{
SCOPE_FUNCTION,
SCOPE_MODULE,
SCOPE_GLOBAL
};
class ID; class ID;
using IDPtr = IntrusivePtr<ID>; using IDPtr = IntrusivePtr<ID>;
@ -61,8 +53,7 @@ using IDSet = std::unordered_set<const ID*>;
class IDOptInfo; class IDOptInfo;
class ID final : public Obj, public notifier::detail::Modifiable class ID final : public Obj, public notifier::detail::Modifiable {
{
public: public:
static inline const IDPtr nil; static inline const IDPtr nil;
@ -84,7 +75,10 @@ public:
const TypePtr& GetType() const { return type; } const TypePtr& GetType() const { return type; }
template <class T> IntrusivePtr<T> GetType() const { return cast_intrusive<T>(type); } template<class T>
IntrusivePtr<T> GetType() const {
return cast_intrusive<T>(type);
}
bool IsType() const { return is_type; } bool IsType() const { return is_type; }
@ -183,8 +177,7 @@ protected:
} // namespace zeek::detail } // namespace zeek::detail
namespace zeek::id namespace zeek::id {
{
/** /**
* Lookup an ID in the global module and return it, if one exists; * Lookup an ID in the global module and return it, if one exists;
@ -208,8 +201,8 @@ const TypePtr& find_type(std::string_view name);
* @param name The identifier name to lookup * @param name The identifier name to lookup
* @return The type of the identifier. * @return The type of the identifier.
*/ */
template <class T> IntrusivePtr<T> find_type(std::string_view name) template<class T>
{ IntrusivePtr<T> find_type(std::string_view name) {
return cast_intrusive<T>(find_type(name)); return cast_intrusive<T>(find_type(name));
} }
@ -227,8 +220,8 @@ const ValPtr& find_val(std::string_view name);
* @param name The identifier name to lookup * @param name The identifier name to lookup
* @return The current value of the identifier. * @return The current value of the identifier.
*/ */
template <class T> IntrusivePtr<T> find_val(std::string_view name) template<class T>
{ IntrusivePtr<T> find_val(std::string_view name) {
return cast_intrusive<T>(find_val(name)); return cast_intrusive<T>(find_val(name));
} }
@ -246,8 +239,8 @@ const ValPtr& find_const(std::string_view name);
* @param name The identifier name to lookup * @param name The identifier name to lookup
* @return The current value of the identifier. * @return The current value of the identifier.
*/ */
template <class T> IntrusivePtr<T> find_const(std::string_view name) template<class T>
{ IntrusivePtr<T> find_const(std::string_view name) {
return cast_intrusive<T>(find_const(name)); return cast_intrusive<T>(find_const(name));
} }
@ -271,8 +264,7 @@ extern TableTypePtr count_set;
extern VectorTypePtr string_vec; extern VectorTypePtr string_vec;
extern VectorTypePtr index_vec; extern VectorTypePtr index_vec;
namespace detail namespace detail {
{
void init_types(); void init_types();

353
src/IP.cc
View file

@ -13,13 +13,10 @@
#include "zeek/Var.h" #include "zeek/Var.h"
#include "zeek/ZeekString.h" #include "zeek/ZeekString.h"
namespace zeek namespace zeek {
{
bool IPv6_Hdr::IsOptionTruncated(uint16_t off) const bool IPv6_Hdr::IsOptionTruncated(uint16_t off) const {
{ if ( Length() < off ) {
if ( Length() < off )
{
reporter->Weird("truncated_IPv6_option"); reporter->Weird("truncated_IPv6_option");
return true; return true;
} }
@ -27,27 +24,23 @@ bool IPv6_Hdr::IsOptionTruncated(uint16_t off) const
return false; return false;
} }
static VectorValPtr BuildOptionsVal(const u_char* data, int len) static VectorValPtr BuildOptionsVal(const u_char* data, int len) {
{
auto vv = make_intrusive<VectorVal>(id::find_type<VectorType>("ip6_options")); auto vv = make_intrusive<VectorVal>(id::find_type<VectorType>("ip6_options"));
while ( len > 0 && static_cast<size_t>(len) >= sizeof(struct ip6_opt) ) while ( len > 0 && static_cast<size_t>(len) >= sizeof(struct ip6_opt) ) {
{
static auto ip6_option_type = id::find_type<RecordType>("ip6_option"); static auto ip6_option_type = id::find_type<RecordType>("ip6_option");
const struct ip6_opt* opt = (const struct ip6_opt*)data; const struct ip6_opt* opt = (const struct ip6_opt*)data;
auto rv = make_intrusive<RecordVal>(ip6_option_type); auto rv = make_intrusive<RecordVal>(ip6_option_type);
rv->Assign(0, opt->ip6o_type); rv->Assign(0, opt->ip6o_type);
if ( opt->ip6o_type == 0 ) if ( opt->ip6o_type == 0 ) {
{
// Pad1 option // Pad1 option
rv->Assign(1, 0); rv->Assign(1, 0);
rv->Assign(2, val_mgr->EmptyString()); rv->Assign(2, val_mgr->EmptyString());
data += sizeof(uint8_t); data += sizeof(uint8_t);
len -= sizeof(uint8_t); len -= sizeof(uint8_t);
} }
else else {
{
// PadN or other option // PadN or other option
uint16_t off = 2 * sizeof(uint8_t); uint16_t off = 2 * sizeof(uint8_t);
@ -66,14 +59,11 @@ static VectorValPtr BuildOptionsVal(const u_char* data, int len)
return vv; return vv;
} }
RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const {
{
RecordValPtr rv; RecordValPtr rv;
switch ( type ) switch ( type ) {
{ case IPPROTO_IPV6: {
case IPPROTO_IPV6:
{
static auto ip6_hdr_type = id::find_type<RecordType>("ip6_hdr"); static auto ip6_hdr_type = id::find_type<RecordType>("ip6_hdr");
rv = make_intrusive<RecordVal>(ip6_hdr_type); rv = make_intrusive<RecordVal>(ip6_hdr_type);
const struct ip6_hdr* ip6 = (const struct ip6_hdr*)data; const struct ip6_hdr* ip6 = (const struct ip6_hdr*)data;
@ -87,11 +77,9 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
if ( ! chain ) if ( ! chain )
chain = make_intrusive<VectorVal>(id::find_type<VectorType>("ip6_ext_hdr_chain")); chain = make_intrusive<VectorVal>(id::find_type<VectorType>("ip6_ext_hdr_chain"));
rv->Assign(7, std::move(chain)); rv->Assign(7, std::move(chain));
} } break;
break;
case IPPROTO_HOPOPTS: case IPPROTO_HOPOPTS: {
{
uint16_t off = 2 * sizeof(uint8_t); uint16_t off = 2 * sizeof(uint8_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
return nullptr; return nullptr;
@ -102,11 +90,9 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
rv->Assign(0, hbh->ip6h_nxt); rv->Assign(0, hbh->ip6h_nxt);
rv->Assign(1, hbh->ip6h_len); rv->Assign(1, hbh->ip6h_len);
rv->Assign(2, BuildOptionsVal(data + off, Length() - off)); rv->Assign(2, BuildOptionsVal(data + off, Length() - off));
} } break;
break;
case IPPROTO_DSTOPTS: case IPPROTO_DSTOPTS: {
{
uint16_t off = 2 * sizeof(uint8_t); uint16_t off = 2 * sizeof(uint8_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
return nullptr; return nullptr;
@ -117,11 +103,9 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
rv->Assign(0, dst->ip6d_nxt); rv->Assign(0, dst->ip6d_nxt);
rv->Assign(1, dst->ip6d_len); rv->Assign(1, dst->ip6d_len);
rv->Assign(2, BuildOptionsVal(data + off, Length() - off)); rv->Assign(2, BuildOptionsVal(data + off, Length() - off));
} } break;
break;
case IPPROTO_ROUTING: case IPPROTO_ROUTING: {
{
uint16_t off = 4 * sizeof(uint8_t); uint16_t off = 4 * sizeof(uint8_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
return nullptr; return nullptr;
@ -134,11 +118,9 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
rv->Assign(2, rt->ip6r_type); rv->Assign(2, rt->ip6r_type);
rv->Assign(3, rt->ip6r_segleft); rv->Assign(3, rt->ip6r_segleft);
rv->Assign(4, new String(data + off, Length() - off, true)); rv->Assign(4, new String(data + off, Length() - off, true));
} } break;
break;
case IPPROTO_FRAGMENT: case IPPROTO_FRAGMENT: {
{
static auto ip6_fragment_type = id::find_type<RecordType>("ip6_fragment"); static auto ip6_fragment_type = id::find_type<RecordType>("ip6_fragment");
rv = make_intrusive<RecordVal>(ip6_fragment_type); rv = make_intrusive<RecordVal>(ip6_fragment_type);
const struct ip6_frag* frag = (const struct ip6_frag*)data; const struct ip6_frag* frag = (const struct ip6_frag*)data;
@ -148,11 +130,9 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
rv->Assign(3, (ntohs(frag->ip6f_offlg) & 0x0006) >> 1); rv->Assign(3, (ntohs(frag->ip6f_offlg) & 0x0006) >> 1);
rv->Assign(4, static_cast<bool>(ntohs(frag->ip6f_offlg) & 0x0001)); rv->Assign(4, static_cast<bool>(ntohs(frag->ip6f_offlg) & 0x0001));
rv->Assign(5, static_cast<uint32_t>(ntohl(frag->ip6f_ident))); rv->Assign(5, static_cast<uint32_t>(ntohl(frag->ip6f_ident)));
} } break;
break;
case IPPROTO_AH: case IPPROTO_AH: {
{
static auto ip6_ah_type = id::find_type<RecordType>("ip6_ah"); static auto ip6_ah_type = id::find_type<RecordType>("ip6_ah");
rv = make_intrusive<RecordVal>(ip6_ah_type); rv = make_intrusive<RecordVal>(ip6_ah_type);
rv->Assign(0, ((ip6_ext*)data)->ip6e_nxt); rv->Assign(0, ((ip6_ext*)data)->ip6e_nxt);
@ -160,29 +140,24 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
rv->Assign(2, ntohs(((uint16_t*)data)[1])); rv->Assign(2, ntohs(((uint16_t*)data)[1]));
rv->Assign(3, static_cast<uint32_t>(ntohl(((uint32_t*)data)[1]))); rv->Assign(3, static_cast<uint32_t>(ntohl(((uint32_t*)data)[1])));
if ( Length() >= 12 ) if ( Length() >= 12 ) {
{
// Sequence Number and ICV fields can only be extracted if // Sequence Number and ICV fields can only be extracted if
// Payload Len was non-zero for this header. // Payload Len was non-zero for this header.
rv->Assign(4, static_cast<uint32_t>(ntohl(((uint32_t*)data)[2]))); rv->Assign(4, static_cast<uint32_t>(ntohl(((uint32_t*)data)[2])));
uint16_t off = 3 * sizeof(uint32_t); uint16_t off = 3 * sizeof(uint32_t);
rv->Assign(5, new String(data + off, Length() - off, true)); rv->Assign(5, new String(data + off, Length() - off, true));
} }
} } break;
break;
case IPPROTO_ESP: case IPPROTO_ESP: {
{
static auto ip6_esp_type = id::find_type<RecordType>("ip6_esp"); static auto ip6_esp_type = id::find_type<RecordType>("ip6_esp");
rv = make_intrusive<RecordVal>(ip6_esp_type); rv = make_intrusive<RecordVal>(ip6_esp_type);
const uint32_t* esp = (const uint32_t*)data; const uint32_t* esp = (const uint32_t*)data;
rv->Assign(0, static_cast<uint32_t>(ntohl(esp[0]))); rv->Assign(0, static_cast<uint32_t>(ntohl(esp[0])));
rv->Assign(1, static_cast<uint32_t>(ntohl(esp[1]))); rv->Assign(1, static_cast<uint32_t>(ntohl(esp[1])));
} } break;
break;
case IPPROTO_MOBILITY: case IPPROTO_MOBILITY: {
{
static auto ip6_mob_type = id::find_type<RecordType>("ip6_mobility_hdr"); static auto ip6_mob_type = id::find_type<RecordType>("ip6_mobility_hdr");
rv = make_intrusive<RecordVal>(ip6_mob_type); rv = make_intrusive<RecordVal>(ip6_mob_type);
const struct ip6_mobility* mob = (const struct ip6_mobility*)data; const struct ip6_mobility* mob = (const struct ip6_mobility*)data;
@ -208,10 +183,8 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
static auto ip6_mob_back_type = id::find_type<RecordType>("ip6_mobility_back"); static auto ip6_mob_back_type = id::find_type<RecordType>("ip6_mobility_back");
static auto ip6_mob_be_type = id::find_type<RecordType>("ip6_mobility_be"); static auto ip6_mob_be_type = id::find_type<RecordType>("ip6_mobility_be");
switch ( mob->ip6mob_type ) switch ( mob->ip6mob_type ) {
{ case 0: {
case 0:
{
off += sizeof(uint16_t); off += sizeof(uint16_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
break; break;
@ -223,8 +196,7 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
break; break;
} }
case 1: case 1: {
{
off += sizeof(uint16_t) + sizeof(uint64_t); off += sizeof(uint16_t) + sizeof(uint64_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
break; break;
@ -237,8 +209,7 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
break; break;
} }
case 2: case 2: {
{
off += sizeof(uint16_t) + sizeof(uint64_t); off += sizeof(uint16_t) + sizeof(uint64_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
break; break;
@ -251,8 +222,7 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
break; break;
} }
case 3: case 3: {
{
off += sizeof(uint16_t) + 2 * sizeof(uint64_t); off += sizeof(uint16_t) + 2 * sizeof(uint64_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
break; break;
@ -260,15 +230,13 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
auto m = make_intrusive<RecordVal>(ip6_mob_hot_type); auto m = make_intrusive<RecordVal>(ip6_mob_hot_type);
m->Assign(0, ntohs(*((uint16_t*)msg_data))); m->Assign(0, ntohs(*((uint16_t*)msg_data)));
m->Assign(1, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t))))); m->Assign(1, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t)))));
m->Assign( m->Assign(2, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t) + sizeof(uint64_t)))));
2, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t) + sizeof(uint64_t)))));
m->Assign(3, BuildOptionsVal(data + off, Length() - off)); m->Assign(3, BuildOptionsVal(data + off, Length() - off));
msg->Assign(4, std::move(m)); msg->Assign(4, std::move(m));
break; break;
} }
case 4: case 4: {
{
off += sizeof(uint16_t) + 2 * sizeof(uint64_t); off += sizeof(uint16_t) + 2 * sizeof(uint64_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
break; break;
@ -276,45 +244,37 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
auto m = make_intrusive<RecordVal>(ip6_mob_cot_type); auto m = make_intrusive<RecordVal>(ip6_mob_cot_type);
m->Assign(0, ntohs(*((uint16_t*)msg_data))); m->Assign(0, ntohs(*((uint16_t*)msg_data)));
m->Assign(1, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t))))); m->Assign(1, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t)))));
m->Assign( m->Assign(2, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t) + sizeof(uint64_t)))));
2, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t) + sizeof(uint64_t)))));
m->Assign(3, BuildOptionsVal(data + off, Length() - off)); m->Assign(3, BuildOptionsVal(data + off, Length() - off));
msg->Assign(5, std::move(m)); msg->Assign(5, std::move(m));
break; break;
} }
case 5: case 5: {
{
off += 3 * sizeof(uint16_t); off += 3 * sizeof(uint16_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
break; break;
auto m = make_intrusive<RecordVal>(ip6_mob_bu_type); auto m = make_intrusive<RecordVal>(ip6_mob_bu_type);
m->Assign(0, ntohs(*((uint16_t*)msg_data))); m->Assign(0, ntohs(*((uint16_t*)msg_data)));
m->Assign(1, static_cast<bool>( m->Assign(1, static_cast<bool>(ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x8000));
ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x8000)); m->Assign(2, static_cast<bool>(ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x4000));
m->Assign(2, static_cast<bool>( m->Assign(3, static_cast<bool>(ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x2000));
ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x4000)); m->Assign(4, static_cast<bool>(ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x1000));
m->Assign(3, static_cast<bool>(
ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x2000));
m->Assign(4, static_cast<bool>(
ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x1000));
m->Assign(5, ntohs(*((uint16_t*)(msg_data + 2 * sizeof(uint16_t))))); m->Assign(5, ntohs(*((uint16_t*)(msg_data + 2 * sizeof(uint16_t)))));
m->Assign(6, BuildOptionsVal(data + off, Length() - off)); m->Assign(6, BuildOptionsVal(data + off, Length() - off));
msg->Assign(6, std::move(m)); msg->Assign(6, std::move(m));
break; break;
} }
case 6: case 6: {
{
off += 3 * sizeof(uint16_t); off += 3 * sizeof(uint16_t);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
break; break;
auto m = make_intrusive<RecordVal>(ip6_mob_back_type); auto m = make_intrusive<RecordVal>(ip6_mob_back_type);
m->Assign(0, *((uint8_t*)msg_data)); m->Assign(0, *((uint8_t*)msg_data));
m->Assign(1, m->Assign(1, static_cast<bool>(*((uint8_t*)(msg_data + sizeof(uint8_t))) & 0x80));
static_cast<bool>(*((uint8_t*)(msg_data + sizeof(uint8_t))) & 0x80));
m->Assign(2, ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t))))); m->Assign(2, ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))));
m->Assign(3, ntohs(*((uint16_t*)(msg_data + 2 * sizeof(uint16_t))))); m->Assign(3, ntohs(*((uint16_t*)(msg_data + 2 * sizeof(uint16_t)))));
m->Assign(4, BuildOptionsVal(data + off, Length() - off)); m->Assign(4, BuildOptionsVal(data + off, Length() - off));
@ -322,8 +282,7 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
break; break;
} }
case 7: case 7: {
{
off += sizeof(uint16_t) + sizeof(in6_addr); off += sizeof(uint16_t) + sizeof(in6_addr);
if ( IsOptionTruncated(off) ) if ( IsOptionTruncated(off) )
break; break;
@ -337,53 +296,32 @@ RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const
break; break;
} }
default: default: reporter->Weird("unknown_mobility_type", util::fmt("%d", mob->ip6mob_type)); break;
reporter->Weird("unknown_mobility_type", util::fmt("%d", mob->ip6mob_type));
break;
} }
rv->Assign(5, std::move(msg)); rv->Assign(5, std::move(msg));
} } break;
break;
default: default: break;
break;
} }
return rv; return rv;
} }
RecordValPtr IPv6_Hdr::ToVal() const RecordValPtr IPv6_Hdr::ToVal() const { return ToVal(nullptr); }
{
return ToVal(nullptr);
}
IPAddr IP_Hdr::IPHeaderSrcAddr() const IPAddr IP_Hdr::IPHeaderSrcAddr() const { return ip4 ? IPAddr(ip4->ip_src) : IPAddr(ip6->ip6_src); }
{
return ip4 ? IPAddr(ip4->ip_src) : IPAddr(ip6->ip6_src);
}
IPAddr IP_Hdr::IPHeaderDstAddr() const IPAddr IP_Hdr::IPHeaderDstAddr() const { return ip4 ? IPAddr(ip4->ip_dst) : IPAddr(ip6->ip6_dst); }
{
return ip4 ? IPAddr(ip4->ip_dst) : IPAddr(ip6->ip6_dst);
}
IPAddr IP_Hdr::SrcAddr() const IPAddr IP_Hdr::SrcAddr() const { return ip4 ? IPAddr(ip4->ip_src) : ip6_hdrs->SrcAddr(); }
{
return ip4 ? IPAddr(ip4->ip_src) : ip6_hdrs->SrcAddr();
}
IPAddr IP_Hdr::DstAddr() const IPAddr IP_Hdr::DstAddr() const { return ip4 ? IPAddr(ip4->ip_dst) : ip6_hdrs->DstAddr(); }
{
return ip4 ? IPAddr(ip4->ip_dst) : ip6_hdrs->DstAddr();
}
RecordValPtr IP_Hdr::ToIPHdrVal() const RecordValPtr IP_Hdr::ToIPHdrVal() const {
{
RecordValPtr rval; RecordValPtr rval;
if ( ip4 ) if ( ip4 ) {
{
static auto ip4_hdr_type = id::find_type<RecordType>("ip4_hdr"); static auto ip4_hdr_type = id::find_type<RecordType>("ip4_hdr");
rval = make_intrusive<RecordVal>(ip4_hdr_type); rval = make_intrusive<RecordVal>(ip4_hdr_type);
rval->Assign(0, ip4->ip_hl * 4); rval->Assign(0, ip4->ip_hl * 4);
@ -399,22 +337,19 @@ RecordValPtr IP_Hdr::ToIPHdrVal() const
rval->Assign(10, make_intrusive<AddrVal>(ip4->ip_src.s_addr)); rval->Assign(10, make_intrusive<AddrVal>(ip4->ip_src.s_addr));
rval->Assign(11, make_intrusive<AddrVal>(ip4->ip_dst.s_addr)); rval->Assign(11, make_intrusive<AddrVal>(ip4->ip_dst.s_addr));
} }
else else {
{
rval = ((*ip6_hdrs)[0])->ToVal(ip6_hdrs->ToVal()); rval = ((*ip6_hdrs)[0])->ToVal(ip6_hdrs->ToVal());
} }
return rval; return rval;
} }
RecordValPtr IP_Hdr::ToPktHdrVal() const RecordValPtr IP_Hdr::ToPktHdrVal() const {
{
static auto pkt_hdr_type = id::find_type<RecordType>("pkt_hdr"); static auto pkt_hdr_type = id::find_type<RecordType>("pkt_hdr");
return ToPktHdrVal(make_intrusive<RecordVal>(pkt_hdr_type), 0); return ToPktHdrVal(make_intrusive<RecordVal>(pkt_hdr_type), 0);
} }
RecordValPtr IP_Hdr::ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const RecordValPtr IP_Hdr::ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const {
{
static auto tcp_hdr_type = id::find_type<RecordType>("tcp_hdr"); static auto tcp_hdr_type = id::find_type<RecordType>("tcp_hdr");
static auto udp_hdr_type = id::find_type<RecordType>("udp_hdr"); static auto udp_hdr_type = id::find_type<RecordType>("udp_hdr");
static auto icmp_hdr_type = id::find_type<RecordType>("icmp_hdr"); static auto icmp_hdr_type = id::find_type<RecordType>("icmp_hdr");
@ -428,10 +363,8 @@ RecordValPtr IP_Hdr::ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const
const u_char* data = Payload(); const u_char* data = Payload();
int proto = NextProto(); int proto = NextProto();
switch ( proto ) switch ( proto ) {
{ case IPPROTO_TCP: {
case IPPROTO_TCP:
{
if ( PayloadLen() < sizeof(struct tcphdr) ) if ( PayloadLen() < sizeof(struct tcphdr) )
break; break;
@ -461,8 +394,7 @@ RecordValPtr IP_Hdr::ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const
break; break;
} }
case IPPROTO_UDP: case IPPROTO_UDP: {
{
if ( PayloadLen() < sizeof(struct udphdr) ) if ( PayloadLen() < sizeof(struct udphdr) )
break; break;
@ -477,8 +409,7 @@ RecordValPtr IP_Hdr::ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const
break; break;
} }
case IPPROTO_ICMP: case IPPROTO_ICMP: {
{
if ( PayloadLen() < sizeof(struct icmp) ) if ( PayloadLen() < sizeof(struct icmp) )
break; break;
@ -491,8 +422,7 @@ RecordValPtr IP_Hdr::ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const
break; break;
} }
case IPPROTO_ICMPV6: case IPPROTO_ICMPV6: {
{
if ( PayloadLen() < sizeof(struct icmp6_hdr) ) if ( PayloadLen() < sizeof(struct icmp6_hdr) )
break; break;
@ -505,8 +435,7 @@ RecordValPtr IP_Hdr::ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const
break; break;
} }
default: default: {
{
// This is not a protocol we understand. // This is not a protocol we understand.
break; break;
} }
@ -515,47 +444,38 @@ RecordValPtr IP_Hdr::ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const
return pkt_hdr; return pkt_hdr;
} }
static inline bool isIPv6ExtHeader(uint8_t type) static inline bool isIPv6ExtHeader(uint8_t type) {
{ switch ( type ) {
switch ( type )
{
case IPPROTO_HOPOPTS: case IPPROTO_HOPOPTS:
case IPPROTO_ROUTING: case IPPROTO_ROUTING:
case IPPROTO_DSTOPTS: case IPPROTO_DSTOPTS:
case IPPROTO_FRAGMENT: case IPPROTO_FRAGMENT:
case IPPROTO_AH: case IPPROTO_AH:
case IPPROTO_ESP: case IPPROTO_ESP:
case IPPROTO_MOBILITY: case IPPROTO_MOBILITY: return true;
return true; default: return false;
default:
return false;
} }
} }
IPv6_Hdr_Chain::~IPv6_Hdr_Chain() IPv6_Hdr_Chain::~IPv6_Hdr_Chain() {
{
for ( size_t i = 0; i < chain.size(); ++i ) for ( size_t i = 0; i < chain.size(); ++i )
delete chain[i]; delete chain[i];
delete homeAddr; delete homeAddr;
delete finalDst; delete finalDst;
} }
void IPv6_Hdr_Chain::Init(const struct ip6_hdr* ip6, uint64_t total_len, bool set_next, void IPv6_Hdr_Chain::Init(const struct ip6_hdr* ip6, uint64_t total_len, bool set_next, uint16_t next) {
uint16_t next)
{
length = 0; length = 0;
uint8_t current_type, next_type; uint8_t current_type, next_type;
next_type = IPPROTO_IPV6; next_type = IPPROTO_IPV6;
const u_char* hdrs = (const u_char*)ip6; const u_char* hdrs = (const u_char*)ip6;
if ( total_len < (int)sizeof(struct ip6_hdr) ) if ( total_len < (int)sizeof(struct ip6_hdr) ) {
{
reporter->InternalWarning("truncated IP header in IPv6_HdrChain::Init"); reporter->InternalWarning("truncated IP header in IPv6_HdrChain::Init");
return; return;
} }
do do {
{
// We can't determine a given header's length if there's less than // We can't determine a given header's length if there's less than
// two bytes of data available (2nd byte of extension headers is length) // two bytes of data available (2nd byte of extension headers is length)
if ( total_len < 2 ) if ( total_len < 2 )
@ -568,14 +488,12 @@ void IPv6_Hdr_Chain::Init(const struct ip6_hdr* ip6, uint64_t total_len, bool se
uint16_t cur_len = p->Length(); uint16_t cur_len = p->Length();
// If this header is truncated, don't add it to chain, don't go further. // If this header is truncated, don't add it to chain, don't go further.
if ( cur_len > total_len ) if ( cur_len > total_len ) {
{
delete p; delete p;
return; return;
} }
if ( set_next && next_type == IPPROTO_FRAGMENT ) if ( set_next && next_type == IPPROTO_FRAGMENT ) {
{
p->ChangeNext(next); p->ChangeNext(next);
next_type = next; next_type = next;
} }
@ -594,14 +512,12 @@ void IPv6_Hdr_Chain::Init(const struct ip6_hdr* ip6, uint64_t total_len, bool se
length += cur_len; length += cur_len;
total_len -= cur_len; total_len -= cur_len;
} while ( current_type != IPPROTO_FRAGMENT && current_type != IPPROTO_ESP && } while ( current_type != IPPROTO_FRAGMENT && current_type != IPPROTO_ESP && current_type != IPPROTO_MOBILITY &&
current_type != IPPROTO_MOBILITY && isIPv6ExtHeader(next_type) ); isIPv6ExtHeader(next_type) );
} }
bool IPv6_Hdr_Chain::IsFragment() const bool IPv6_Hdr_Chain::IsFragment() const {
{ if ( chain.empty() ) {
if ( chain.empty() )
{
reporter->InternalWarning("empty IPv6 header chain"); reporter->InternalWarning("empty IPv6 header chain");
return false; return false;
} }
@ -609,13 +525,11 @@ bool IPv6_Hdr_Chain::IsFragment() const
return chain[chain.size() - 1]->Type() == IPPROTO_FRAGMENT; return chain[chain.size() - 1]->Type() == IPPROTO_FRAGMENT;
} }
IPAddr IPv6_Hdr_Chain::SrcAddr() const IPAddr IPv6_Hdr_Chain::SrcAddr() const {
{
if ( homeAddr ) if ( homeAddr )
return {*homeAddr}; return {*homeAddr};
if ( chain.empty() ) if ( chain.empty() ) {
{
reporter->InternalWarning("empty IPv6 header chain"); reporter->InternalWarning("empty IPv6 header chain");
return {}; return {};
} }
@ -623,13 +537,11 @@ IPAddr IPv6_Hdr_Chain::SrcAddr() const
return IPAddr{((const struct ip6_hdr*)(chain[0]->Data()))->ip6_src}; return IPAddr{((const struct ip6_hdr*)(chain[0]->Data()))->ip6_src};
} }
IPAddr IPv6_Hdr_Chain::DstAddr() const IPAddr IPv6_Hdr_Chain::DstAddr() const {
{
if ( finalDst ) if ( finalDst )
return {*finalDst}; return {*finalDst};
if ( chain.empty() ) if ( chain.empty() ) {
{
reporter->InternalWarning("empty IPv6 header chain"); reporter->InternalWarning("empty IPv6 header chain");
return {}; return {};
} }
@ -637,10 +549,8 @@ IPAddr IPv6_Hdr_Chain::DstAddr() const
return IPAddr{((const struct ip6_hdr*)(chain[0]->Data()))->ip6_dst}; return IPAddr{((const struct ip6_hdr*)(chain[0]->Data()))->ip6_dst};
} }
void IPv6_Hdr_Chain::ProcessRoutingHeader(const struct ip6_rthdr* r, uint16_t len) void IPv6_Hdr_Chain::ProcessRoutingHeader(const struct ip6_rthdr* r, uint16_t len) {
{ if ( finalDst ) {
if ( finalDst )
{
// RFC 2460 section 4.1 says Routing should occur at most once. // RFC 2460 section 4.1 says Routing should occur at most once.
reporter->Weird(SrcAddr(), DstAddr(), "multiple_routing_headers"); reporter->Weird(SrcAddr(), DstAddr(), "multiple_routing_headers");
return; return;
@ -649,12 +559,10 @@ void IPv6_Hdr_Chain::ProcessRoutingHeader(const struct ip6_rthdr* r, uint16_t le
// Last 16 bytes of header (for all known types) is the address we want. // Last 16 bytes of header (for all known types) is the address we want.
const in6_addr* addr = (const in6_addr*)(((const u_char*)r) + len - 16); const in6_addr* addr = (const in6_addr*)(((const u_char*)r) + len - 16);
switch ( r->ip6r_type ) switch ( r->ip6r_type ) {
{
case 0: // Defined by RFC 2460, deprecated by RFC 5095 case 0: // Defined by RFC 2460, deprecated by RFC 5095
{ {
if ( r->ip6r_segleft > 0 && r->ip6r_len >= 2 ) if ( r->ip6r_segleft > 0 && r->ip6r_len >= 2 ) {
{
if ( r->ip6r_len % 2 == 0 ) if ( r->ip6r_len % 2 == 0 )
finalDst = new IPAddr(*addr); finalDst = new IPAddr(*addr);
else else
@ -663,30 +571,23 @@ void IPv6_Hdr_Chain::ProcessRoutingHeader(const struct ip6_rthdr* r, uint16_t le
// Always raise a weird since this type is deprecated. // Always raise a weird since this type is deprecated.
reporter->Weird(SrcAddr(), DstAddr(), "routing0_hdr"); reporter->Weird(SrcAddr(), DstAddr(), "routing0_hdr");
} } break;
break;
case 2: // Defined by Mobile IPv6 RFC 6275. case 2: // Defined by Mobile IPv6 RFC 6275.
{ {
if ( r->ip6r_segleft > 0 ) if ( r->ip6r_segleft > 0 ) {
{
if ( r->ip6r_len == 2 ) if ( r->ip6r_len == 2 )
finalDst = new IPAddr(*addr); finalDst = new IPAddr(*addr);
else else
reporter->Weird(SrcAddr(), DstAddr(), "bad_routing2_len"); reporter->Weird(SrcAddr(), DstAddr(), "bad_routing2_len");
} }
} } break;
break;
default: default: reporter->Weird(SrcAddr(), DstAddr(), "unknown_routing_type", util::fmt("%d", r->ip6r_type)); break;
reporter->Weird(SrcAddr(), DstAddr(), "unknown_routing_type",
util::fmt("%d", r->ip6r_type));
break;
} }
} }
void IPv6_Hdr_Chain::ProcessDstOpts(const struct ip6_dest* d, uint16_t len) void IPv6_Hdr_Chain::ProcessDstOpts(const struct ip6_dest* d, uint16_t len) {
{
// Skip two bytes to get the beginning of the first option structure. These // Skip two bytes to get the beginning of the first option structure. These
// two bytes are the protocol for the next header and extension header length, // two bytes are the protocol for the next header and extension header length,
// already known to exist before calling this method. See header format: // already known to exist before calling this method. See header format:
@ -697,39 +598,32 @@ void IPv6_Hdr_Chain::ProcessDstOpts(const struct ip6_dest* d, uint16_t len)
len -= 2 * sizeof(uint8_t); len -= 2 * sizeof(uint8_t);
data += 2 * sizeof(uint8_t); data += 2 * sizeof(uint8_t);
while ( len > 0 ) while ( len > 0 ) {
{
const struct ip6_opt* opt = (const struct ip6_opt*)data; const struct ip6_opt* opt = (const struct ip6_opt*)data;
switch ( opt->ip6o_type ) switch ( opt->ip6o_type ) {
{
case 0: case 0:
// If option type is zero, it's a Pad0 and can be just a single // If option type is zero, it's a Pad0 and can be just a single
// byte in width. Skip over it. // byte in width. Skip over it.
data += sizeof(uint8_t); data += sizeof(uint8_t);
len -= sizeof(uint8_t); len -= sizeof(uint8_t);
break; break;
default: default: {
{
// Double-check that the len can hold the whole option structure. // Double-check that the len can hold the whole option structure.
// Otherwise we get a buffer-overflow when we check the option_len. // Otherwise we get a buffer-overflow when we check the option_len.
// Also check that it holds everything for the option itself. // Also check that it holds everything for the option itself.
if ( len < sizeof(struct ip6_opt) || len < sizeof(struct ip6_opt) + opt->ip6o_len ) if ( len < sizeof(struct ip6_opt) || len < sizeof(struct ip6_opt) + opt->ip6o_len ) {
{
reporter->Weird(SrcAddr(), DstAddr(), "bad_ipv6_dest_opt_len"); reporter->Weird(SrcAddr(), DstAddr(), "bad_ipv6_dest_opt_len");
len = 0; len = 0;
break; break;
} }
if ( opt->ip6o_type == if ( opt->ip6o_type == 201 ) // Home Address Option, Mobile IPv6 RFC 6275 section 6.3
201 ) // Home Address Option, Mobile IPv6 RFC 6275 section 6.3
{
if ( opt->ip6o_len == sizeof(struct in6_addr) )
{ {
if ( opt->ip6o_len == sizeof(struct in6_addr) ) {
if ( homeAddr ) if ( homeAddr )
reporter->Weird(SrcAddr(), DstAddr(), "multiple_home_addr_opts"); reporter->Weird(SrcAddr(), DstAddr(), "multiple_home_addr_opts");
else else
homeAddr = new IPAddr( homeAddr = new IPAddr(*((const in6_addr*)(data + sizeof(struct ip6_opt))));
*((const in6_addr*)(data + sizeof(struct ip6_opt))));
} }
else else
reporter->Weird(SrcAddr(), DstAddr(), "bad_home_addr_len"); reporter->Weird(SrcAddr(), DstAddr(), "bad_home_addr_len");
@ -737,14 +631,12 @@ void IPv6_Hdr_Chain::ProcessDstOpts(const struct ip6_dest* d, uint16_t len)
data += sizeof(struct ip6_opt) + opt->ip6o_len; data += sizeof(struct ip6_opt) + opt->ip6o_len;
len -= sizeof(struct ip6_opt) + opt->ip6o_len; len -= sizeof(struct ip6_opt) + opt->ip6o_len;
} } break;
break;
} }
} }
} }
VectorValPtr IPv6_Hdr_Chain::ToVal() const VectorValPtr IPv6_Hdr_Chain::ToVal() const {
{
static auto ip6_ext_hdr_type = id::find_type<RecordType>("ip6_ext_hdr"); static auto ip6_ext_hdr_type = id::find_type<RecordType>("ip6_ext_hdr");
static auto ip6_hopopts_type = id::find_type<RecordType>("ip6_hopopts"); static auto ip6_hopopts_type = id::find_type<RecordType>("ip6_hopopts");
static auto ip6_dstopts_type = id::find_type<RecordType>("ip6_dstopts"); static auto ip6_dstopts_type = id::find_type<RecordType>("ip6_dstopts");
@ -755,39 +647,21 @@ VectorValPtr IPv6_Hdr_Chain::ToVal() const
static auto ip6_ext_hdr_chain_type = id::find_type<VectorType>("ip6_ext_hdr_chain"); static auto ip6_ext_hdr_chain_type = id::find_type<VectorType>("ip6_ext_hdr_chain");
auto rval = make_intrusive<VectorVal>(ip6_ext_hdr_chain_type); auto rval = make_intrusive<VectorVal>(ip6_ext_hdr_chain_type);
for ( size_t i = 1; i < chain.size(); ++i ) for ( size_t i = 1; i < chain.size(); ++i ) {
{
auto v = chain[i]->ToVal(); auto v = chain[i]->ToVal();
auto ext_hdr = make_intrusive<RecordVal>(ip6_ext_hdr_type); auto ext_hdr = make_intrusive<RecordVal>(ip6_ext_hdr_type);
uint8_t type = chain[i]->Type(); uint8_t type = chain[i]->Type();
ext_hdr->Assign(0, type); ext_hdr->Assign(0, type);
switch ( type ) switch ( type ) {
{ case IPPROTO_HOPOPTS: ext_hdr->Assign(1, std::move(v)); break;
case IPPROTO_HOPOPTS: case IPPROTO_DSTOPTS: ext_hdr->Assign(2, std::move(v)); break;
ext_hdr->Assign(1, std::move(v)); case IPPROTO_ROUTING: ext_hdr->Assign(3, std::move(v)); break;
break; case IPPROTO_FRAGMENT: ext_hdr->Assign(4, std::move(v)); break;
case IPPROTO_DSTOPTS: case IPPROTO_AH: ext_hdr->Assign(5, std::move(v)); break;
ext_hdr->Assign(2, std::move(v)); case IPPROTO_ESP: ext_hdr->Assign(6, std::move(v)); break;
break; case IPPROTO_MOBILITY: ext_hdr->Assign(7, std::move(v)); break;
case IPPROTO_ROUTING: default: reporter->InternalWarning("IPv6_Hdr_Chain bad header %d", type); continue;
ext_hdr->Assign(3, std::move(v));
break;
case IPPROTO_FRAGMENT:
ext_hdr->Assign(4, std::move(v));
break;
case IPPROTO_AH:
ext_hdr->Assign(5, std::move(v));
break;
case IPPROTO_ESP:
ext_hdr->Assign(6, std::move(v));
break;
case IPPROTO_MOBILITY:
ext_hdr->Assign(7, std::move(v));
break;
default:
reporter->InternalWarning("IPv6_Hdr_Chain bad header %d", type);
continue;
} }
rval->Assign(rval->Size(), std::move(ext_hdr)); rval->Assign(rval->Size(), std::move(ext_hdr));
@ -796,12 +670,10 @@ VectorValPtr IPv6_Hdr_Chain::ToVal() const
return rval; return rval;
} }
IP_Hdr* IP_Hdr::Copy() const IP_Hdr* IP_Hdr::Copy() const {
{
char* new_hdr = new char[HdrLen()]; char* new_hdr = new char[HdrLen()];
if ( ip4 ) if ( ip4 ) {
{
memcpy(new_hdr, ip4, HdrLen()); memcpy(new_hdr, ip4, HdrLen());
return new IP_Hdr((const struct ip*)new_hdr, true); return new IP_Hdr((const struct ip*)new_hdr, true);
} }
@ -812,8 +684,7 @@ IP_Hdr* IP_Hdr::Copy() const
return new IP_Hdr(new_ip6, true, 0, new_ip6_hdrs); return new IP_Hdr(new_ip6, true, 0, new_ip6_hdrs);
} }
IPv6_Hdr_Chain* IPv6_Hdr_Chain::Copy(const ip6_hdr* new_hdr) const IPv6_Hdr_Chain* IPv6_Hdr_Chain::Copy(const ip6_hdr* new_hdr) const {
{
IPv6_Hdr_Chain* rval = new IPv6_Hdr_Chain; IPv6_Hdr_Chain* rval = new IPv6_Hdr_Chain;
rval->length = length; rval->length = length;
@ -823,8 +694,7 @@ IPv6_Hdr_Chain* IPv6_Hdr_Chain::Copy(const ip6_hdr* new_hdr) const
if ( finalDst ) if ( finalDst )
rval->finalDst = new IPAddr(*finalDst); rval->finalDst = new IPAddr(*finalDst);
if ( chain.empty() ) if ( chain.empty() ) {
{
reporter->InternalWarning("empty IPv6 header chain"); reporter->InternalWarning("empty IPv6 header chain");
delete rval; delete rval;
return nullptr; return nullptr;
@ -833,8 +703,7 @@ IPv6_Hdr_Chain* IPv6_Hdr_Chain::Copy(const ip6_hdr* new_hdr) const
const u_char* new_data = (const u_char*)new_hdr; const u_char* new_data = (const u_char*)new_hdr;
const u_char* old_data = chain[0]->Data(); const u_char* old_data = chain[0]->Data();
for ( size_t i = 0; i < chain.size(); ++i ) for ( size_t i = 0; i < chain.size(); ++i ) {
{
int off = chain[i]->Data() - old_data; int off = chain[i]->Data() - old_data;
rval->chain.push_back(new IPv6_Hdr(chain[i]->Type(), new_data + off)); rval->chain.push_back(new IPv6_Hdr(chain[i]->Type(), new_data + off));
} }

137
src/IP.h
View file

@ -20,8 +20,7 @@
#include "zeek/IntrusivePtr.h" #include "zeek/IntrusivePtr.h"
namespace zeek namespace zeek {
{
class IPAddr; class IPAddr;
class RecordVal; class RecordVal;
@ -29,8 +28,7 @@ class VectorVal;
using RecordValPtr = IntrusivePtr<RecordVal>; using RecordValPtr = IntrusivePtr<RecordVal>;
using VectorValPtr = IntrusivePtr<VectorVal>; using VectorValPtr = IntrusivePtr<VectorVal>;
namespace detail namespace detail {
{
class FragReassembler; class FragReassembler;
} }
@ -38,8 +36,7 @@ class FragReassembler;
#define IPPROTO_MOBILITY 135 #define IPPROTO_MOBILITY 135
#endif #endif
struct ip6_mobility struct ip6_mobility {
{
uint8_t ip6mob_payload; uint8_t ip6mob_payload;
uint8_t ip6mob_len; uint8_t ip6mob_len;
uint8_t ip6mob_type; uint8_t ip6mob_type;
@ -50,8 +47,7 @@ struct ip6_mobility
/** /**
* Base class for IPv6 header/extensions. * Base class for IPv6 header/extensions.
*/ */
class IPv6_Hdr class IPv6_Hdr {
{
public: public:
/** /**
* Construct an IPv6 header or extension header from assigned type number. * Construct an IPv6 header or extension header from assigned type number.
@ -61,24 +57,17 @@ public:
/** /**
* Replace the value of the next protocol field. * Replace the value of the next protocol field.
*/ */
void ChangeNext(uint8_t next_type) void ChangeNext(uint8_t next_type) {
{ switch ( type ) {
switch ( type ) case IPPROTO_IPV6: ((ip6_hdr*)data)->ip6_nxt = next_type; break;
{
case IPPROTO_IPV6:
((ip6_hdr*)data)->ip6_nxt = next_type;
break;
case IPPROTO_HOPOPTS: case IPPROTO_HOPOPTS:
case IPPROTO_DSTOPTS: case IPPROTO_DSTOPTS:
case IPPROTO_ROUTING: case IPPROTO_ROUTING:
case IPPROTO_FRAGMENT: case IPPROTO_FRAGMENT:
case IPPROTO_AH: case IPPROTO_AH:
case IPPROTO_MOBILITY: case IPPROTO_MOBILITY: ((ip6_ext*)data)->ip6e_nxt = next_type; break;
((ip6_ext*)data)->ip6e_nxt = next_type;
break;
case IPPROTO_ESP: case IPPROTO_ESP:
default: default: break;
break;
} }
} }
@ -88,47 +77,34 @@ public:
* Returns the assigned IPv6 extension header type number of the header * Returns the assigned IPv6 extension header type number of the header
* that immediately follows this one. * that immediately follows this one.
*/ */
uint8_t NextHdr() const uint8_t NextHdr() const {
{ switch ( type ) {
switch ( type ) case IPPROTO_IPV6: return ((ip6_hdr*)data)->ip6_nxt;
{
case IPPROTO_IPV6:
return ((ip6_hdr*)data)->ip6_nxt;
case IPPROTO_HOPOPTS: case IPPROTO_HOPOPTS:
case IPPROTO_DSTOPTS: case IPPROTO_DSTOPTS:
case IPPROTO_ROUTING: case IPPROTO_ROUTING:
case IPPROTO_FRAGMENT: case IPPROTO_FRAGMENT:
case IPPROTO_AH: case IPPROTO_AH:
case IPPROTO_MOBILITY: case IPPROTO_MOBILITY: return ((ip6_ext*)data)->ip6e_nxt;
return ((ip6_ext*)data)->ip6e_nxt;
case IPPROTO_ESP: case IPPROTO_ESP:
default: default: return IPPROTO_NONE;
return IPPROTO_NONE;
} }
} }
/** /**
* Returns the length of the header in bytes. * Returns the length of the header in bytes.
*/ */
uint16_t Length() const uint16_t Length() const {
{ switch ( type ) {
switch ( type ) case IPPROTO_IPV6: return 40;
{
case IPPROTO_IPV6:
return 40;
case IPPROTO_HOPOPTS: case IPPROTO_HOPOPTS:
case IPPROTO_DSTOPTS: case IPPROTO_DSTOPTS:
case IPPROTO_ROUTING: case IPPROTO_ROUTING:
case IPPROTO_MOBILITY: case IPPROTO_MOBILITY: return 8 + 8 * ((ip6_ext*)data)->ip6e_len;
return 8 + 8 * ((ip6_ext*)data)->ip6e_len; case IPPROTO_FRAGMENT: return 8;
case IPPROTO_FRAGMENT: case IPPROTO_AH: return 8 + 4 * ((ip6_ext*)data)->ip6e_len;
return 8; case IPPROTO_ESP: return 8; // encrypted payload begins after 8 bytes
case IPPROTO_AH: default: return 0;
return 8 + 4 * ((ip6_ext*)data)->ip6e_len;
case IPPROTO_ESP:
return 8; // encrypted payload begins after 8 bytes
default:
return 0;
} }
} }
@ -156,8 +132,7 @@ private:
bool IsOptionTruncated(uint16_t off) const; bool IsOptionTruncated(uint16_t off) const;
}; };
class IPv6_Hdr_Chain class IPv6_Hdr_Chain {
{
public: public:
/** /**
* Initializes the header chain from an IPv6 header structure. * Initializes the header chain from an IPv6 header structure.
@ -195,8 +170,7 @@ public:
/** /**
* Returns pointer to fragment header structure if the chain contains one. * Returns pointer to fragment header structure if the chain contains one.
*/ */
const struct ip6_frag* GetFragHdr() const const struct ip6_frag* GetFragHdr() const {
{
return IsFragment() ? (const struct ip6_frag*)chain[chain.size() - 1]->Data() : nullptr; return IsFragment() ? (const struct ip6_frag*)chain[chain.size() - 1]->Data() : nullptr;
} }
@ -204,10 +178,7 @@ public:
* If the header chain is a fragment, returns the offset in number of bytes * If the header chain is a fragment, returns the offset in number of bytes
* relative to the start of the Fragmentable Part of the original packet. * relative to the start of the Fragmentable Part of the original packet.
*/ */
uint16_t FragOffset() const uint16_t FragOffset() const { return IsFragment() ? (ntohs(GetFragHdr()->ip6f_offlg) & 0xfff8) : 0; }
{
return IsFragment() ? (ntohs(GetFragHdr()->ip6f_offlg) & 0xfff8) : 0;
}
/** /**
* If the header chain is a fragment, returns the identification field. * If the header chain is a fragment, returns the identification field.
@ -250,10 +221,7 @@ protected:
* Initializes the header chain from an IPv6 header structure, and replaces * Initializes the header chain from an IPv6 header structure, and replaces
* the first next protocol pointer field that points to a fragment header. * the first next protocol pointer field that points to a fragment header.
*/ */
IPv6_Hdr_Chain(const struct ip6_hdr* ip6, uint16_t next, uint64_t len) IPv6_Hdr_Chain(const struct ip6_hdr* ip6, uint16_t next, uint64_t len) { Init(ip6, len, true, next); }
{
Init(ip6, len, true, next);
}
/** /**
* Initializes the header chain from an IPv6 header structure of a given * Initializes the header chain from an IPv6 header structure of a given
@ -297,8 +265,7 @@ protected:
* A class that wraps either an IPv4 or IPv6 packet and abstracts methods * A class that wraps either an IPv4 or IPv6 packet and abstracts methods
* for inquiring about common features between the two. * for inquiring about common features between the two.
*/ */
class IP_Hdr class IP_Hdr {
{
public: public:
/** /**
* Construct the header wrapper from an IPv4 packet. Caller must have * Construct the header wrapper from an IPv4 packet. Caller must have
@ -308,9 +275,7 @@ public:
* @param reassembled whether this header is for a reassembled packet. * @param reassembled whether this header is for a reassembled packet.
*/ */
IP_Hdr(const struct ip* arg_ip4, bool arg_del, bool reassembled = false) IP_Hdr(const struct ip* arg_ip4, bool arg_del, bool reassembled = false)
: ip4(arg_ip4), del(arg_del), reassembled(reassembled) : ip4(arg_ip4), del(arg_del), reassembled(reassembled) {}
{
}
/** /**
* Construct the header wrapper from an IPv6 packet. Caller must have * Construct the header wrapper from an IPv6 packet. Caller must have
@ -324,12 +289,9 @@ public:
* @param c an already-constructed header chain to take ownership of. * @param c an already-constructed header chain to take ownership of.
* @param reassembled whether this header is for a reassembled packet. * @param reassembled whether this header is for a reassembled packet.
*/ */
IP_Hdr(const struct ip6_hdr* arg_ip6, bool arg_del, uint64_t len, IP_Hdr(const struct ip6_hdr* arg_ip6, bool arg_del, uint64_t len, const IPv6_Hdr_Chain* c = nullptr,
const IPv6_Hdr_Chain* c = nullptr, bool reassembled = false) bool reassembled = false)
: ip6(arg_ip6), ip6_hdrs(c ? c : new IPv6_Hdr_Chain(ip6, len)), del(arg_del), : ip6(arg_ip6), ip6_hdrs(c ? c : new IPv6_Hdr_Chain(ip6, len)), del(arg_del), reassembled(reassembled) {}
reassembled(reassembled)
{
}
/** /**
* Copy a header. The internal buffer which contains the header data * Copy a header. The internal buffer which contains the header data
@ -341,12 +303,10 @@ public:
/** /**
* Destructor. * Destructor.
*/ */
~IP_Hdr() ~IP_Hdr() {
{
delete ip6_hdrs; delete ip6_hdrs;
if ( del ) if ( del ) {
{
delete[] (struct ip*)ip4; delete[] (struct ip*)ip4;
delete[] (struct ip6_hdr*)ip6; delete[] (struct ip6_hdr*)ip6;
} }
@ -391,8 +351,7 @@ public:
* Returns a pointer to the payload of the IP packet, usually an * Returns a pointer to the payload of the IP packet, usually an
* upper-layer protocol. * upper-layer protocol.
*/ */
const u_char* Payload() const const u_char* Payload() const {
{
if ( ip4 ) if ( ip4 )
return ((const u_char*)ip4) + ip4->ip_hl * 4; return ((const u_char*)ip4) + ip4->ip_hl * 4;
@ -403,8 +362,7 @@ public:
* Returns a pointer to the mobility header of the IP packet, if present, * Returns a pointer to the mobility header of the IP packet, if present,
* else a null pointer. * else a null pointer.
*/ */
const ip6_mobility* MobilityHeader() const const ip6_mobility* MobilityHeader() const {
{
if ( ip4 ) if ( ip4 )
return nullptr; return nullptr;
else if ( (*ip6_hdrs)[ip6_hdrs->Size() - 1]->Type() != IPPROTO_MOBILITY ) else if ( (*ip6_hdrs)[ip6_hdrs->Size() - 1]->Type() != IPPROTO_MOBILITY )
@ -420,10 +378,8 @@ public:
* Also returns 0 if the IPv4 length field is set to zero - which is, e.g., * Also returns 0 if the IPv4 length field is set to zero - which is, e.g.,
* the case when TCP segment offloading is enabled. * the case when TCP segment offloading is enabled.
*/ */
uint16_t PayloadLen() const uint16_t PayloadLen() const {
{ if ( ip4 ) {
if ( ip4 )
{
// prevent overflow in case of segment offloading/zeroed header length. // prevent overflow in case of segment offloading/zeroed header length.
auto total_len = ntohs(ip4->ip_len); auto total_len = ntohs(ip4->ip_len);
return total_len ? total_len - ip4->ip_hl * 4 : 0; return total_len ? total_len - ip4->ip_hl * 4 : 0;
@ -435,8 +391,7 @@ public:
/** /**
* Returns the length of the IP packet (length of headers and payload). * Returns the length of the IP packet (length of headers and payload).
*/ */
uint32_t TotalLen() const uint32_t TotalLen() const {
{
if ( ip4 ) if ( ip4 )
return ntohs(ip4->ip_len); return ntohs(ip4->ip_len);
@ -451,8 +406,7 @@ public:
/** /**
* For IPv6 header chains, returns the type of the last header in the chain. * For IPv6 header chains, returns the type of the last header in the chain.
*/ */
uint8_t LastHeader() const uint8_t LastHeader() const {
{
if ( ip4 ) if ( ip4 )
return IPPROTO_RAW; return IPPROTO_RAW;
@ -468,8 +422,7 @@ public:
* upper-layer protocol. For IPv6, this returns the last (extension) * upper-layer protocol. For IPv6, this returns the last (extension)
* header's Next Header value. * header's Next Header value.
*/ */
unsigned char NextProto() const unsigned char NextProto() const {
{
if ( ip4 ) if ( ip4 )
return ip4->ip_p; return ip4->ip_p;
@ -488,19 +441,13 @@ public:
/** /**
* Returns whether the IP header indicates this packet is a fragment. * Returns whether the IP header indicates this packet is a fragment.
*/ */
bool IsFragment() const bool IsFragment() const { return ip4 ? (ntohs(ip4->ip_off) & 0x3fff) != 0 : ip6_hdrs->IsFragment(); }
{
return ip4 ? (ntohs(ip4->ip_off) & 0x3fff) != 0 : ip6_hdrs->IsFragment();
}
/** /**
* Returns the fragment packet's offset in relation to the original * Returns the fragment packet's offset in relation to the original
* packet in bytes. * packet in bytes.
*/ */
uint16_t FragOffset() const uint16_t FragOffset() const { return ip4 ? (ntohs(ip4->ip_off) & 0x1fff) * 8 : ip6_hdrs->FragOffset(); }
{
return ip4 ? (ntohs(ip4->ip_off) & 0x1fff) * 8 : ip6_hdrs->FragOffset();
}
/** /**
* Returns the fragment packet's identification field. * Returns the fragment packet's identification field.

View file

@ -13,28 +13,23 @@
#include "zeek/ZeekString.h" #include "zeek/ZeekString.h"
#include "zeek/analyzer/Manager.h" #include "zeek/analyzer/Manager.h"
namespace zeek namespace zeek {
{
const IPAddr IPAddr::v4_unspecified = IPAddr(in4_addr{}); const IPAddr IPAddr::v4_unspecified = IPAddr(in4_addr{});
const IPAddr IPAddr::v6_unspecified = IPAddr(); const IPAddr IPAddr::v6_unspecified = IPAddr();
namespace detail namespace detail {
{
ConnKey::ConnKey(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, ConnKey::ConnKey(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, TransportProto t,
TransportProto t, bool one_way) bool one_way) {
{
Init(src, dst, src_port, dst_port, t, one_way); Init(src, dst, src_port, dst_port, t, one_way);
} }
ConnKey::ConnKey(const ConnTuple& id) ConnKey::ConnKey(const ConnTuple& id) {
{
Init(id.src_addr, id.dst_addr, id.src_port, id.dst_port, id.proto, id.is_one_way); Init(id.src_addr, id.dst_addr, id.src_port, id.dst_port, id.proto, id.is_one_way);
} }
ConnKey& ConnKey::operator=(const ConnKey& rhs) ConnKey& ConnKey::operator=(const ConnKey& rhs) {
{
if ( this == &rhs ) if ( this == &rhs )
return *this; return *this;
@ -53,11 +48,9 @@ ConnKey& ConnKey::operator=(const ConnKey& rhs)
return *this; return *this;
} }
ConnKey::ConnKey(Val* v) ConnKey::ConnKey(Val* v) {
{
const auto& vt = v->GetType(); const auto& vt = v->GetType();
if ( ! IsRecord(vt->Tag()) ) if ( ! IsRecord(vt->Tag()) ) {
{
valid = false; valid = false;
return; return;
} }
@ -68,23 +61,20 @@ ConnKey::ConnKey(Val* v)
int orig_h, orig_p; // indices into record's value list int orig_h, orig_p; // indices into record's value list
int resp_h, resp_p; int resp_h, resp_p;
if ( vr == id::conn_id ) if ( vr == id::conn_id ) {
{
orig_h = 0; orig_h = 0;
orig_p = 1; orig_p = 1;
resp_h = 2; resp_h = 2;
resp_p = 3; resp_p = 3;
} }
else else {
{
// While it's not a conn_id, it may have equivalent fields. // While it's not a conn_id, it may have equivalent fields.
orig_h = vr->FieldOffset("orig_h"); orig_h = vr->FieldOffset("orig_h");
resp_h = vr->FieldOffset("resp_h"); resp_h = vr->FieldOffset("resp_h");
orig_p = vr->FieldOffset("orig_p"); orig_p = vr->FieldOffset("orig_p");
resp_p = vr->FieldOffset("resp_p"); resp_p = vr->FieldOffset("resp_p");
if ( orig_h < 0 || resp_h < 0 || orig_p < 0 || resp_p < 0 ) if ( orig_h < 0 || resp_h < 0 || orig_p < 0 || resp_p < 0 ) {
{
valid = false; valid = false;
return; return;
} }
@ -99,13 +89,12 @@ ConnKey::ConnKey(Val* v)
auto orig_portv = vl->GetFieldAs<PortVal>(orig_p); auto orig_portv = vl->GetFieldAs<PortVal>(orig_p);
auto resp_portv = vl->GetFieldAs<PortVal>(resp_p); auto resp_portv = vl->GetFieldAs<PortVal>(resp_p);
Init(orig_addr, resp_addr, htons((unsigned short)orig_portv->Port()), Init(orig_addr, resp_addr, htons((unsigned short)orig_portv->Port()), htons((unsigned short)resp_portv->Port()),
htons((unsigned short)resp_portv->Port()), orig_portv->PortType(), false); orig_portv->PortType(), false);
} }
void ConnKey::Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, void ConnKey::Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, TransportProto t,
TransportProto t, bool one_way) bool one_way) {
{
// Because of padding in the object, this needs to memset to clear out // Because of padding in the object, this needs to memset to clear out
// the extra memory used by padding. Otherwise, the session key stuff // the extra memory used by padding. Otherwise, the session key stuff
// doesn't work quite right. // doesn't work quite right.
@ -114,15 +103,13 @@ void ConnKey::Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint
// Lookup up connection based on canonical ordering, which is // Lookup up connection based on canonical ordering, which is
// the smaller of <src addr, src port> and <dst addr, dst port> // the smaller of <src addr, src port> and <dst addr, dst port>
// followed by the other. // followed by the other.
if ( one_way || addr_port_canon_lt(src, src_port, dst, dst_port) ) if ( one_way || addr_port_canon_lt(src, src_port, dst, dst_port) ) {
{
ip1 = src.in6; ip1 = src.in6;
ip2 = dst.in6; ip2 = dst.in6;
port1 = src_port; port1 = src_port;
port2 = dst_port; port2 = dst_port;
} }
else else {
{
ip1 = dst.in6; ip1 = dst.in6;
ip2 = src.in6; ip2 = src.in6;
port1 = dst_port; port1 = dst_port;
@ -135,28 +122,21 @@ void ConnKey::Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint
} // namespace detail } // namespace detail
IPAddr::IPAddr(const String& s) IPAddr::IPAddr(const String& s) { Init(s.CheckString()); }
{
Init(s.CheckString());
}
std::unique_ptr<detail::HashKey> IPAddr::MakeHashKey() const std::unique_ptr<detail::HashKey> IPAddr::MakeHashKey() const {
{
return std::make_unique<detail::HashKey>((void*)in6.s6_addr, sizeof(in6.s6_addr)); return std::make_unique<detail::HashKey>((void*)in6.s6_addr, sizeof(in6.s6_addr));
} }
static inline uint32_t bit_mask32(int bottom_bits) static inline uint32_t bit_mask32(int bottom_bits) {
{
if ( bottom_bits >= 32 ) if ( bottom_bits >= 32 )
return 0xffffffff; return 0xffffffff;
return (((uint32_t)1) << bottom_bits) - 1; return (((uint32_t)1) << bottom_bits) - 1;
} }
void IPAddr::Mask(int top_bits_to_keep) void IPAddr::Mask(int top_bits_to_keep) {
{ if ( top_bits_to_keep < 0 || top_bits_to_keep > 128 ) {
if ( top_bits_to_keep < 0 || top_bits_to_keep > 128 )
{
reporter->Error("Bad IPAddr::Mask value %d", top_bits_to_keep); reporter->Error("Bad IPAddr::Mask value %d", top_bits_to_keep);
return; return;
} }
@ -176,10 +156,8 @@ void IPAddr::Mask(int top_bits_to_keep)
p[i] &= mask_bits[i]; p[i] &= mask_bits[i];
} }
void IPAddr::ReverseMask(int top_bits_to_chop) void IPAddr::ReverseMask(int top_bits_to_chop) {
{ if ( top_bits_to_chop < 0 || top_bits_to_chop > 128 ) {
if ( top_bits_to_chop < 0 || top_bits_to_chop > 128 )
{
reporter->Error("Bad IPAddr::ReverseMask value %d", top_bits_to_chop); reporter->Error("Bad IPAddr::ReverseMask value %d", top_bits_to_chop);
return; return;
} }
@ -199,8 +177,7 @@ void IPAddr::ReverseMask(int top_bits_to_chop)
p[i] &= mask_bits[i]; p[i] &= mask_bits[i];
} }
bool IPAddr::ConvertString(const char* s, in6_addr* result) bool IPAddr::ConvertString(const char* s, in6_addr* result) {
{
for ( auto p = s; *p; ++p ) for ( auto p = s; *p; ++p )
if ( *p == ':' ) if ( *p == ':' )
// IPv6 // IPv6
@ -231,19 +208,15 @@ bool IPAddr::ConvertString(const char* s, in6_addr* result)
return true; return true;
} }
void IPAddr::Init(const char* s) void IPAddr::Init(const char* s) {
{ if ( ! ConvertString(s, &in6) ) {
if ( ! ConvertString(s, &in6) )
{
reporter->Error("Bad IP address: %s", s); reporter->Error("Bad IP address: %s", s);
memset(in6.s6_addr, 0, sizeof(in6.s6_addr)); memset(in6.s6_addr, 0, sizeof(in6.s6_addr));
} }
} }
std::string IPAddr::AsString() const std::string IPAddr::AsString() const {
{ if ( GetFamily() == IPv4 ) {
if ( GetFamily() == IPv4 )
{
char s[INET_ADDRSTRLEN]; char s[INET_ADDRSTRLEN];
if ( ! zeek_inet_ntop(AF_INET, &in6.s6_addr[12], s, INET_ADDRSTRLEN) ) if ( ! zeek_inet_ntop(AF_INET, &in6.s6_addr[12], s, INET_ADDRSTRLEN) )
@ -251,8 +224,7 @@ std::string IPAddr::AsString() const
else else
return s; return s;
} }
else else {
{
char s[INET6_ADDRSTRLEN]; char s[INET6_ADDRSTRLEN];
if ( ! zeek_inet_ntop(AF_INET6, in6.s6_addr, s, INET6_ADDRSTRLEN) ) if ( ! zeek_inet_ntop(AF_INET6, in6.s6_addr, s, INET6_ADDRSTRLEN) )
@ -262,17 +234,14 @@ std::string IPAddr::AsString() const
} }
} }
std::string IPAddr::AsHexString() const std::string IPAddr::AsHexString() const {
{
char buf[33]; char buf[33];
if ( GetFamily() == IPv4 ) if ( GetFamily() == IPv4 ) {
{
uint32_t* p = (uint32_t*)&in6.s6_addr[12]; uint32_t* p = (uint32_t*)&in6.s6_addr[12];
snprintf(buf, sizeof(buf), "%08x", (uint32_t)ntohl(*p)); snprintf(buf, sizeof(buf), "%08x", (uint32_t)ntohl(*p));
} }
else else {
{
uint32_t* p = (uint32_t*)in6.s6_addr; uint32_t* p = (uint32_t*)in6.s6_addr;
snprintf(buf, sizeof(buf), "%08x%08x%08x%08x", (uint32_t)ntohl(p[0]), (uint32_t)ntohl(p[1]), snprintf(buf, sizeof(buf), "%08x%08x%08x%08x", (uint32_t)ntohl(p[0]), (uint32_t)ntohl(p[1]),
(uint32_t)ntohl(p[2]), (uint32_t)ntohl(p[3])); (uint32_t)ntohl(p[2]), (uint32_t)ntohl(p[3]));
@ -281,10 +250,8 @@ std::string IPAddr::AsHexString() const
return buf; return buf;
} }
std::string IPAddr::PtrName() const std::string IPAddr::PtrName() const {
{ if ( GetFamily() == IPv4 ) {
if ( GetFamily() == IPv4 )
{
char buf[256]; char buf[256];
uint32_t* p = (uint32_t*)&in6.s6_addr[12]; uint32_t* p = (uint32_t*)&in6.s6_addr[12];
uint32_t a = ntohl(*p); uint32_t a = ntohl(*p);
@ -295,17 +262,14 @@ std::string IPAddr::PtrName() const
snprintf(buf, sizeof(buf), "%u.%u.%u.%u.in-addr.arpa", a0, a1, a2, a3); snprintf(buf, sizeof(buf), "%u.%u.%u.%u.in-addr.arpa", a0, a1, a2, a3);
return buf; return buf;
} }
else else {
{
static const char hex_digit[] = "0123456789abcdef"; static const char hex_digit[] = "0123456789abcdef";
std::string ptr_name("ip6.arpa"); std::string ptr_name("ip6.arpa");
uint32_t* p = (uint32_t*)in6.s6_addr; uint32_t* p = (uint32_t*)in6.s6_addr;
for ( unsigned int i = 0; i < 4; ++i ) for ( unsigned int i = 0; i < 4; ++i ) {
{
uint32_t a = ntohl(p[i]); uint32_t a = ntohl(p[i]);
for ( unsigned int j = 1; j <= 8; ++j ) for ( unsigned int j = 1; j <= 8; ++j ) {
{
ptr_name.insert(0, 1, '.'); ptr_name.insert(0, 1, '.');
ptr_name.insert(0, 1, hex_digit[(a >> (32 - j * 4)) & 0x0f]); ptr_name.insert(0, 1, hex_digit[(a >> (32 - j * 4)) & 0x0f]);
} }
@ -315,10 +279,8 @@ std::string IPAddr::PtrName() const
} }
} }
IPPrefix::IPPrefix(const in4_addr& in4, uint8_t length) : prefix(in4), length(96 + length) IPPrefix::IPPrefix(const in4_addr& in4, uint8_t length) : prefix(in4), length(96 + length) {
{ if ( length > 32 ) {
if ( length > 32 )
{
reporter->Error("Bad in4_addr IPPrefix length : %d", length); reporter->Error("Bad in4_addr IPPrefix length : %d", length);
this->length = 0; this->length = 0;
} }
@ -326,10 +288,8 @@ IPPrefix::IPPrefix(const in4_addr& in4, uint8_t length) : prefix(in4), length(96
prefix.Mask(this->length); prefix.Mask(this->length);
} }
IPPrefix::IPPrefix(const in6_addr& in6, uint8_t length) : prefix(in6), length(length) IPPrefix::IPPrefix(const in6_addr& in6, uint8_t length) : prefix(in6), length(length) {
{ if ( length > 128 ) {
if ( length > 128 )
{
reporter->Error("Bad in6_addr IPPrefix length : %d", length); reporter->Error("Bad in6_addr IPPrefix length : %d", length);
this->length = 0; this->length = 0;
} }
@ -337,16 +297,13 @@ IPPrefix::IPPrefix(const in6_addr& in6, uint8_t length) : prefix(in6), length(le
prefix.Mask(this->length); prefix.Mask(this->length);
} }
bool IPAddr::CheckPrefixLength(uint8_t length, bool len_is_v6_relative) const bool IPAddr::CheckPrefixLength(uint8_t length, bool len_is_v6_relative) const {
{ if ( GetFamily() == IPv4 && ! len_is_v6_relative ) {
if ( GetFamily() == IPv4 && ! len_is_v6_relative )
{
if ( length > 32 ) if ( length > 32 )
return false; return false;
} }
else else {
{
if ( length > 128 ) if ( length > 128 )
return false; return false;
} }
@ -354,17 +311,14 @@ bool IPAddr::CheckPrefixLength(uint8_t length, bool len_is_v6_relative) const
return true; return true;
} }
IPPrefix::IPPrefix(const IPAddr& addr, uint8_t length, bool len_is_v6_relative) : prefix(addr) IPPrefix::IPPrefix(const IPAddr& addr, uint8_t length, bool len_is_v6_relative) : prefix(addr) {
{ if ( prefix.CheckPrefixLength(length, len_is_v6_relative) ) {
if ( prefix.CheckPrefixLength(length, len_is_v6_relative) )
{
if ( prefix.GetFamily() == IPv4 && ! len_is_v6_relative ) if ( prefix.GetFamily() == IPv4 && ! len_is_v6_relative )
this->length = length + 96; this->length = length + 96;
else else
this->length = length; this->length = length;
} }
else else {
{
auto vstr = prefix.GetFamily() == IPv4 ? "v4" : "v6"; auto vstr = prefix.GetFamily() == IPv4 ? "v4" : "v6";
reporter->Error("Bad IPAddr(%s) IPPrefix length : %d", vstr, length); reporter->Error("Bad IPAddr(%s) IPPrefix length : %d", vstr, length);
this->length = 0; this->length = 0;
@ -373,8 +327,7 @@ IPPrefix::IPPrefix(const IPAddr& addr, uint8_t length, bool len_is_v6_relative)
prefix.Mask(this->length); prefix.Mask(this->length);
} }
std::string IPPrefix::AsString() const std::string IPPrefix::AsString() const {
{
char l[16]; char l[16];
if ( prefix.GetFamily() == IPv4 ) if ( prefix.GetFamily() == IPv4 )
@ -385,10 +338,8 @@ std::string IPPrefix::AsString() const
return prefix.AsString() + "/" + l; return prefix.AsString() + "/" + l;
} }
std::unique_ptr<detail::HashKey> IPPrefix::MakeHashKey() const std::unique_ptr<detail::HashKey> IPPrefix::MakeHashKey() const {
{ struct {
struct
{
in6_addr ip; in6_addr ip;
uint32_t len; uint32_t len;
} key; } key;
@ -399,8 +350,7 @@ std::unique_ptr<detail::HashKey> IPPrefix::MakeHashKey() const
return std::make_unique<detail::HashKey>(&key, sizeof(key)); return std::make_unique<detail::HashKey>(&key, sizeof(key));
} }
bool IPPrefix::ConvertString(const char* text, IPPrefix* result) bool IPPrefix::ConvertString(const char* text, IPPrefix* result) {
{
std::string s(text); std::string s(text);
size_t slash_loc = s.find('/'); size_t slash_loc = s.find('/');

View file

@ -12,20 +12,17 @@
using in4_addr = in_addr; using in4_addr = in_addr;
namespace zeek namespace zeek {
{
class String; class String;
struct ConnTuple; struct ConnTuple;
class Val; class Val;
namespace detail namespace detail {
{
class HashKey; class HashKey;
class ConnKey class ConnKey {
{
public: public:
in6_addr ip1; in6_addr ip1;
in6_addr ip2; in6_addr ip2;
@ -34,8 +31,7 @@ public:
TransportProto transport = TRANSPORT_UNKNOWN; TransportProto transport = TRANSPORT_UNKNOWN;
bool valid = true; bool valid = true;
ConnKey(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, ConnKey(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, TransportProto t, bool one_way);
TransportProto t, bool one_way);
ConnKey(const ConnTuple& conn); ConnKey(const ConnTuple& conn);
ConnKey(const ConnKey& rhs) { *this = rhs; } ConnKey(const ConnKey& rhs) { *this = rhs; }
ConnKey(Val* v); ConnKey(Val* v);
@ -50,8 +46,8 @@ public:
ConnKey& operator=(const ConnKey& rhs); ConnKey& operator=(const ConnKey& rhs);
private: private:
void Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, void Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, TransportProto t,
TransportProto t, bool one_way); bool one_way);
}; };
} // namespace detail } // namespace detail
@ -59,8 +55,7 @@ private:
/** /**
* Class storing both IPv4 and IPv6 addresses. * Class storing both IPv4 and IPv6 addresses.
*/ */
class IPAddr class IPAddr {
{
public: public:
/** /**
* Address family. * Address family.
@ -70,11 +65,7 @@ public:
/** /**
* Byte order. * Byte order.
*/ */
enum ByteOrder enum ByteOrder { Host, Network };
{
Host,
Network
};
/** /**
* Constructs the unspecified IPv6 address (all 128 bits zeroed). * Constructs the unspecified IPv6 address (all 128 bits zeroed).
@ -86,8 +77,7 @@ public:
* *
* @param in6 The IPv6 address. * @param in6 The IPv6 address.
*/ */
explicit IPAddr(const in4_addr& in4) explicit IPAddr(const in4_addr& in4) {
{
memcpy(in6.s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix)); memcpy(in6.s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix));
memcpy(&in6.s6_addr[12], &in4.s_addr, sizeof(in4.s_addr)); memcpy(&in6.s6_addr[12], &in4.s_addr, sizeof(in4.s_addr));
} }
@ -150,8 +140,7 @@ public:
/** /**
* Returns the address' family. * Returns the address' family.
*/ */
Family GetFamily() const Family GetFamily() const {
{
if ( memcmp(in6.s6_addr, v4_mapped_prefix, 12) == 0 ) if ( memcmp(in6.s6_addr, v4_mapped_prefix, 12) == 0 )
return IPv4; return IPv4;
@ -166,8 +155,7 @@ public:
/** /**
* Returns true if the address represents a multicast address. * Returns true if the address represents a multicast address.
*/ */
bool IsMulticast() const bool IsMulticast() const {
{
if ( GetFamily() == IPv4 ) if ( GetFamily() == IPv4 )
return in6.s6_addr[12] == 224; return in6.s6_addr[12] == 224;
@ -177,11 +165,10 @@ public:
/** /**
* Returns true if the address represents a broadcast address. * Returns true if the address represents a broadcast address.
*/ */
bool IsBroadcast() const bool IsBroadcast() const {
{
if ( GetFamily() == IPv4 ) if ( GetFamily() == IPv4 )
return ((in6.s6_addr[12] == 0xff) && (in6.s6_addr[13] == 0xff) && return ((in6.s6_addr[12] == 0xff) && (in6.s6_addr[13] == 0xff) && (in6.s6_addr[14] == 0xff) &&
(in6.s6_addr[14] == 0xff) && (in6.s6_addr[15] == 0xff)); (in6.s6_addr[15] == 0xff));
return false; return false;
} }
@ -198,15 +185,12 @@ public:
* @return The number of 32-bit words the raw representation uses. This * @return The number of 32-bit words the raw representation uses. This
* will be 1 for an IPv4 address and 4 for an IPv6 address. * will be 1 for an IPv4 address and 4 for an IPv6 address.
*/ */
int GetBytes(const uint32_t** bytes) const int GetBytes(const uint32_t** bytes) const {
{ if ( GetFamily() == IPv4 ) {
if ( GetFamily() == IPv4 )
{
*bytes = (uint32_t*)&in6.s6_addr[12]; *bytes = (uint32_t*)&in6.s6_addr[12];
return 1; return 1;
} }
else else {
{
*bytes = (uint32_t*)in6.s6_addr; *bytes = (uint32_t*)in6.s6_addr;
return 4; return 4;
} }
@ -223,12 +207,10 @@ public:
* @param order The byte-order in which the returned raw bytes are copied. * @param order The byte-order in which the returned raw bytes are copied.
* The default is network order. * The default is network order.
*/ */
void CopyIPv6(uint32_t* bytes, ByteOrder order = Network) const void CopyIPv6(uint32_t* bytes, ByteOrder order = Network) const {
{
memcpy(bytes, in6.s6_addr, sizeof(in6.s6_addr)); memcpy(bytes, in6.s6_addr, sizeof(in6.s6_addr));
if ( order == Host ) if ( order == Host ) {
{
for ( unsigned int i = 0; i < 4; ++i ) for ( unsigned int i = 0; i < 4; ++i )
bytes[i] = ntohl(bytes[i]); bytes[i] = ntohl(bytes[i]);
} }
@ -238,10 +220,7 @@ public:
* Retrieves a copy of the IPv6 raw byte representation of the address. * Retrieves a copy of the IPv6 raw byte representation of the address.
* @see CopyIPv6(uint32_t) * @see CopyIPv6(uint32_t)
*/ */
void CopyIPv6(in6_addr* arg_in6) const void CopyIPv6(in6_addr* arg_in6) const { memcpy(arg_in6->s6_addr, in6.s6_addr, sizeof(in6.s6_addr)); }
{
memcpy(arg_in6->s6_addr, in6.s6_addr, sizeof(in6.s6_addr));
}
/** /**
* Retrieves a copy of the IPv4 raw byte representation of the address. * Retrieves a copy of the IPv4 raw byte representation of the address.
@ -251,10 +230,7 @@ public:
* @param in4 The pointer to a memory location in which the raw bytes * @param in4 The pointer to a memory location in which the raw bytes
* of the address are to be copied in network byte-order. * of the address are to be copied in network byte-order.
*/ */
void CopyIPv4(in4_addr* in4) const void CopyIPv4(in4_addr* in4) const { memcpy(&in4->s_addr, &in6.s6_addr[12], sizeof(in4->s_addr)); }
{
memcpy(&in4->s_addr, &in6.s6_addr[12], sizeof(in4->s_addr));
}
/** /**
* Returns a key that can be used to lookup the IP Address in a hash table. * Returns a key that can be used to lookup the IP Address in a hash table.
@ -287,8 +263,7 @@ public:
/** /**
* Assignment operator. * Assignment operator.
*/ */
IPAddr& operator=(const IPAddr& other) IPAddr& operator=(const IPAddr& other) {
{
// No self-assignment check here because it's correct without it and // No self-assignment check here because it's correct without it and
// makes the common case faster. // makes the common case faster.
in6 = other.in6; in6 = other.in6;
@ -299,8 +274,7 @@ public:
* Bitwise OR operator returns the IP address resulting from the bitwise * Bitwise OR operator returns the IP address resulting from the bitwise
* OR operation on the raw bytes of this address with another. * OR operation on the raw bytes of this address with another.
*/ */
IPAddr operator|(const IPAddr& other) IPAddr operator|(const IPAddr& other) {
{
in6_addr result; in6_addr result;
for ( int i = 0; i < 16; ++i ) for ( int i = 0; i < 16; ++i )
result.s6_addr[i] = this->in6.s6_addr[i] | other.in6.s6_addr[i]; result.s6_addr[i] = this->in6.s6_addr[i] | other.in6.s6_addr[i];
@ -320,8 +294,7 @@ public:
* in an URI. For IPv4 addresses, this is the same as AsString(), but * in an URI. For IPv4 addresses, this is the same as AsString(), but
* IPv6 addresses are encased in square brackets. * IPv6 addresses are encased in square brackets.
*/ */
std::string AsURIString() const std::string AsURIString() const {
{
if ( GetFamily() == IPv4 ) if ( GetFamily() == IPv4 )
return AsString(); return AsString();
@ -348,8 +321,7 @@ public:
/** /**
* Comparison operator for IP address. * Comparison operator for IP address.
*/ */
friend bool operator==(const IPAddr& addr1, const IPAddr& addr2) friend bool operator==(const IPAddr& addr1, const IPAddr& addr2) {
{
return memcmp(&addr1.in6, &addr2.in6, sizeof(in6_addr)) == 0; return memcmp(&addr1.in6, &addr2.in6, sizeof(in6_addr)) == 0;
} }
@ -360,15 +332,11 @@ public:
* IP addresses. However, the order does not necessarily correspond to * IP addresses. However, the order does not necessarily correspond to
* their numerical values. * their numerical values.
*/ */
friend bool operator<(const IPAddr& addr1, const IPAddr& addr2) friend bool operator<(const IPAddr& addr1, const IPAddr& addr2) {
{
return memcmp(&addr1.in6, &addr2.in6, sizeof(in6_addr)) < 0; return memcmp(&addr1.in6, &addr2.in6, sizeof(in6_addr)) < 0;
} }
friend bool operator<=(const IPAddr& addr1, const IPAddr& addr2) friend bool operator<=(const IPAddr& addr1, const IPAddr& addr2) { return addr1 < addr2 || addr1 == addr2; }
{
return addr1 < addr2 || addr1 == addr2;
}
friend bool operator>=(const IPAddr& addr1, const IPAddr& addr2) { return ! (addr1 < addr2); } friend bool operator>=(const IPAddr& addr1, const IPAddr& addr2) { return ! (addr1 < addr2); }
@ -412,8 +380,7 @@ public:
* *
* @return whether the string is a valid IP address * @return whether the string is a valid IP address
*/ */
static bool IsValid(const char* s) static bool IsValid(const char* s) {
{
in6_addr tmp; in6_addr tmp;
return ConvertString(s, &tmp); return ConvertString(s, &tmp);
} }
@ -446,28 +413,22 @@ private:
static constexpr uint8_t v4_mapped_prefix[12] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}; static constexpr uint8_t v4_mapped_prefix[12] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff};
}; };
inline IPAddr::IPAddr(Family family, const uint32_t* bytes, ByteOrder order) inline IPAddr::IPAddr(Family family, const uint32_t* bytes, ByteOrder order) {
{ if ( family == IPv4 ) {
if ( family == IPv4 )
{
memcpy(in6.s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix)); memcpy(in6.s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix));
memcpy(&in6.s6_addr[12], bytes, sizeof(uint32_t)); memcpy(&in6.s6_addr[12], bytes, sizeof(uint32_t));
if ( order == Host ) if ( order == Host ) {
{
uint32_t* p = (uint32_t*)&in6.s6_addr[12]; uint32_t* p = (uint32_t*)&in6.s6_addr[12];
*p = htonl(*p); *p = htonl(*p);
} }
} }
else else {
{
memcpy(in6.s6_addr, bytes, sizeof(in6.s6_addr)); memcpy(in6.s6_addr, bytes, sizeof(in6.s6_addr));
if ( order == Host ) if ( order == Host ) {
{ for ( unsigned int i = 0; i < 4; ++i ) {
for ( unsigned int i = 0; i < 4; ++i )
{
uint32_t* p = (uint32_t*)&in6.s6_addr[i * 4]; uint32_t* p = (uint32_t*)&in6.s6_addr[i * 4];
*p = htonl(*p); *p = htonl(*p);
} }
@ -475,33 +436,24 @@ inline IPAddr::IPAddr(Family family, const uint32_t* bytes, ByteOrder order)
} }
} }
inline bool IPAddr::IsLoopback() const inline bool IPAddr::IsLoopback() const {
{
if ( GetFamily() == IPv4 ) if ( GetFamily() == IPv4 )
return in6.s6_addr[12] == 127; return in6.s6_addr[12] == 127;
else else
return ((in6.s6_addr[0] == 0) && (in6.s6_addr[1] == 0) && (in6.s6_addr[2] == 0) && return ((in6.s6_addr[0] == 0) && (in6.s6_addr[1] == 0) && (in6.s6_addr[2] == 0) && (in6.s6_addr[3] == 0) &&
(in6.s6_addr[3] == 0) && (in6.s6_addr[4] == 0) && (in6.s6_addr[5] == 0) && (in6.s6_addr[4] == 0) && (in6.s6_addr[5] == 0) && (in6.s6_addr[6] == 0) && (in6.s6_addr[7] == 0) &&
(in6.s6_addr[6] == 0) && (in6.s6_addr[7] == 0) && (in6.s6_addr[8] == 0) && (in6.s6_addr[8] == 0) && (in6.s6_addr[9] == 0) && (in6.s6_addr[10] == 0) && (in6.s6_addr[11] == 0) &&
(in6.s6_addr[9] == 0) && (in6.s6_addr[10] == 0) && (in6.s6_addr[11] == 0) && (in6.s6_addr[12] == 0) && (in6.s6_addr[13] == 0) && (in6.s6_addr[14] == 0) && (in6.s6_addr[15] == 1));
(in6.s6_addr[12] == 0) && (in6.s6_addr[13] == 0) && (in6.s6_addr[14] == 0) &&
(in6.s6_addr[15] == 1));
} }
inline void IPAddr::ConvertToThreadingValue(threading::Value::addr_t* v) const inline void IPAddr::ConvertToThreadingValue(threading::Value::addr_t* v) const {
{
v->family = GetFamily(); v->family = GetFamily();
switch ( v->family ) switch ( v->family ) {
{ case IPv4: CopyIPv4(&v->in.in4); return;
case IPv4:
CopyIPv4(&v->in.in4);
return;
case IPv6: case IPv6: CopyIPv6(&v->in.in6); return;
CopyIPv6(&v->in.in6);
return;
} }
} }
@ -509,8 +461,7 @@ inline void IPAddr::ConvertToThreadingValue(threading::Value::addr_t* v) const
* Class storing both IPv4 and IPv6 prefixes * Class storing both IPv4 and IPv6 prefixes
* (i.e., \c 192.168.1.1/16 and \c FD00::/8. * (i.e., \c 192.168.1.1/16 and \c FD00::/8.
*/ */
class IPPrefix class IPPrefix {
{
public: public:
/** /**
* Constructs a prefix 0/0. * Constructs a prefix 0/0.
@ -585,8 +536,7 @@ public:
* *
* @param addr The address to test. * @param addr The address to test.
*/ */
bool Contains(const IPAddr& addr) const bool Contains(const IPAddr& addr) const {
{
IPAddr p(addr); IPAddr p(addr);
p.Mask(length); p.Mask(length);
return p == prefix; return p == prefix;
@ -594,8 +544,7 @@ public:
/** /**
* Assignment operator. * Assignment operator.
*/ */
IPPrefix& operator=(const IPPrefix& other) IPPrefix& operator=(const IPPrefix& other) {
{
// No self-assignment check here because it's correct without it and // No self-assignment check here because it's correct without it and
// makes the common case faster. // makes the common case faster.
prefix = other.prefix; prefix = other.prefix;
@ -621,8 +570,7 @@ public:
* Converts the prefix into the type used internally by the * Converts the prefix into the type used internally by the
* inter-thread communication. * inter-thread communication.
*/ */
void ConvertToThreadingValue(threading::Value::subnet_t* v) const void ConvertToThreadingValue(threading::Value::subnet_t* v) const {
{
v->length = length; v->length = length;
prefix.ConvertToThreadingValue(&v->prefix); prefix.ConvertToThreadingValue(&v->prefix);
} }
@ -630,8 +578,7 @@ public:
/** /**
* Comparison operator for IP prefix. * Comparison operator for IP prefix.
*/ */
friend bool operator==(const IPPrefix& net1, const IPPrefix& net2) friend bool operator==(const IPPrefix& net1, const IPPrefix& net2) {
{
return net1.Prefix() == net2.Prefix() && net1.Length() == net2.Length(); return net1.Prefix() == net2.Prefix() && net1.Length() == net2.Length();
} }
@ -642,8 +589,7 @@ public:
* IP prefix. However, the order does not necessarily corresponding to their * IP prefix. However, the order does not necessarily corresponding to their
* numerical values. * numerical values.
*/ */
friend bool operator<(const IPPrefix& net1, const IPPrefix& net2) friend bool operator<(const IPPrefix& net1, const IPPrefix& net2) {
{
if ( net1.Prefix() < net2.Prefix() ) if ( net1.Prefix() < net2.Prefix() )
return true; return true;
@ -654,10 +600,7 @@ public:
return false; return false;
} }
friend bool operator<=(const IPPrefix& net1, const IPPrefix& net2) friend bool operator<=(const IPPrefix& net1, const IPPrefix& net2) { return net1 < net2 || net1 == net2; }
{
return net1 < net2 || net1 == net2;
}
friend bool operator>=(const IPPrefix& net1, const IPPrefix& net2) { return ! (net1 < net2); } friend bool operator>=(const IPPrefix& net1, const IPPrefix& net2) { return ! (net1 < net2); }
@ -679,8 +622,7 @@ public:
* *
* @return whether the string is a valid IP address prefix * @return whether the string is a valid IP address prefix
*/ */
static bool IsValid(const char* s) static bool IsValid(const char* s) {
{
IPPrefix tmp; IPPrefix tmp;
return ConvertString(s, &tmp); return ConvertString(s, &tmp);
} }

View file

@ -4,11 +4,9 @@
#include <cstring> #include <cstring>
namespace zeek::detail namespace zeek::detail {
{
void IntSet::Expand(unsigned int i) void IntSet::Expand(unsigned int i) {
{
unsigned int newsize = i / 8 + 1; unsigned int newsize = i / 8 + 1;
unsigned char* newset = new unsigned char[newsize]; unsigned char* newset = new unsigned char[newsize];

View file

@ -8,11 +8,9 @@
#include <cstring> #include <cstring>
namespace zeek::detail namespace zeek::detail {
{
class IntSet class IntSet {
{
public: public:
// n is a hint for the value of the largest integer. // n is a hint for the value of the largest integer.
explicit IntSet(unsigned int n = 1); explicit IntSet(unsigned int n = 1);
@ -31,42 +29,30 @@ private:
unsigned char* set; unsigned char* set;
}; };
inline IntSet::IntSet(unsigned int n) inline IntSet::IntSet(unsigned int n) {
{
size = n / 8 + 1; size = n / 8 + 1;
set = new unsigned char[size]; set = new unsigned char[size];
memset(set, 0, size); memset(set, 0, size);
} }
inline IntSet::~IntSet() inline IntSet::~IntSet() { delete[] set; }
{
delete[] set;
}
inline void IntSet::Insert(unsigned int i) inline void IntSet::Insert(unsigned int i) {
{
if ( i / 8 >= size ) if ( i / 8 >= size )
Expand(i); Expand(i);
set[i / 8] |= (1 << (i % 8)); set[i / 8] |= (1 << (i % 8));
} }
inline void IntSet::Remove(unsigned int i) inline void IntSet::Remove(unsigned int i) {
{
if ( i / 8 >= size ) if ( i / 8 >= size )
Expand(i); Expand(i);
else else
set[i / 8] &= ~(1 << (i % 8)); set[i / 8] &= ~(1 << (i % 8));
} }
inline bool IntSet::Contains(unsigned int i) const inline bool IntSet::Contains(unsigned int i) const { return i / 8 < size ? set[i / 8] & (1 << (i % 8)) : false; }
{
return i / 8 < size ? set[i / 8] & (1 << (i % 8)) : false;
}
inline void IntSet::Clear() inline void IntSet::Clear() { memset(set, 0, size); }
{
memset(set, 0, size);
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -8,24 +8,19 @@
#include "Obj.h" #include "Obj.h"
namespace zeek namespace zeek {
{
/** /**
* A tag class for the #IntrusivePtr constructor which means: adopt * A tag class for the #IntrusivePtr constructor which means: adopt
* the reference from the caller. * the reference from the caller.
*/ */
struct AdoptRef struct AdoptRef {};
{
};
/** /**
* A tag class for the #IntrusivePtr constructor which means: create a * A tag class for the #IntrusivePtr constructor which means: create a
* new reference to the object. * new reference to the object.
*/ */
struct NewRef struct NewRef {};
{
};
/** /**
* This has to be forward declared and known here in order for us to be able * This has to be forward declared and known here in order for us to be able
@ -55,8 +50,8 @@ class OpaqueVal;
* should use a smart pointer whenever possible to reduce boilerplate code and * should use a smart pointer whenever possible to reduce boilerplate code and
* increase robustness of the code (in particular w.r.t. exceptions). * increase robustness of the code (in particular w.r.t. exceptions).
*/ */
template <class T> class IntrusivePtr template<class T>
{ class IntrusivePtr {
public: public:
// -- member types // -- member types
@ -74,8 +69,7 @@ public:
constexpr IntrusivePtr() noexcept = default; constexpr IntrusivePtr() noexcept = default;
constexpr IntrusivePtr(std::nullptr_t) noexcept : IntrusivePtr() constexpr IntrusivePtr(std::nullptr_t) noexcept : IntrusivePtr() {
{
// nop // nop
} }
@ -97,29 +91,24 @@ public:
* *
* @param raw_ptr Pointer to the shared object. * @param raw_ptr Pointer to the shared object.
*/ */
IntrusivePtr(NewRef, pointer raw_ptr) noexcept : ptr_(raw_ptr) IntrusivePtr(NewRef, pointer raw_ptr) noexcept : ptr_(raw_ptr) {
{
if ( ptr_ ) if ( ptr_ )
Ref(ptr_); Ref(ptr_);
} }
IntrusivePtr(IntrusivePtr&& other) noexcept : ptr_(other.release()) IntrusivePtr(IntrusivePtr&& other) noexcept : ptr_(other.release()) {
{
// nop // nop
} }
IntrusivePtr(const IntrusivePtr& other) noexcept : IntrusivePtr(NewRef{}, other.get()) {} IntrusivePtr(const IntrusivePtr& other) noexcept : IntrusivePtr(NewRef{}, other.get()) {}
template<class U, class = std::enable_if_t<std::is_convertible_v<U*, T*>>> template<class U, class = std::enable_if_t<std::is_convertible_v<U*, T*>>>
IntrusivePtr(IntrusivePtr<U> other) noexcept : ptr_(other.release()) IntrusivePtr(IntrusivePtr<U> other) noexcept : ptr_(other.release()) {
{
// nop // nop
} }
~IntrusivePtr() ~IntrusivePtr() {
{ if ( ptr_ ) {
if ( ptr_ )
{
// Specializing `OpaqueVal` as MSVC compiler does not detect it // Specializing `OpaqueVal` as MSVC compiler does not detect it
// inheriting from `zeek::Obj` so we have to do that manually. // inheriting from `zeek::Obj` so we have to do that manually.
if constexpr ( std::is_same_v<T, OpaqueVal> ) if constexpr ( std::is_same_v<T, OpaqueVal> )
@ -131,8 +120,7 @@ public:
void swap(IntrusivePtr& other) noexcept { std::swap(ptr_, other.ptr_); } void swap(IntrusivePtr& other) noexcept { std::swap(ptr_, other.ptr_); }
friend void swap(IntrusivePtr& a, IntrusivePtr& b) noexcept friend void swap(IntrusivePtr& a, IntrusivePtr& b) noexcept {
{
using std::swap; using std::swap;
swap(a.ptr_, b.ptr_); swap(a.ptr_, b.ptr_);
} }
@ -144,23 +132,19 @@ public:
*/ */
pointer release() noexcept { return std::exchange(ptr_, nullptr); } pointer release() noexcept { return std::exchange(ptr_, nullptr); }
IntrusivePtr& operator=(const IntrusivePtr& other) noexcept IntrusivePtr& operator=(const IntrusivePtr& other) noexcept {
{
IntrusivePtr tmp{other}; IntrusivePtr tmp{other};
swap(tmp); swap(tmp);
return *this; return *this;
} }
IntrusivePtr& operator=(IntrusivePtr&& other) noexcept IntrusivePtr& operator=(IntrusivePtr&& other) noexcept {
{
swap(other); swap(other);
return *this; return *this;
} }
IntrusivePtr& operator=(std::nullptr_t) noexcept IntrusivePtr& operator=(std::nullptr_t) noexcept {
{ if ( ptr_ ) {
if ( ptr_ )
{
Unref(ptr_); Unref(ptr_);
ptr_ = nullptr; ptr_ = nullptr;
} }
@ -189,8 +173,8 @@ private:
* @note This function assumes that any @c T starts with a reference count of 1. * @note This function assumes that any @c T starts with a reference count of 1.
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T, class... Ts> IntrusivePtr<T> make_intrusive(Ts&&... args) template<class T, class... Ts>
{ IntrusivePtr<T> make_intrusive(Ts&&... args) {
// Assumes that objects start with a reference count of 1! // Assumes that objects start with a reference count of 1!
return {AdoptRef{}, new T(std::forward<Ts>(args)...)}; return {AdoptRef{}, new T(std::forward<Ts>(args)...)};
} }
@ -201,8 +185,8 @@ template <class T, class... Ts> IntrusivePtr<T> make_intrusive(Ts&&... args)
* @param p The pointer of type @c U to cast to another type, @c T. * @param p The pointer of type @c U to cast to another type, @c T.
* @return The pointer, as cast to type @c T. * @return The pointer, as cast to type @c T.
*/ */
template <class T, class U> IntrusivePtr<T> cast_intrusive(IntrusivePtr<U> p) noexcept template<class T, class U>
{ IntrusivePtr<T> cast_intrusive(IntrusivePtr<U> p) noexcept {
return {AdoptRef{}, static_cast<T*>(p.release())}; return {AdoptRef{}, static_cast<T*>(p.release())};
} }
@ -211,32 +195,32 @@ template <class T, class U> IntrusivePtr<T> cast_intrusive(IntrusivePtr<U> p) no
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator==(const zeek::IntrusivePtr<T>& x, std::nullptr_t) template<class T>
{ bool operator==(const zeek::IntrusivePtr<T>& x, std::nullptr_t) {
return ! x; return ! x;
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator==(std::nullptr_t, const zeek::IntrusivePtr<T>& x) template<class T>
{ bool operator==(std::nullptr_t, const zeek::IntrusivePtr<T>& x) {
return ! x; return ! x;
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator!=(const zeek::IntrusivePtr<T>& x, std::nullptr_t) template<class T>
{ bool operator!=(const zeek::IntrusivePtr<T>& x, std::nullptr_t) {
return static_cast<bool>(x); return static_cast<bool>(x);
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator!=(std::nullptr_t, const zeek::IntrusivePtr<T>& x) template<class T>
{ bool operator!=(std::nullptr_t, const zeek::IntrusivePtr<T>& x) {
return static_cast<bool>(x); return static_cast<bool>(x);
} }
@ -245,32 +229,32 @@ template <class T> bool operator!=(std::nullptr_t, const zeek::IntrusivePtr<T>&
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator==(const zeek::IntrusivePtr<T>& x, const T* y) template<class T>
{ bool operator==(const zeek::IntrusivePtr<T>& x, const T* y) {
return x.get() == y; return x.get() == y;
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator==(const T* x, const zeek::IntrusivePtr<T>& y) template<class T>
{ bool operator==(const T* x, const zeek::IntrusivePtr<T>& y) {
return x == y.get(); return x == y.get();
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator!=(const zeek::IntrusivePtr<T>& x, const T* y) template<class T>
{ bool operator!=(const zeek::IntrusivePtr<T>& x, const T* y) {
return x.get() != y; return x.get() != y;
} }
/** /**
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template <class T> bool operator!=(const T* x, const zeek::IntrusivePtr<T>& y) template<class T>
{ bool operator!=(const T* x, const zeek::IntrusivePtr<T>& y) {
return x != y.get(); return x != y.get();
} }
@ -283,9 +267,7 @@ template <class T> bool operator!=(const T* x, const zeek::IntrusivePtr<T>& y)
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template<class T, class U> template<class T, class U>
auto operator==(const zeek::IntrusivePtr<T>& x, const zeek::IntrusivePtr<U>& y) auto operator==(const zeek::IntrusivePtr<T>& x, const zeek::IntrusivePtr<U>& y) -> decltype(x.get() == y.get()) {
-> decltype(x.get() == y.get())
{
return x.get() == y.get(); return x.get() == y.get();
} }
@ -293,9 +275,7 @@ auto operator==(const zeek::IntrusivePtr<T>& x, const zeek::IntrusivePtr<U>& y)
* @relates IntrusivePtr * @relates IntrusivePtr
*/ */
template<class T, class U> template<class T, class U>
auto operator!=(const zeek::IntrusivePtr<T>& x, const zeek::IntrusivePtr<U>& y) auto operator!=(const zeek::IntrusivePtr<T>& x, const zeek::IntrusivePtr<U>& y) -> decltype(x.get() != y.get()) {
-> decltype(x.get() != y.get())
{
return x.get() != y.get(); return x.get() != y.get();
} }
@ -303,14 +283,10 @@ auto operator!=(const zeek::IntrusivePtr<T>& x, const zeek::IntrusivePtr<U>& y)
// -- hashing ------------------------------------------------ // -- hashing ------------------------------------------------
namespace std namespace std {
{ template<class T>
template <class T> struct hash<zeek::IntrusivePtr<T>> struct hash<zeek::IntrusivePtr<T>> {
{
// Hash of intrusive pointer is the same as hash of the raw pointer it holds. // Hash of intrusive pointer is the same as hash of the raw pointer it holds.
size_t operator()(const zeek::IntrusivePtr<T>& v) const noexcept size_t operator()(const zeek::IntrusivePtr<T>& v) const noexcept { return std::hash<T*>{}(v.get()); }
{
return std::hash<T*>{}(v.get());
}
}; };
} } // namespace std

View file

@ -2,8 +2,7 @@
#include "zeek/3rdparty/doctest.h" #include "zeek/3rdparty/doctest.h"
TEST_CASE("list construction") TEST_CASE("list construction") {
{
zeek::List<int> list; zeek::List<int> list;
CHECK(list.empty()); CHECK(list.empty());
@ -12,8 +11,7 @@ TEST_CASE("list construction")
CHECK(list2.max() == 10); CHECK(list2.max() == 10);
} }
TEST_CASE("list operation") TEST_CASE("list operation") {
{
zeek::List<int> list({1, 2, 3}); zeek::List<int> list({1, 2, 3});
CHECK(list.size() == 3); CHECK(list.size() == 3);
CHECK(list.max() == 3); CHECK(list.max() == 3);
@ -84,8 +82,7 @@ TEST_CASE("list operation")
CHECK(list.max() == 0); CHECK(list.max() == 0);
} }
TEST_CASE("list iteration") TEST_CASE("list iteration") {
{
zeek::List<int> list({1, 2, 3, 4}); zeek::List<int> list({1, 2, 3, 4});
int index = 1; int index = 1;
@ -97,8 +94,7 @@ TEST_CASE("list iteration")
CHECK(*it == index); CHECK(*it == index);
} }
TEST_CASE("plists") TEST_CASE("plists") {
{
zeek::PList<int> list; zeek::PList<int> list;
list.push_back(new int{1}); list.push_back(new int{1});
list.push_back(new int{2}); list.push_back(new int{2});
@ -116,8 +112,7 @@ TEST_CASE("plists")
list.clear(); list.clear();
} }
TEST_CASE("unordered list operation") TEST_CASE("unordered list operation") {
{
zeek::List<int, zeek::ListOrder::UNORDERED> list({1, 2, 3, 4}); zeek::List<int, zeek::ListOrder::UNORDERED> list({1, 2, 3, 4});
CHECK(list.size() == 4); CHECK(list.size() == 4);

View file

@ -27,28 +27,21 @@
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek namespace zeek {
{
enum class ListOrder : int enum class ListOrder : int { ORDERED, UNORDERED };
{
ORDERED,
UNORDERED
};
template <typename T, ListOrder Order = ListOrder::ORDERED> class List template<typename T, ListOrder Order = ListOrder::ORDERED>
{ class List {
public: public:
constexpr static int DEFAULT_LIST_SIZE = 10; constexpr static int DEFAULT_LIST_SIZE = 10;
constexpr static int LIST_GROWTH_FACTOR = 2; constexpr static int LIST_GROWTH_FACTOR = 2;
~List() { free(entries); } ~List() { free(entries); }
explicit List(int size = 0) explicit List(int size = 0) {
{
num_entries = 0; num_entries = 0;
if ( size <= 0 ) if ( size <= 0 ) {
{
max_entries = 0; max_entries = 0;
entries = nullptr; entries = nullptr;
return; return;
@ -59,8 +52,7 @@ public:
entries = (T*)util::safe_malloc(max_entries * sizeof(T)); entries = (T*)util::safe_malloc(max_entries * sizeof(T));
} }
List(const List& b) List(const List& b) {
{
max_entries = b.max_entries; max_entries = b.max_entries;
num_entries = b.num_entries; num_entries = b.num_entries;
@ -73,8 +65,7 @@ public:
entries[i] = b.entries[i]; entries[i] = b.entries[i];
} }
List(List&& b) List(List&& b) {
{
entries = b.entries; entries = b.entries;
num_entries = b.num_entries; num_entries = b.num_entries;
max_entries = b.max_entries; max_entries = b.max_entries;
@ -83,8 +74,7 @@ public:
b.num_entries = b.max_entries = 0; b.num_entries = b.max_entries = 0;
} }
List(const T* arr, int n) List(const T* arr, int n) {
{
num_entries = max_entries = n; num_entries = max_entries = n;
entries = (T*)util::safe_malloc(max_entries * sizeof(T)); entries = (T*)util::safe_malloc(max_entries * sizeof(T));
memcpy(entries, arr, n * sizeof(T)); memcpy(entries, arr, n * sizeof(T));
@ -92,8 +82,7 @@ public:
List(std::initializer_list<T> il) : List(il.begin(), il.size()) {} List(std::initializer_list<T> il) : List(il.begin(), il.size()) {}
List& operator=(const List& b) List& operator=(const List& b) {
{
if ( this == &b ) if ( this == &b )
return *this; return *this;
@ -113,8 +102,7 @@ public:
return *this; return *this;
} }
List& operator=(List&& b) List& operator=(List&& b) {
{
if ( this == &b ) if ( this == &b )
return *this; return *this;
@ -148,8 +136,7 @@ public:
if ( new_size < num_entries ) if ( new_size < num_entries )
new_size = num_entries; // do not lose any entries new_size = num_entries; // do not lose any entries
if ( new_size != max_entries ) if ( new_size != max_entries ) {
{
entries = (T*)util::safe_realloc((void*)entries, sizeof(T) * new_size); entries = (T*)util::safe_realloc((void*)entries, sizeof(T) * new_size);
if ( entries ) if ( entries )
max_entries = new_size; max_entries = new_size;
@ -160,8 +147,7 @@ public:
return max_entries; return max_entries;
} }
void push_front(const T& a) void push_front(const T& a) {
{
if ( num_entries == max_entries ) if ( num_entries == max_entries )
resize(max_entries ? max_entries * LIST_GROWTH_FACTOR : DEFAULT_LIST_SIZE); resize(max_entries ? max_entries * LIST_GROWTH_FACTOR : DEFAULT_LIST_SIZE);
@ -172,8 +158,7 @@ public:
entries[0] = a; entries[0] = a;
} }
void push_back(const T& a) void push_back(const T& a) {
{
if ( num_entries == max_entries ) if ( num_entries == max_entries )
resize(max_entries ? max_entries * LIST_GROWTH_FACTOR : DEFAULT_LIST_SIZE); resize(max_entries ? max_entries * LIST_GROWTH_FACTOR : DEFAULT_LIST_SIZE);
@ -196,8 +181,7 @@ public:
bool remove(const T& a) // delete entry from list bool remove(const T& a) // delete entry from list
{ {
int pos = member_pos(a); int pos = member_pos(a);
if ( pos != -1 ) if ( pos != -1 ) {
{
remove_nth(pos); remove_nth(pos);
return true; return true;
} }
@ -214,15 +198,13 @@ public:
// For data where we don't care about ordering, we don't care about keeping // For data where we don't care about ordering, we don't care about keeping
// the list in the same order when removing an element. Just swap the last // the list in the same order when removing an element. Just swap the last
// element with the element being removed. // element with the element being removed.
if constexpr ( Order == ListOrder::ORDERED ) if constexpr ( Order == ListOrder::ORDERED ) {
{
--num_entries; --num_entries;
for ( ; n < num_entries; ++n ) for ( ; n < num_entries; ++n )
entries[n] = entries[n + 1]; entries[n] = entries[n + 1];
} }
else else {
{
entries[n] = entries[num_entries - 1]; entries[n] = entries[num_entries - 1];
--num_entries; --num_entries;
} }
@ -231,15 +213,13 @@ public:
} }
// Return 0 if ent is not in the list, ent otherwise. // Return 0 if ent is not in the list, ent otherwise.
bool is_member(const T& a) const bool is_member(const T& a) const {
{
int pos = member_pos(a); int pos = member_pos(a);
return pos != -1; return pos != -1;
} }
// Returns -1 if ent is not in the list, otherwise its position. // Returns -1 if ent is not in the list, otherwise its position.
int member_pos(const T& e) const int member_pos(const T& e) const {
{
int i; int i;
for ( i = 0; i < length() && e != entries[i]; ++i ) for ( i = 0; i < length() && e != entries[i]; ++i )
; ;
@ -254,8 +234,7 @@ public:
T old_ent{}; T old_ent{};
if ( ent_index > num_entries - 1 ) if ( ent_index > num_entries - 1 ) { // replacement beyond the end of the list
{ // replacement beyond the end of the list
resize(ent_index + 1); resize(ent_index + 1);
for ( int i = num_entries; i < max_entries; ++i ) for ( int i = num_entries; i < max_entries; ++i )
@ -326,7 +305,8 @@ protected:
}; };
// Specialization of the List class to store pointers of a type. // Specialization of the List class to store pointers of a type.
template <typename T, ListOrder Order = ListOrder::ORDERED> using PList = List<T*, Order>; template<typename T, ListOrder Order = ListOrder::ORDERED>
using PList = List<T*, Order>;
// Popular type of list: list of strings. // Popular type of list: list of strings.
using name_list = PList<char>; using name_list = PList<char>;

View file

@ -10,13 +10,11 @@
#include "zeek/EquivClass.h" #include "zeek/EquivClass.h"
#include "zeek/IntSet.h" #include "zeek/IntSet.h"
namespace zeek::detail namespace zeek::detail {
{
static int nfa_state_id = 0; static int nfa_state_id = 0;
NFA_State::NFA_State(int arg_sym, EquivClass* ec) NFA_State::NFA_State(int arg_sym, EquivClass* ec) {
{
sym = arg_sym; sym = arg_sym;
ccl = nullptr; ccl = nullptr;
accept = NO_ACCEPT; accept = NO_ACCEPT;
@ -37,8 +35,7 @@ NFA_State::NFA_State(int arg_sym, EquivClass* ec)
ec->UniqueChar(sym); ec->UniqueChar(sym);
} }
NFA_State::NFA_State(CCL* arg_ccl) NFA_State::NFA_State(CCL* arg_ccl) {
{
sym = SYM_CCL; sym = SYM_CCL;
ccl = arg_ccl; ccl = arg_ccl;
accept = NO_ACCEPT; accept = NO_ACCEPT;
@ -48,8 +45,7 @@ NFA_State::NFA_State(CCL* arg_ccl)
epsclosure = nullptr; epsclosure = nullptr;
} }
NFA_State::~NFA_State() NFA_State::~NFA_State() {
{
for ( int i = 0; i < xtions.length(); ++i ) for ( int i = 0; i < xtions.length(); ++i )
if ( i > 0 || ! first_trans_is_back_ref ) if ( i > 0 || ! first_trans_is_back_ref )
Unref(xtions[i]); Unref(xtions[i]);
@ -57,16 +53,13 @@ NFA_State::~NFA_State()
delete epsclosure; delete epsclosure;
} }
void NFA_State::AddXtionsTo(NFA_state_list* ns) void NFA_State::AddXtionsTo(NFA_state_list* ns) {
{
for ( int i = 0; i < xtions.length(); ++i ) for ( int i = 0; i < xtions.length(); ++i )
ns->push_back(xtions[i]); ns->push_back(xtions[i]);
} }
NFA_State* NFA_State::DeepCopy() NFA_State* NFA_State::DeepCopy() {
{ if ( mark ) {
if ( mark )
{
Ref(mark); Ref(mark);
return mark; return mark;
} }
@ -80,18 +73,15 @@ NFA_State* NFA_State::DeepCopy()
return copy; return copy;
} }
void NFA_State::ClearMarks() void NFA_State::ClearMarks() {
{ if ( mark ) {
if ( mark )
{
SetMark(nullptr); SetMark(nullptr);
for ( int i = 0; i < xtions.length(); ++i ) for ( int i = 0; i < xtions.length(); ++i )
xtions[i]->ClearMarks(); xtions[i]->ClearMarks();
} }
} }
NFA_state_list* NFA_State::EpsilonClosure() NFA_state_list* NFA_State::EpsilonClosure() {
{
if ( epsclosure ) if ( epsclosure )
return epsclosure; return epsclosure;
@ -102,17 +92,13 @@ NFA_state_list* NFA_State::EpsilonClosure()
SetMark(this); SetMark(this);
int i; int i;
for ( i = 0; i < states.length(); ++i ) for ( i = 0; i < states.length(); ++i ) {
{
NFA_State* ns = states[i]; NFA_State* ns = states[i];
if ( ns->TransSym() == SYM_EPSILON ) if ( ns->TransSym() == SYM_EPSILON ) {
{
NFA_state_list* x = ns->Transitions(); NFA_state_list* x = ns->Transitions();
for ( int j = 0; j < x->length(); ++j ) for ( int j = 0; j < x->length(); ++j ) {
{
NFA_State* nxt = (*x)[j]; NFA_State* nxt = (*x)[j];
if ( ! nxt->Mark() ) if ( ! nxt->Mark() ) {
{
states.push_back(nxt); states.push_back(nxt);
nxt->SetMark(nxt); nxt->SetMark(nxt);
} }
@ -137,13 +123,9 @@ NFA_state_list* NFA_State::EpsilonClosure()
return epsclosure; return epsclosure;
} }
void NFA_State::Describe(ODesc* d) const void NFA_State::Describe(ODesc* d) const { d->Add("NFA state"); }
{
d->Add("NFA state");
}
void NFA_State::Dump(FILE* f) void NFA_State::Dump(FILE* f) {
{
if ( mark ) if ( mark )
return; return;
@ -157,32 +139,23 @@ void NFA_State::Dump(FILE* f)
xtions[i]->Dump(f); xtions[i]->Dump(f);
} }
NFA_Machine::NFA_Machine(NFA_State* first, NFA_State* final) NFA_Machine::NFA_Machine(NFA_State* first, NFA_State* final) {
{
first_state = first; first_state = first;
final_state = final ? final : first; final_state = final ? final : first;
eol = bol = 0; eol = bol = 0;
} }
NFA_Machine::~NFA_Machine() NFA_Machine::~NFA_Machine() { Unref(first_state); }
{
Unref(first_state);
}
void NFA_Machine::InsertEpsilon() void NFA_Machine::InsertEpsilon() {
{
NFA_State* eps = new EpsilonState(); NFA_State* eps = new EpsilonState();
eps->AddXtion(first_state); eps->AddXtion(first_state);
first_state = eps; first_state = eps;
} }
void NFA_Machine::AppendEpsilon() void NFA_Machine::AppendEpsilon() { AppendState(new EpsilonState()); }
{
AppendState(new EpsilonState());
}
void NFA_Machine::AddAccept(int accept_val) void NFA_Machine::AddAccept(int accept_val) {
{
// Hang the accepting number off an epsilon state. If it is associated // Hang the accepting number off an epsilon state. If it is associated
// with a state that has a non-epsilon out-transition, then the state // with a state that has a non-epsilon out-transition, then the state
// will accept BEFORE it makes that transition, i.e., one character // will accept BEFORE it makes that transition, i.e., one character
@ -194,8 +167,7 @@ void NFA_Machine::AddAccept(int accept_val)
final_state->SetAccept(accept_val); final_state->SetAccept(accept_val);
} }
void NFA_Machine::LinkCopies(int n) void NFA_Machine::LinkCopies(int n) {
{
if ( n <= 0 ) if ( n <= 0 )
return; return;
@ -214,8 +186,7 @@ void NFA_Machine::LinkCopies(int n)
delete[] copies; delete[] copies;
} }
NFA_Machine* NFA_Machine::DuplicateMachine() NFA_Machine* NFA_Machine::DuplicateMachine() {
{
NFA_State* new_first_state = first_state->DeepCopy(); NFA_State* new_first_state = first_state->DeepCopy();
NFA_Machine* new_m = new NFA_Machine(new_first_state, final_state->Mark()); NFA_Machine* new_m = new NFA_Machine(new_first_state, final_state->Mark());
first_state->ClearMarks(); first_state->ClearMarks();
@ -223,14 +194,12 @@ NFA_Machine* NFA_Machine::DuplicateMachine()
return new_m; return new_m;
} }
void NFA_Machine::AppendState(NFA_State* s) void NFA_Machine::AppendState(NFA_State* s) {
{
final_state->AddXtion(s); final_state->AddXtion(s);
final_state = s; final_state = s;
} }
void NFA_Machine::AppendMachine(NFA_Machine* m) void NFA_Machine::AppendMachine(NFA_Machine* m) {
{
AppendEpsilon(); AppendEpsilon();
final_state->AddXtion(m->FirstState()); final_state->AddXtion(m->FirstState());
final_state = m->FinalState(); final_state = m->FinalState();
@ -239,16 +208,14 @@ void NFA_Machine::AppendMachine(NFA_Machine* m)
Unref(m); Unref(m);
} }
void NFA_Machine::MakeOptional() void NFA_Machine::MakeOptional() {
{
InsertEpsilon(); InsertEpsilon();
AppendEpsilon(); AppendEpsilon();
first_state->AddXtion(final_state); first_state->AddXtion(final_state);
Ref(final_state); Ref(final_state);
} }
void NFA_Machine::MakePositiveClosure() void NFA_Machine::MakePositiveClosure() {
{
AppendEpsilon(); AppendEpsilon();
final_state->AddXtion(first_state); final_state->AddXtion(first_state);
@ -257,23 +224,20 @@ void NFA_Machine::MakePositiveClosure()
final_state->SetFirstTransIsBackRef(); final_state->SetFirstTransIsBackRef();
} }
void NFA_Machine::MakeRepl(int lower, int upper) void NFA_Machine::MakeRepl(int lower, int upper) {
{
NFA_Machine* dup = nullptr; NFA_Machine* dup = nullptr;
if ( upper > lower || upper == NO_UPPER_BOUND ) if ( upper > lower || upper == NO_UPPER_BOUND )
dup = DuplicateMachine(); dup = DuplicateMachine();
LinkCopies(lower - 1); LinkCopies(lower - 1);
if ( upper == NO_UPPER_BOUND ) if ( upper == NO_UPPER_BOUND ) {
{
dup->MakeClosure(); dup->MakeClosure();
AppendMachine(dup); AppendMachine(dup);
return; return;
} }
while ( upper > lower ) while ( upper > lower ) {
{
NFA_Machine* dup2; NFA_Machine* dup2;
if ( --upper == lower ) if ( --upper == lower )
// Don't need "dup" for any further copies // Don't need "dup" for any further copies
@ -286,19 +250,14 @@ void NFA_Machine::MakeRepl(int lower, int upper)
} }
} }
void NFA_Machine::Describe(ODesc* d) const void NFA_Machine::Describe(ODesc* d) const { d->Add("NFA machine"); }
{
d->Add("NFA machine");
}
void NFA_Machine::Dump(FILE* f) void NFA_Machine::Dump(FILE* f) {
{
first_state->Dump(f); first_state->Dump(f);
first_state->ClearMarks(); first_state->ClearMarks();
} }
NFA_Machine* make_alternate(NFA_Machine* m1, NFA_Machine* m2) NFA_Machine* make_alternate(NFA_Machine* m1, NFA_Machine* m2) {
{
if ( ! m1 ) if ( ! m1 )
return m2; return m2;
if ( ! m2 ) if ( ! m2 )
@ -324,23 +283,19 @@ NFA_Machine* make_alternate(NFA_Machine* m1, NFA_Machine* m2)
return new NFA_Machine(first, last); return new NFA_Machine(first, last);
} }
NFA_state_list* epsilon_closure(NFA_state_list* states) NFA_state_list* epsilon_closure(NFA_state_list* states) {
{
// We just keep one of this as it may get quite large. // We just keep one of this as it may get quite large.
static IntSet closuremap; static IntSet closuremap;
closuremap.Clear(); closuremap.Clear();
NFA_state_list* closure = new NFA_state_list; NFA_state_list* closure = new NFA_state_list;
for ( int i = 0; i < states->length(); ++i ) for ( int i = 0; i < states->length(); ++i ) {
{
NFA_state_list* stateclosure = (*states)[i]->EpsilonClosure(); NFA_state_list* stateclosure = (*states)[i]->EpsilonClosure();
for ( int j = 0; j < stateclosure->length(); ++j ) for ( int j = 0; j < stateclosure->length(); ++j ) {
{
NFA_State* ns = (*stateclosure)[j]; NFA_State* ns = (*stateclosure)[j];
if ( ! closuremap.Contains(ns->ID()) ) if ( ! closuremap.Contains(ns->ID()) ) {
{
closuremap.Insert(ns->ID()); closuremap.Insert(ns->ID());
closure->push_back(ns); closure->push_back(ns);
} }
@ -358,8 +313,7 @@ NFA_state_list* epsilon_closure(NFA_state_list* states)
return closure; return closure;
} }
bool NFA_state_cmp_neg(const NFA_State* v1, const NFA_State* v2) bool NFA_state_cmp_neg(const NFA_State* v1, const NFA_State* v2) {
{
if ( v1->ID() < v2->ID() ) if ( v1->ID() < v2->ID() )
return true; return true;
else else

View file

@ -16,13 +16,11 @@
#define SYM_EPSILON 259 #define SYM_EPSILON 259
#define SYM_CCL 260 #define SYM_CCL 260
namespace zeek namespace zeek {
{
class Func; class Func;
namespace detail namespace detail {
{
class CCL; class CCL;
class EquivClass; class EquivClass;
@ -30,8 +28,7 @@ class EquivClass;
class NFA_State; class NFA_State;
using NFA_state_list = PList<NFA_State>; using NFA_state_list = PList<NFA_State>;
class NFA_State : public Obj class NFA_State : public Obj {
{
public: public:
NFA_State(int sym, EquivClass* ec); NFA_State(int sym, EquivClass* ec);
explicit NFA_State(CCL* ccl); explicit NFA_State(CCL* ccl);
@ -79,14 +76,12 @@ protected:
NFA_State* mark; NFA_State* mark;
}; };
class EpsilonState : public NFA_State class EpsilonState : public NFA_State {
{
public: public:
EpsilonState() : NFA_State(SYM_EPSILON, nullptr) {} EpsilonState() : NFA_State(SYM_EPSILON, nullptr) {}
}; };
class NFA_Machine : public Obj class NFA_Machine : public Obj {
{
public: public:
explicit NFA_Machine(NFA_State* first, NFA_State* final = nullptr); explicit NFA_Machine(NFA_State* first, NFA_State* final = nullptr);
~NFA_Machine() override; ~NFA_Machine() override;
@ -98,8 +93,7 @@ public:
void AddAccept(int accept_val); void AddAccept(int accept_val);
void MakeClosure() void MakeClosure() {
{
MakePositiveClosure(); MakePositiveClosure();
MakeOptional(); MakeOptional();
} }

View file

@ -105,8 +105,7 @@ zeek::StringVal* cmd_line_bpf_filter;
zeek::StringVal* global_hash_seed; zeek::StringVal* global_hash_seed;
namespace zeek::detail namespace zeek::detail {
{
int watchdog_interval; int watchdog_interval;
@ -194,27 +193,24 @@ zeek_uint_t bits_per_uid;
zeek_uint_t tunnel_max_changes_per_connection; zeek_uint_t tunnel_max_changes_per_connection;
} // namespace zeek::detail. The namespace has be closed here before we include the netvar_def } // namespace zeek::detail
// files. // files.
// Because of how the BIF include files are built with namespaces already in them, // Because of how the BIF include files are built with namespaces already in them,
// these files need to be included separately before the namespace is opened below. // these files need to be included separately before the namespace is opened below.
static void bif_init_event_handlers() static void bif_init_event_handlers() {
{
#include "event.bif.netvar_init" #include "event.bif.netvar_init"
} }
static void bif_init_net_var() static void bif_init_net_var() {
{
#include "const.bif.netvar_init" #include "const.bif.netvar_init"
#include "packet_analysis.bif.netvar_init" #include "packet_analysis.bif.netvar_init"
#include "reporter.bif.netvar_init" #include "reporter.bif.netvar_init"
#include "supervisor.bif.netvar_init" #include "supervisor.bif.netvar_init"
} }
static void init_bif_types() static void init_bif_types() {
{
#include "types.bif.netvar_init" #include "types.bif.netvar_init"
} }
@ -226,16 +222,11 @@ static void init_bif_types()
#include "types.bif.netvar_def" #include "types.bif.netvar_def"
// Re-open the namespace now that the bif headers are all included. // Re-open the namespace now that the bif headers are all included.
namespace zeek::detail namespace zeek::detail {
{
void init_event_handlers() void init_event_handlers() { bif_init_event_handlers(); }
{
bif_init_event_handlers();
}
void init_general_global_var() void init_general_global_var() {
{
table_expire_interval = id::find_val("table_expire_interval")->AsInterval(); table_expire_interval = id::find_val("table_expire_interval")->AsInterval();
table_expire_delay = id::find_val("table_expire_delay")->AsInterval(); table_expire_delay = id::find_val("table_expire_delay")->AsInterval();
table_incremental_step = id::find_val("table_incremental_step")->AsCount(); table_incremental_step = id::find_val("table_incremental_step")->AsCount();
@ -246,14 +237,12 @@ void init_general_global_var()
bits_per_uid = id::find_val("bits_per_uid")->AsCount(); bits_per_uid = id::find_val("bits_per_uid")->AsCount();
} }
void init_builtin_types() void init_builtin_types() {
{
init_bif_types(); init_bif_types();
id::detail::init_types(); id::detail::init_types();
} }
void init_net_var() void init_net_var() {
{
bif_init_net_var(); bif_init_net_var();
ignore_checksums = id::find_val("ignore_checksums")->AsBool(); ignore_checksums = id::find_val("ignore_checksums")->AsBool();
@ -272,10 +261,8 @@ void init_net_var()
tcp_partial_close_delay = id::find_val("tcp_partial_close_delay")->AsInterval(); tcp_partial_close_delay = id::find_val("tcp_partial_close_delay")->AsInterval();
tcp_max_initial_window = id::find_val("tcp_max_initial_window")->AsCount(); tcp_max_initial_window = id::find_val("tcp_max_initial_window")->AsCount();
tcp_max_above_hole_without_any_acks = tcp_max_above_hole_without_any_acks = id::find_val("tcp_max_above_hole_without_any_acks")->AsCount();
id::find_val("tcp_max_above_hole_without_any_acks")->AsCount(); tcp_excessive_data_without_further_acks = id::find_val("tcp_excessive_data_without_further_acks")->AsCount();
tcp_excessive_data_without_further_acks =
id::find_val("tcp_excessive_data_without_further_acks")->AsCount();
tcp_max_old_segments = id::find_val("tcp_max_old_segments")->AsCount(); tcp_max_old_segments = id::find_val("tcp_max_old_segments")->AsCount();
non_analyzed_lifetime = id::find_val("non_analyzed_lifetime")->AsInterval(); non_analyzed_lifetime = id::find_val("non_analyzed_lifetime")->AsInterval();
@ -291,8 +278,7 @@ void init_net_var()
udp_content_deliver_all_orig = bool(id::find_val("udp_content_deliver_all_orig")->AsBool()); udp_content_deliver_all_orig = bool(id::find_val("udp_content_deliver_all_orig")->AsBool());
udp_content_deliver_all_resp = bool(id::find_val("udp_content_deliver_all_resp")->AsBool()); udp_content_deliver_all_resp = bool(id::find_val("udp_content_deliver_all_resp")->AsBool());
udp_content_delivery_ports_use_resp = bool( udp_content_delivery_ports_use_resp = bool(id::find_val("udp_content_delivery_ports_use_resp")->AsBool());
id::find_val("udp_content_delivery_ports_use_resp")->AsBool());
dns_session_timeout = id::find_val("dns_session_timeout")->AsInterval(); dns_session_timeout = id::find_val("dns_session_timeout")->AsInterval();
rpc_timeout = id::find_val("rpc_timeout")->AsInterval(); rpc_timeout = id::find_val("rpc_timeout")->AsInterval();
@ -345,8 +331,7 @@ void init_net_var()
dpd_late_match_stop = id::find_val("dpd_late_match_stop")->AsBool(); dpd_late_match_stop = id::find_val("dpd_late_match_stop")->AsBool();
dpd_ignore_ports = id::find_val("dpd_ignore_ports")->AsBool(); dpd_ignore_ports = id::find_val("dpd_ignore_ports")->AsBool();
tunnel_max_changes_per_connection = tunnel_max_changes_per_connection = id::find_val("Tunnel::max_changes_per_connection")->AsCount();
id::find_val("Tunnel::max_changes_per_connection")->AsCount();
} }
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -6,8 +6,7 @@
#include "zeek/Stats.h" #include "zeek/Stats.h"
#include "zeek/Val.h" #include "zeek/Val.h"
namespace zeek::detail namespace zeek::detail {
{
extern int watchdog_interval; extern int watchdog_interval;

View file

@ -8,42 +8,30 @@
zeek::notifier::detail::Registry zeek::notifier::detail::registry; zeek::notifier::detail::Registry zeek::notifier::detail::registry;
namespace zeek::notifier::detail namespace zeek::notifier::detail {
{
Receiver::Receiver() Receiver::Receiver() { DBG_LOG(DBG_NOTIFIERS, "creating receiver %p", this); }
{
DBG_LOG(DBG_NOTIFIERS, "creating receiver %p", this);
}
Receiver::~Receiver() Receiver::~Receiver() { DBG_LOG(DBG_NOTIFIERS, "deleting receiver %p", this); }
{
DBG_LOG(DBG_NOTIFIERS, "deleting receiver %p", this);
}
Registry::~Registry() Registry::~Registry() {
{
while ( registrations.begin() != registrations.end() ) while ( registrations.begin() != registrations.end() )
Unregister(registrations.begin()->first); Unregister(registrations.begin()->first);
} }
void Registry::Register(Modifiable* m, Receiver* r) void Registry::Register(Modifiable* m, Receiver* r) {
{
DBG_LOG(DBG_NOTIFIERS, "registering object %p for receiver %p", m, r); DBG_LOG(DBG_NOTIFIERS, "registering object %p for receiver %p", m, r);
registrations.insert({m, r}); registrations.insert({m, r});
++m->num_receivers; ++m->num_receivers;
} }
void Registry::Unregister(Modifiable* m, Receiver* r) void Registry::Unregister(Modifiable* m, Receiver* r) {
{
DBG_LOG(DBG_NOTIFIERS, "unregistering object %p from receiver %p", m, r); DBG_LOG(DBG_NOTIFIERS, "unregistering object %p from receiver %p", m, r);
auto x = registrations.equal_range(m); auto x = registrations.equal_range(m);
for ( auto i = x.first; i != x.second; i++ ) for ( auto i = x.first; i != x.second; i++ ) {
{ if ( i->second == r ) {
if ( i->second == r )
{
--i->first->num_receivers; --i->first->num_receivers;
registrations.erase(i); registrations.erase(i);
break; break;
@ -51,8 +39,7 @@ void Registry::Unregister(Modifiable* m, Receiver* r)
} }
} }
void Registry::Unregister(Modifiable* m) void Registry::Unregister(Modifiable* m) {
{
DBG_LOG(DBG_NOTIFIERS, "unregistering object %p from all notifiers", m); DBG_LOG(DBG_NOTIFIERS, "unregistering object %p from all notifiers", m);
auto x = registrations.equal_range(m); auto x = registrations.equal_range(m);
@ -62,8 +49,7 @@ void Registry::Unregister(Modifiable* m)
registrations.erase(x.first, x.second); registrations.erase(x.first, x.second);
} }
void Registry::Modified(Modifiable* m) void Registry::Modified(Modifiable* m) {
{
DBG_LOG(DBG_NOTIFIERS, "object %p has been modified", m); DBG_LOG(DBG_NOTIFIERS, "object %p has been modified", m);
auto x = registrations.equal_range(m); auto x = registrations.equal_range(m);
@ -71,8 +57,7 @@ void Registry::Modified(Modifiable* m)
i->second->Modified(m); i->second->Modified(m);
} }
void Registry::Terminate() void Registry::Terminate() {
{
std::set<Receiver*> receivers; std::set<Receiver*> receivers;
for ( auto& r : registrations ) for ( auto& r : registrations )
@ -82,8 +67,7 @@ void Registry::Terminate()
r->Terminate(); r->Terminate();
} }
Modifiable::~Modifiable() Modifiable::~Modifiable() {
{
if ( num_receivers ) if ( num_receivers )
registry.Unregister(this); registry.Unregister(this);
} }

View file

@ -10,14 +10,12 @@
#include <cstdint> #include <cstdint>
#include <unordered_map> #include <unordered_map>
namespace zeek::notifier::detail namespace zeek::notifier::detail {
{
class Modifiable; class Modifiable;
/** Interface class for receivers of notifications. */ /** Interface class for receivers of notifications. */
class Receiver class Receiver {
{
public: public:
Receiver(); Receiver();
virtual ~Receiver(); virtual ~Receiver();
@ -37,8 +35,7 @@ public:
}; };
/** Singleton class tracking all notification requests globally. */ /** Singleton class tracking all notification requests globally. */
class Registry class Registry {
{
public: public:
~Registry(); ~Registry();
@ -100,15 +97,13 @@ extern Registry registry;
* Base class for objects that can trigger notifications to receivers when * Base class for objects that can trigger notifications to receivers when
* modified. * modified.
*/ */
class Modifiable class Modifiable {
{
public: public:
/** /**
* Calling this method signals to all registered receivers that the * Calling this method signals to all registered receivers that the
* object has been modified. * object has been modified.
*/ */
void Modified() void Modified() {
{
if ( num_receivers ) if ( num_receivers )
registry.Modified(this); registry.Modified(this);
} }

View file

@ -11,18 +11,14 @@
#include "zeek/Func.h" #include "zeek/Func.h"
#include "zeek/plugin/Manager.h" #include "zeek/plugin/Manager.h"
namespace zeek namespace zeek {
{ namespace detail {
namespace detail
{
Location start_location("<start uninitialized>", 0, 0, 0, 0); Location start_location("<start uninitialized>", 0, 0, 0, 0);
Location end_location("<end uninitialized>", 0, 0, 0, 0); Location end_location("<end uninitialized>", 0, 0, 0, 0);
void Location::Describe(ODesc* d) const void Location::Describe(ODesc* d) const {
{ if ( filename ) {
if ( filename )
{
d->Add(filename); d->Add(filename);
if ( first_line == 0 ) if ( first_line == 0 )
@ -31,22 +27,19 @@ void Location::Describe(ODesc* d) const
d->AddSP(","); d->AddSP(",");
} }
if ( last_line != first_line ) if ( last_line != first_line ) {
{
d->Add("lines "); d->Add("lines ");
d->Add(first_line); d->Add(first_line);
d->Add("-"); d->Add("-");
d->Add(last_line); d->Add(last_line);
} }
else else {
{
d->Add("line "); d->Add("line ");
d->Add(first_line); d->Add(first_line);
} }
} }
bool Location::operator==(const Location& l) const bool Location::operator==(const Location& l) const {
{
if ( filename == l.filename || (filename && l.filename && util::streq(filename, l.filename)) ) if ( filename == l.filename || (filename && l.filename && util::streq(filename, l.filename)) )
return first_line == l.first_line && last_line == l.last_line; return first_line == l.first_line && last_line == l.last_line;
else else
@ -57,26 +50,21 @@ bool Location::operator==(const Location& l) const
int Obj::suppress_errors = 0; int Obj::suppress_errors = 0;
Obj::~Obj() Obj::~Obj() {
{
if ( notify_plugins ) if ( notify_plugins )
PLUGIN_HOOK_VOID(HOOK_OBJ_DTOR, HookObjDtor(this)); PLUGIN_HOOK_VOID(HOOK_OBJ_DTOR, HookObjDtor(this));
delete location; delete location;
} }
void Obj::Warn(const char* msg, const Obj* obj2, bool pinpoint_only, void Obj::Warn(const char* msg, const Obj* obj2, bool pinpoint_only, const detail::Location* expr_location) const {
const detail::Location* expr_location) const
{
ODesc d; ODesc d;
DoMsg(&d, msg, obj2, pinpoint_only, expr_location); DoMsg(&d, msg, obj2, pinpoint_only, expr_location);
reporter->Warning("%s", d.Description()); reporter->Warning("%s", d.Description());
reporter->PopLocation(); reporter->PopLocation();
} }
void Obj::Error(const char* msg, const Obj* obj2, bool pinpoint_only, void Obj::Error(const char* msg, const Obj* obj2, bool pinpoint_only, const detail::Location* expr_location) const {
const detail::Location* expr_location) const
{
if ( suppress_errors ) if ( suppress_errors )
return; return;
@ -86,8 +74,7 @@ void Obj::Error(const char* msg, const Obj* obj2, bool pinpoint_only,
reporter->PopLocation(); reporter->PopLocation();
} }
void Obj::BadTag(const char* msg, const char* t1, const char* t2) const void Obj::BadTag(const char* msg, const char* t1, const char* t2) const {
{
char out[512]; char out[512];
if ( t2 ) if ( t2 )
@ -103,8 +90,7 @@ void Obj::BadTag(const char* msg, const char* t1, const char* t2) const
reporter->PopLocation(); reporter->PopLocation();
} }
void Obj::Internal(const char* msg) const void Obj::Internal(const char* msg) const {
{
ODesc d; ODesc d;
DoMsg(&d, msg); DoMsg(&d, msg);
auto rcs = render_call_stack(); auto rcs = render_call_stack();
@ -117,18 +103,15 @@ void Obj::Internal(const char* msg) const
reporter->PopLocation(); reporter->PopLocation();
} }
void Obj::InternalWarning(const char* msg) const void Obj::InternalWarning(const char* msg) const {
{
ODesc d; ODesc d;
DoMsg(&d, msg); DoMsg(&d, msg);
reporter->InternalWarning("%s", d.Description()); reporter->InternalWarning("%s", d.Description());
reporter->PopLocation(); reporter->PopLocation();
} }
void Obj::AddLocation(ODesc* d) const void Obj::AddLocation(ODesc* d) const {
{ if ( ! location ) {
if ( ! location )
{
d->Add("<no location>"); d->Add("<no location>");
return; return;
} }
@ -136,8 +119,7 @@ void Obj::AddLocation(ODesc* d) const
location->Describe(d); location->Describe(d);
} }
bool Obj::SetLocationInfo(const detail::Location* start, const detail::Location* end) bool Obj::SetLocationInfo(const detail::Location* start, const detail::Location* end) {
{
if ( ! start || ! end ) if ( ! start || ! end )
return false; return false;
@ -150,14 +132,13 @@ bool Obj::SetLocationInfo(const detail::Location* start, const detail::Location*
delete location; delete location;
location = new detail::Location(start->filename, start->first_line, end->last_line, location =
start->first_column, end->last_column); new detail::Location(start->filename, start->first_line, end->last_line, start->first_column, end->last_column);
return true; return true;
} }
void Obj::UpdateLocationEndInfo(const detail::Location& end) void Obj::UpdateLocationEndInfo(const detail::Location& end) {
{
if ( ! location ) if ( ! location )
SetLocationInfo(&end, &end); SetLocationInfo(&end, &end);
@ -166,16 +147,14 @@ void Obj::UpdateLocationEndInfo(const detail::Location& end)
} }
void Obj::DoMsg(ODesc* d, const char s1[], const Obj* obj2, bool pinpoint_only, void Obj::DoMsg(ODesc* d, const char s1[], const Obj* obj2, bool pinpoint_only,
const detail::Location* expr_location) const const detail::Location* expr_location) const {
{
d->SetShort(); d->SetShort();
d->Add(s1); d->Add(s1);
PinPoint(d, obj2, pinpoint_only); PinPoint(d, obj2, pinpoint_only);
const detail::Location* loc2 = nullptr; const detail::Location* loc2 = nullptr;
if ( obj2 && obj2->GetLocationInfo() != &detail::no_location && if ( obj2 && obj2->GetLocationInfo() != &detail::no_location && *obj2->GetLocationInfo() != *GetLocationInfo() )
*obj2->GetLocationInfo() != *GetLocationInfo() )
loc2 = obj2->GetLocationInfo(); loc2 = obj2->GetLocationInfo();
else if ( expr_location ) else if ( expr_location )
loc2 = expr_location; loc2 = expr_location;
@ -183,12 +162,10 @@ void Obj::DoMsg(ODesc* d, const char s1[], const Obj* obj2, bool pinpoint_only,
reporter->PushLocation(GetLocationInfo(), loc2); reporter->PushLocation(GetLocationInfo(), loc2);
} }
void Obj::PinPoint(ODesc* d, const Obj* obj2, bool pinpoint_only) const void Obj::PinPoint(ODesc* d, const Obj* obj2, bool pinpoint_only) const {
{
d->Add(" ("); d->Add(" (");
Describe(d); Describe(d);
if ( obj2 && ! pinpoint_only ) if ( obj2 && ! pinpoint_only ) {
{
d->Add(" and "); d->Add(" and ");
obj2->Describe(d); obj2->Describe(d);
} }
@ -196,23 +173,18 @@ void Obj::PinPoint(ODesc* d, const Obj* obj2, bool pinpoint_only) const
d->Add(")"); d->Add(")");
} }
void Obj::Print() const void Obj::Print() const {
{
static File fstderr(stderr); static File fstderr(stderr);
ODesc d(DESC_READABLE, &fstderr); ODesc d(DESC_READABLE, &fstderr);
Describe(&d); Describe(&d);
d.Add("\n"); d.Add("\n");
} }
void bad_ref(int type) void bad_ref(int type) {
{
reporter->InternalError("bad reference count [%d]", type); reporter->InternalError("bad reference count [%d]", type);
abort(); abort();
} }
void obj_delete_func(void* v) void obj_delete_func(void* v) { Unref((Obj*)v); }
{
Unref((Obj*)v);
}
} // namespace zeek } // namespace zeek

View file

@ -6,22 +6,16 @@
#include <climits> #include <climits>
namespace zeek namespace zeek {
{
class ODesc; class ODesc;
namespace detail namespace detail {
{
class Location final class Location final {
{
public: public:
constexpr Location(const char* fname, int line_f, int line_l, int col_f, int col_l) noexcept constexpr Location(const char* fname, int line_f, int line_l, int col_f, int col_l) noexcept
: filename(fname), first_line(line_f), last_line(line_l), first_column(col_f), : filename(fname), first_line(line_f), last_line(line_l), first_column(col_f), last_column(col_l) {}
last_column(col_l)
{
}
Location() = default; Location() = default;
@ -48,24 +42,18 @@ extern Location start_location;
extern Location end_location; extern Location end_location;
// Used by parser to set the above. // Used by parser to set the above.
inline void set_location(const Location loc) inline void set_location(const Location loc) { start_location = end_location = loc; }
{
start_location = end_location = loc;
}
inline void set_location(const Location start, const Location end) inline void set_location(const Location start, const Location end) {
{
start_location = start; start_location = start;
end_location = end; end_location = end;
} }
} // namespace detail } // namespace detail
class Obj class Obj {
{
public: public:
Obj() Obj() {
{
// A bit of a hack. We'd like to associate location // A bit of a hack. We'd like to associate location
// information with every object created when parsing, // information with every object created when parsing,
// since for them, the location is generally well-defined. // since for them, the location is generally well-defined.
@ -114,10 +102,7 @@ public:
void AddLocation(ODesc* d) const; void AddLocation(ODesc* d) const;
// Get location info for debugging. // Get location info for debugging.
virtual const detail::Location* GetLocationInfo() const virtual const detail::Location* GetLocationInfo() const { return location ? location : &detail::no_location; }
{
return location ? location : &detail::no_location;
}
virtual bool SetLocationInfo(const detail::Location* loc) { return SetLocationInfo(loc, loc); } virtual bool SetLocationInfo(const detail::Location* loc) { return SetLocationInfo(loc, loc); }
@ -135,8 +120,7 @@ public:
// Helper class to temporarily suppress errors // Helper class to temporarily suppress errors
// as long as there exist any instances. // as long as there exist any instances.
class SuppressErrors class SuppressErrors {
{
public: public:
SuppressErrors() { ++Obj::suppress_errors; } SuppressErrors() { ++Obj::suppress_errors; }
~SuppressErrors() { --Obj::suppress_errors; } ~SuppressErrors() { --Obj::suppress_errors; }
@ -167,25 +151,19 @@ private:
// Sometimes useful when dealing with Obj subclasses that have their // Sometimes useful when dealing with Obj subclasses that have their
// own (protected) versions of Error. // own (protected) versions of Error.
inline void Error(const Obj* o, const char* msg) inline void Error(const Obj* o, const char* msg) { o->Error(msg); }
{
o->Error(msg);
}
[[noreturn]] extern void bad_ref(int type); [[noreturn]] extern void bad_ref(int type);
inline void Ref(Obj* o) inline void Ref(Obj* o) {
{
if ( ++(o->ref_cnt) <= 1 ) if ( ++(o->ref_cnt) <= 1 )
bad_ref(0); bad_ref(0);
if ( o->ref_cnt == INT_MAX ) if ( o->ref_cnt == INT_MAX )
bad_ref(1); bad_ref(1);
} }
inline void Unref(Obj* o) inline void Unref(Obj* o) {
{ if ( o && --o->ref_cnt <= 0 ) {
if ( o && --o->ref_cnt <= 0 )
{
if ( o->ref_cnt < 0 ) if ( o->ref_cnt < 0 )
bad_ref(2); bad_ref(2);
delete o; delete o;

View file

@ -23,14 +23,12 @@
#include "zeek/probabilistic/BloomFilter.h" #include "zeek/probabilistic/BloomFilter.h"
#include "zeek/probabilistic/CardinalityCounter.h" #include "zeek/probabilistic/CardinalityCounter.h"
namespace zeek namespace zeek {
{
// Helper to retrieve a broker value out of a broker::vector at a specified // Helper to retrieve a broker value out of a broker::vector at a specified
// index, and casted to the expected destination type. // index, and casted to the expected destination type.
template<typename S, typename V, typename D> template<typename S, typename V, typename D>
inline bool get_vector_idx(const V& v, unsigned int i, D* dst) inline bool get_vector_idx(const V& v, unsigned int i, D* dst) {
{
if ( i >= v.size() ) if ( i >= v.size() )
return false; return false;
@ -42,33 +40,28 @@ inline bool get_vector_idx(const V& v, unsigned int i, D* dst)
return true; return true;
} }
OpaqueMgr* OpaqueMgr::mgr() OpaqueMgr* OpaqueMgr::mgr() {
{
static OpaqueMgr mgr; static OpaqueMgr mgr;
return &mgr; return &mgr;
} }
OpaqueVal::OpaqueVal(OpaqueTypePtr t) : Val(std::move(t)) {} OpaqueVal::OpaqueVal(OpaqueTypePtr t) : Val(std::move(t)) {}
const std::string& OpaqueMgr::TypeID(const OpaqueVal* v) const const std::string& OpaqueMgr::TypeID(const OpaqueVal* v) const {
{
auto x = _types.find(v->OpaqueName()); auto x = _types.find(v->OpaqueName());
if ( x == _types.end() ) if ( x == _types.end() )
reporter->InternalError("OpaqueMgr::TypeID: opaque type %s not registered", reporter->InternalError("OpaqueMgr::TypeID: opaque type %s not registered", v->OpaqueName());
v->OpaqueName());
return x->first; return x->first;
} }
OpaqueValPtr OpaqueMgr::Instantiate(const std::string& id) const OpaqueValPtr OpaqueMgr::Instantiate(const std::string& id) const {
{
auto x = _types.find(id); auto x = _types.find(id);
return x != _types.end() ? (*x->second)() : nullptr; return x != _types.end() ? (*x->second)() : nullptr;
} }
broker::expected<broker::data> OpaqueVal::Serialize() const broker::expected<broker::data> OpaqueVal::Serialize() const {
{
auto type = OpaqueMgr::mgr()->TypeID(this); auto type = OpaqueMgr::mgr()->TypeID(this);
auto d = DoSerialize(); auto d = DoSerialize();
@ -78,8 +71,7 @@ broker::expected<broker::data> OpaqueVal::Serialize() const
return {broker::vector{std::move(type), std::move(*d)}}; return {broker::vector{std::move(type), std::move(*d)}};
} }
OpaqueValPtr OpaqueVal::Unserialize(const broker::data& data) OpaqueValPtr OpaqueVal::Unserialize(const broker::data& data) {
{
auto v = broker::get_if<broker::vector>(&data); auto v = broker::get_if<broker::vector>(&data);
if ( ! (v && v->size() == 2) ) if ( ! (v && v->size() == 2) )
@ -99,13 +91,11 @@ OpaqueValPtr OpaqueVal::Unserialize(const broker::data& data)
return val; return val;
} }
broker::expected<broker::data> OpaqueVal::SerializeType(const TypePtr& t) broker::expected<broker::data> OpaqueVal::SerializeType(const TypePtr& t) {
{
if ( t->InternalType() == TYPE_INTERNAL_ERROR ) if ( t->InternalType() == TYPE_INTERNAL_ERROR )
return broker::ec::invalid_data; return broker::ec::invalid_data;
if ( t->InternalType() == TYPE_INTERNAL_OTHER ) if ( t->InternalType() == TYPE_INTERNAL_OTHER ) {
{
// Serialize by name. // Serialize by name.
assert(t->GetName().size()); assert(t->GetName().size());
return {broker::vector{true, t->GetName()}}; return {broker::vector{true, t->GetName()}};
@ -115,8 +105,7 @@ broker::expected<broker::data> OpaqueVal::SerializeType(const TypePtr& t)
return {broker::vector{false, static_cast<uint64_t>(t->Tag())}}; return {broker::vector{false, static_cast<uint64_t>(t->Tag())}};
} }
TypePtr OpaqueVal::UnserializeType(const broker::data& data) TypePtr OpaqueVal::UnserializeType(const broker::data& data) {
{
auto v = broker::get_if<broker::vector>(&data); auto v = broker::get_if<broker::vector>(&data);
if ( ! (v && v->size() == 2) ) if ( ! (v && v->size() == 2) )
return nullptr; return nullptr;
@ -125,8 +114,7 @@ TypePtr OpaqueVal::UnserializeType(const broker::data& data)
if ( ! by_name ) if ( ! by_name )
return nullptr; return nullptr;
if ( *by_name ) if ( *by_name ) {
{
auto name = broker::get_if<std::string>(&(*v)[1]); auto name = broker::get_if<std::string>(&(*v)[1]);
if ( ! name ) if ( ! name )
return nullptr; return nullptr;
@ -148,8 +136,7 @@ TypePtr OpaqueVal::UnserializeType(const broker::data& data)
return base_type(static_cast<TypeTag>(*tag)); return base_type(static_cast<TypeTag>(*tag));
} }
ValPtr OpaqueVal::DoClone(CloneState* state) ValPtr OpaqueVal::DoClone(CloneState* state) {
{
auto d = OpaqueVal::Serialize(); auto d = OpaqueVal::Serialize();
if ( ! d ) if ( ! d )
return nullptr; return nullptr;
@ -158,23 +145,13 @@ ValPtr OpaqueVal::DoClone(CloneState* state)
return state->NewClone(this, std::move(rval)); return state->NewClone(this, std::move(rval));
} }
void OpaqueVal::ValDescribe(ODesc* d) const void OpaqueVal::ValDescribe(ODesc* d) const { d->Add(util::fmt("<opaque of %s>", OpaqueName())); }
{
d->Add(util::fmt("<opaque of %s>", OpaqueName()));
}
void OpaqueVal::ValDescribeReST(ODesc* d) const void OpaqueVal::ValDescribeReST(ODesc* d) const { d->Add(util::fmt("<opaque of %s>", OpaqueName())); }
{
d->Add(util::fmt("<opaque of %s>", OpaqueName()));
}
bool HashVal::IsValid() const bool HashVal::IsValid() const { return valid; }
{
return valid;
}
bool HashVal::Init() bool HashVal::Init() {
{
if ( valid ) if ( valid )
return false; return false;
@ -182,8 +159,7 @@ bool HashVal::Init()
return valid; return valid;
} }
StringValPtr HashVal::Get() StringValPtr HashVal::Get() {
{
if ( ! valid ) if ( ! valid )
return val_mgr->EmptyString(); return val_mgr->EmptyString();
@ -192,8 +168,7 @@ StringValPtr HashVal::Get()
return result; return result;
} }
bool HashVal::Feed(const void* data, size_t size) bool HashVal::Feed(const void* data, size_t size) {
{
if ( valid ) if ( valid )
return DoFeed(data, size); return DoFeed(data, size);
@ -201,68 +176,50 @@ bool HashVal::Feed(const void* data, size_t size)
return false; return false;
} }
bool HashVal::DoInit() bool HashVal::DoInit() {
{
assert(! "missing implementation of DoInit()"); assert(! "missing implementation of DoInit()");
return false; return false;
} }
bool HashVal::DoFeed(const void*, size_t) bool HashVal::DoFeed(const void*, size_t) {
{
assert(! "missing implementation of DoFeed()"); assert(! "missing implementation of DoFeed()");
return false; return false;
} }
StringValPtr HashVal::DoGet() StringValPtr HashVal::DoGet() {
{
assert(! "missing implementation of DoGet()"); assert(! "missing implementation of DoGet()");
return val_mgr->EmptyString(); return val_mgr->EmptyString();
} }
HashVal::HashVal(OpaqueTypePtr t) : OpaqueVal(std::move(t)) HashVal::HashVal(OpaqueTypePtr t) : OpaqueVal(std::move(t)) { valid = false; }
{
valid = false;
}
MD5Val::MD5Val() : HashVal(md5_type) MD5Val::MD5Val() : HashVal(md5_type) { memset(&ctx, 0, sizeof(ctx)); }
{
memset(&ctx, 0, sizeof(ctx));
}
MD5Val::~MD5Val() MD5Val::~MD5Val() {
{
#if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER)
if ( IsValid() ) if ( IsValid() )
EVP_MD_CTX_free(ctx); EVP_MD_CTX_free(ctx);
#endif #endif
} }
void HashVal::digest_one(EVP_MD_CTX* h, const Val* v) void HashVal::digest_one(EVP_MD_CTX* h, const Val* v) {
{ if ( v->GetType()->Tag() == TYPE_STRING ) {
if ( v->GetType()->Tag() == TYPE_STRING )
{
const String* str = v->AsString(); const String* str = v->AsString();
detail::hash_update(h, str->Bytes(), str->Len()); detail::hash_update(h, str->Bytes(), str->Len());
} }
else else {
{
ODesc d(DESC_BINARY); ODesc d(DESC_BINARY);
v->Describe(&d); v->Describe(&d);
detail::hash_update(h, (const u_char*)d.Bytes(), d.Len()); detail::hash_update(h, (const u_char*)d.Bytes(), d.Len());
} }
} }
void HashVal::digest_one(EVP_MD_CTX* h, const ValPtr& v) void HashVal::digest_one(EVP_MD_CTX* h, const ValPtr& v) { digest_one(h, v.get()); }
{
digest_one(h, v.get());
}
ValPtr MD5Val::DoClone(CloneState* state) ValPtr MD5Val::DoClone(CloneState* state) {
{
auto out = make_intrusive<MD5Val>(); auto out = make_intrusive<MD5Val>();
if ( IsValid() ) if ( IsValid() ) {
{
if ( ! out->Init() ) if ( ! out->Init() )
return nullptr; return nullptr;
@ -276,8 +233,7 @@ ValPtr MD5Val::DoClone(CloneState* state)
return state->NewClone(this, std::move(out)); return state->NewClone(this, std::move(out));
} }
bool MD5Val::DoInit() bool MD5Val::DoInit() {
{
assert(! IsValid()); assert(! IsValid());
#if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER)
ctx = detail::hash_init(detail::Hash_MD5); ctx = detail::hash_init(detail::Hash_MD5);
@ -287,8 +243,7 @@ bool MD5Val::DoInit()
return true; return true;
} }
bool MD5Val::DoFeed(const void* data, size_t size) bool MD5Val::DoFeed(const void* data, size_t size) {
{
if ( ! IsValid() ) if ( ! IsValid() )
return false; return false;
@ -300,8 +255,7 @@ bool MD5Val::DoFeed(const void* data, size_t size)
return true; return true;
} }
StringValPtr MD5Val::DoGet() StringValPtr MD5Val::DoGet() {
{
if ( ! IsValid() ) if ( ! IsValid() )
return val_mgr->EmptyString(); return val_mgr->EmptyString();
@ -316,8 +270,7 @@ StringValPtr MD5Val::DoGet()
IMPLEMENT_OPAQUE_VALUE(MD5Val) IMPLEMENT_OPAQUE_VALUE(MD5Val)
broker::expected<broker::data> MD5Val::DoSerialize() const broker::expected<broker::data> MD5Val::DoSerialize() const {
{
if ( ! IsValid() ) if ( ! IsValid() )
return {broker::vector{false}}; return {broker::vector{false}};
@ -332,8 +285,7 @@ broker::expected<broker::data> MD5Val::DoSerialize() const
return {std::move(d)}; return {std::move(d)};
} }
bool MD5Val::DoUnserialize(const broker::data& data) bool MD5Val::DoUnserialize(const broker::data& data) {
{
auto d = broker::get_if<broker::vector>(&data); auto d = broker::get_if<broker::vector>(&data);
if ( ! d ) if ( ! d )
return false; return false;
@ -342,8 +294,7 @@ bool MD5Val::DoUnserialize(const broker::data& data)
if ( ! valid ) if ( ! valid )
return false; return false;
if ( ! *valid ) if ( ! *valid ) {
{
assert(! IsValid()); // default set by ctor assert(! IsValid()); // default set by ctor
return true; return true;
} }
@ -372,25 +323,19 @@ bool MD5Val::DoUnserialize(const broker::data& data)
return true; return true;
} }
SHA1Val::SHA1Val() : HashVal(sha1_type) SHA1Val::SHA1Val() : HashVal(sha1_type) { memset(&ctx, 0, sizeof(ctx)); }
{
memset(&ctx, 0, sizeof(ctx));
}
SHA1Val::~SHA1Val() SHA1Val::~SHA1Val() {
{
#if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER)
if ( IsValid() ) if ( IsValid() )
EVP_MD_CTX_free(ctx); EVP_MD_CTX_free(ctx);
#endif #endif
} }
ValPtr SHA1Val::DoClone(CloneState* state) ValPtr SHA1Val::DoClone(CloneState* state) {
{
auto out = make_intrusive<SHA1Val>(); auto out = make_intrusive<SHA1Val>();
if ( IsValid() ) if ( IsValid() ) {
{
if ( ! out->Init() ) if ( ! out->Init() )
return nullptr; return nullptr;
@ -404,8 +349,7 @@ ValPtr SHA1Val::DoClone(CloneState* state)
return state->NewClone(this, std::move(out)); return state->NewClone(this, std::move(out));
} }
bool SHA1Val::DoInit() bool SHA1Val::DoInit() {
{
assert(! IsValid()); assert(! IsValid());
#if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER)
ctx = detail::hash_init(detail::Hash_SHA1); ctx = detail::hash_init(detail::Hash_SHA1);
@ -415,8 +359,7 @@ bool SHA1Val::DoInit()
return true; return true;
} }
bool SHA1Val::DoFeed(const void* data, size_t size) bool SHA1Val::DoFeed(const void* data, size_t size) {
{
if ( ! IsValid() ) if ( ! IsValid() )
return false; return false;
@ -428,8 +371,7 @@ bool SHA1Val::DoFeed(const void* data, size_t size)
return true; return true;
} }
StringValPtr SHA1Val::DoGet() StringValPtr SHA1Val::DoGet() {
{
if ( ! IsValid() ) if ( ! IsValid() )
return val_mgr->EmptyString(); return val_mgr->EmptyString();
@ -444,8 +386,7 @@ StringValPtr SHA1Val::DoGet()
IMPLEMENT_OPAQUE_VALUE(SHA1Val) IMPLEMENT_OPAQUE_VALUE(SHA1Val)
broker::expected<broker::data> SHA1Val::DoSerialize() const broker::expected<broker::data> SHA1Val::DoSerialize() const {
{
if ( ! IsValid() ) if ( ! IsValid() )
return {broker::vector{false}}; return {broker::vector{false}};
@ -461,8 +402,7 @@ broker::expected<broker::data> SHA1Val::DoSerialize() const
return {std::move(d)}; return {std::move(d)};
} }
bool SHA1Val::DoUnserialize(const broker::data& data) bool SHA1Val::DoUnserialize(const broker::data& data) {
{
auto d = broker::get_if<broker::vector>(&data); auto d = broker::get_if<broker::vector>(&data);
if ( ! d ) if ( ! d )
return false; return false;
@ -471,8 +411,7 @@ bool SHA1Val::DoUnserialize(const broker::data& data)
if ( ! valid ) if ( ! valid )
return false; return false;
if ( ! *valid ) if ( ! *valid ) {
{
assert(! IsValid()); // default set by ctor assert(! IsValid()); // default set by ctor
return true; return true;
} }
@ -501,25 +440,19 @@ bool SHA1Val::DoUnserialize(const broker::data& data)
return true; return true;
} }
SHA256Val::SHA256Val() : HashVal(sha256_type) SHA256Val::SHA256Val() : HashVal(sha256_type) { memset(&ctx, 0, sizeof(ctx)); }
{
memset(&ctx, 0, sizeof(ctx));
}
SHA256Val::~SHA256Val() SHA256Val::~SHA256Val() {
{
#if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER)
if ( IsValid() ) if ( IsValid() )
EVP_MD_CTX_free(ctx); EVP_MD_CTX_free(ctx);
#endif #endif
} }
ValPtr SHA256Val::DoClone(CloneState* state) ValPtr SHA256Val::DoClone(CloneState* state) {
{
auto out = make_intrusive<SHA256Val>(); auto out = make_intrusive<SHA256Val>();
if ( IsValid() ) if ( IsValid() ) {
{
if ( ! out->Init() ) if ( ! out->Init() )
return nullptr; return nullptr;
@ -533,8 +466,7 @@ ValPtr SHA256Val::DoClone(CloneState* state)
return state->NewClone(this, std::move(out)); return state->NewClone(this, std::move(out));
} }
bool SHA256Val::DoInit() bool SHA256Val::DoInit() {
{
assert(! IsValid()); assert(! IsValid());
#if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER)
ctx = detail::hash_init(detail::Hash_SHA256); ctx = detail::hash_init(detail::Hash_SHA256);
@ -544,8 +476,7 @@ bool SHA256Val::DoInit()
return true; return true;
} }
bool SHA256Val::DoFeed(const void* data, size_t size) bool SHA256Val::DoFeed(const void* data, size_t size) {
{
if ( ! IsValid() ) if ( ! IsValid() )
return false; return false;
@ -557,8 +488,7 @@ bool SHA256Val::DoFeed(const void* data, size_t size)
return true; return true;
} }
StringValPtr SHA256Val::DoGet() StringValPtr SHA256Val::DoGet() {
{
if ( ! IsValid() ) if ( ! IsValid() )
return val_mgr->EmptyString(); return val_mgr->EmptyString();
@ -573,8 +503,7 @@ StringValPtr SHA256Val::DoGet()
IMPLEMENT_OPAQUE_VALUE(SHA256Val) IMPLEMENT_OPAQUE_VALUE(SHA256Val)
broker::expected<broker::data> SHA256Val::DoSerialize() const broker::expected<broker::data> SHA256Val::DoSerialize() const {
{
if ( ! IsValid() ) if ( ! IsValid() )
return {broker::vector{false}}; return {broker::vector{false}};
@ -590,8 +519,7 @@ broker::expected<broker::data> SHA256Val::DoSerialize() const
return {std::move(d)}; return {std::move(d)};
} }
bool SHA256Val::DoUnserialize(const broker::data& data) bool SHA256Val::DoUnserialize(const broker::data& data) {
{
auto d = broker::get_if<broker::vector>(&data); auto d = broker::get_if<broker::vector>(&data);
if ( ! d ) if ( ! d )
return false; return false;
@ -600,8 +528,7 @@ bool SHA256Val::DoUnserialize(const broker::data& data)
if ( ! valid ) if ( ! valid )
return false; return false;
if ( ! *valid ) if ( ! *valid ) {
{
assert(! IsValid()); // default set by ctor assert(! IsValid()); // default set by ctor
return true; return true;
} }
@ -632,23 +559,19 @@ bool SHA256Val::DoUnserialize(const broker::data& data)
EntropyVal::EntropyVal() : OpaqueVal(entropy_type) {} EntropyVal::EntropyVal() : OpaqueVal(entropy_type) {}
bool EntropyVal::Feed(const void* data, size_t size) bool EntropyVal::Feed(const void* data, size_t size) {
{
state.add(data, size); state.add(data, size);
return true; return true;
} }
bool EntropyVal::Get(double* r_ent, double* r_chisq, double* r_mean, double* r_montepicalc, bool EntropyVal::Get(double* r_ent, double* r_chisq, double* r_mean, double* r_montepicalc, double* r_scc) {
double* r_scc)
{
state.end(r_ent, r_chisq, r_mean, r_montepicalc, r_scc); state.end(r_ent, r_chisq, r_mean, r_montepicalc, r_scc);
return true; return true;
} }
IMPLEMENT_OPAQUE_VALUE(EntropyVal) IMPLEMENT_OPAQUE_VALUE(EntropyVal)
broker::expected<broker::data> EntropyVal::DoSerialize() const broker::expected<broker::data> EntropyVal::DoSerialize() const {
{
broker::vector d = { broker::vector d = {
static_cast<uint64_t>(state.totalc), static_cast<uint64_t>(state.mp), static_cast<uint64_t>(state.totalc), static_cast<uint64_t>(state.mp),
static_cast<uint64_t>(state.sccfirst), static_cast<uint64_t>(state.inmont), static_cast<uint64_t>(state.sccfirst), static_cast<uint64_t>(state.inmont),
@ -670,8 +593,7 @@ broker::expected<broker::data> EntropyVal::DoSerialize() const
return {std::move(d)}; return {std::move(d)};
} }
bool EntropyVal::DoUnserialize(const broker::data& data) bool EntropyVal::DoUnserialize(const broker::data& data) {
{
auto d = broker::get_if<broker::vector>(&data); auto d = broker::get_if<broker::vector>(&data);
if ( ! d ) if ( ! d )
return false; return false;
@ -705,14 +627,12 @@ bool EntropyVal::DoUnserialize(const broker::data& data)
if ( ! get_vector_idx<uint64_t>(*d, 13, &state.scct3) ) if ( ! get_vector_idx<uint64_t>(*d, 13, &state.scct3) )
return false; return false;
for ( int i = 0; i < 256; ++i ) for ( int i = 0; i < 256; ++i ) {
{
if ( ! get_vector_idx<uint64_t>(*d, 14 + i, &state.ccount[i]) ) if ( ! get_vector_idx<uint64_t>(*d, 14 + i, &state.ccount[i]) )
return false; return false;
} }
for ( int i = 0; i < RT_MONTEN; ++i ) for ( int i = 0; i < RT_MONTEN; ++i ) {
{
if ( ! get_vector_idx<uint64_t>(*d, 14 + 256 + i, &state.monte[i]) ) if ( ! get_vector_idx<uint64_t>(*d, 14 + 256 + i, &state.monte[i]) )
return false; return false;
} }
@ -720,22 +640,18 @@ bool EntropyVal::DoUnserialize(const broker::data& data)
return true; return true;
} }
BloomFilterVal::BloomFilterVal() : OpaqueVal(bloomfilter_type) BloomFilterVal::BloomFilterVal() : OpaqueVal(bloomfilter_type) {
{
hash = nullptr; hash = nullptr;
bloom_filter = nullptr; bloom_filter = nullptr;
} }
BloomFilterVal::BloomFilterVal(probabilistic::BloomFilter* bf) : OpaqueVal(bloomfilter_type) BloomFilterVal::BloomFilterVal(probabilistic::BloomFilter* bf) : OpaqueVal(bloomfilter_type) {
{
hash = nullptr; hash = nullptr;
bloom_filter = bf; bloom_filter = bf;
} }
ValPtr BloomFilterVal::DoClone(CloneState* state) ValPtr BloomFilterVal::DoClone(CloneState* state) {
{ if ( bloom_filter ) {
if ( bloom_filter )
{
auto bf = make_intrusive<BloomFilterVal>(bloom_filter->Clone()); auto bf = make_intrusive<BloomFilterVal>(bloom_filter->Clone());
assert(type); assert(type);
bf->Typify(type); bf->Typify(type);
@ -745,8 +661,7 @@ ValPtr BloomFilterVal::DoClone(CloneState* state)
return state->NewClone(this, make_intrusive<BloomFilterVal>()); return state->NewClone(this, make_intrusive<BloomFilterVal>());
} }
bool BloomFilterVal::Typify(TypePtr arg_type) bool BloomFilterVal::Typify(TypePtr arg_type) {
{
if ( type ) if ( type )
return false; return false;
@ -759,61 +674,45 @@ bool BloomFilterVal::Typify(TypePtr arg_type)
return true; return true;
} }
void BloomFilterVal::Add(const Val* val) void BloomFilterVal::Add(const Val* val) {
{
auto key = hash->MakeHashKey(*val, true); auto key = hash->MakeHashKey(*val, true);
bloom_filter->Add(key.get()); bloom_filter->Add(key.get());
} }
bool BloomFilterVal::Decrement(const Val* val) bool BloomFilterVal::Decrement(const Val* val) {
{
auto key = hash->MakeHashKey(*val, true); auto key = hash->MakeHashKey(*val, true);
return bloom_filter->Decrement(key.get()); return bloom_filter->Decrement(key.get());
} }
size_t BloomFilterVal::Count(const Val* val) const size_t BloomFilterVal::Count(const Val* val) const {
{
auto key = hash->MakeHashKey(*val, true); auto key = hash->MakeHashKey(*val, true);
size_t cnt = bloom_filter->Count(key.get()); size_t cnt = bloom_filter->Count(key.get());
return cnt; return cnt;
} }
void BloomFilterVal::Clear() void BloomFilterVal::Clear() { bloom_filter->Clear(); }
{
bloom_filter->Clear();
}
bool BloomFilterVal::Empty() const bool BloomFilterVal::Empty() const { return bloom_filter->Empty(); }
{
return bloom_filter->Empty();
}
std::string BloomFilterVal::InternalState() const std::string BloomFilterVal::InternalState() const { return bloom_filter->InternalState(); }
{
return bloom_filter->InternalState();
}
BloomFilterValPtr BloomFilterVal::Merge(const BloomFilterVal* x, const BloomFilterVal* y) BloomFilterValPtr BloomFilterVal::Merge(const BloomFilterVal* x, const BloomFilterVal* y) {
{
if ( x->Type() && // any one 0 is ok here if ( x->Type() && // any one 0 is ok here
y->Type() && ! same_type(x->Type(), y->Type()) ) y->Type() && ! same_type(x->Type(), y->Type()) ) {
{
reporter->Error("cannot merge Bloom filters with different types"); reporter->Error("cannot merge Bloom filters with different types");
return nullptr; return nullptr;
} }
auto final_type = x->Type() ? x->Type() : y->Type(); auto final_type = x->Type() ? x->Type() : y->Type();
if ( typeid(*x->bloom_filter) != typeid(*y->bloom_filter) ) if ( typeid(*x->bloom_filter) != typeid(*y->bloom_filter) ) {
{
reporter->Error("cannot merge different Bloom filter types"); reporter->Error("cannot merge different Bloom filter types");
return nullptr; return nullptr;
} }
probabilistic::BloomFilter* copy = x->bloom_filter->Clone(); probabilistic::BloomFilter* copy = x->bloom_filter->Clone();
if ( ! copy->Merge(y->bloom_filter) ) if ( ! copy->Merge(y->bloom_filter) ) {
{
delete copy; delete copy;
reporter->Error("failed to merge Bloom filter"); reporter->Error("failed to merge Bloom filter");
return nullptr; return nullptr;
@ -821,8 +720,7 @@ BloomFilterValPtr BloomFilterVal::Merge(const BloomFilterVal* x, const BloomFilt
auto merged = make_intrusive<BloomFilterVal>(copy); auto merged = make_intrusive<BloomFilterVal>(copy);
if ( final_type && ! merged->Typify(final_type) ) if ( final_type && ! merged->Typify(final_type) ) {
{
reporter->Error("failed to set type on merged Bloom filter"); reporter->Error("failed to set type on merged Bloom filter");
return nullptr; return nullptr;
} }
@ -830,25 +728,21 @@ BloomFilterValPtr BloomFilterVal::Merge(const BloomFilterVal* x, const BloomFilt
return merged; return merged;
} }
BloomFilterValPtr BloomFilterVal::Intersect(const BloomFilterVal* x, const BloomFilterVal* y) BloomFilterValPtr BloomFilterVal::Intersect(const BloomFilterVal* x, const BloomFilterVal* y) {
{
if ( x->Type() && // any one 0 is ok here if ( x->Type() && // any one 0 is ok here
y->Type() && ! same_type(x->Type(), y->Type()) ) y->Type() && ! same_type(x->Type(), y->Type()) ) {
{
reporter->Error("cannot merge Bloom filters with different types"); reporter->Error("cannot merge Bloom filters with different types");
return nullptr; return nullptr;
} }
if ( typeid(*x->bloom_filter) != typeid(*y->bloom_filter) ) if ( typeid(*x->bloom_filter) != typeid(*y->bloom_filter) ) {
{
reporter->Error("cannot intersect different Bloom filter types"); reporter->Error("cannot intersect different Bloom filter types");
return nullptr; return nullptr;
} }
auto intersected_bf = x->bloom_filter->Intersect(y->bloom_filter); auto intersected_bf = x->bloom_filter->Intersect(y->bloom_filter);
if ( ! intersected_bf ) if ( ! intersected_bf ) {
{
reporter->Error("failed to intersect Bloom filter"); reporter->Error("failed to intersect Bloom filter");
return nullptr; return nullptr;
} }
@ -857,8 +751,7 @@ BloomFilterValPtr BloomFilterVal::Intersect(const BloomFilterVal* x, const Bloom
auto intersected = make_intrusive<BloomFilterVal>(intersected_bf); auto intersected = make_intrusive<BloomFilterVal>(intersected_bf);
if ( final_type && ! intersected->Typify(final_type) ) if ( final_type && ! intersected->Typify(final_type) ) {
{
reporter->Error("Failed to set type on intersected bloom filter"); reporter->Error("Failed to set type on intersected bloom filter");
return nullptr; return nullptr;
} }
@ -866,20 +759,17 @@ BloomFilterValPtr BloomFilterVal::Intersect(const BloomFilterVal* x, const Bloom
return intersected; return intersected;
} }
BloomFilterVal::~BloomFilterVal() BloomFilterVal::~BloomFilterVal() {
{
delete hash; delete hash;
delete bloom_filter; delete bloom_filter;
} }
IMPLEMENT_OPAQUE_VALUE(BloomFilterVal) IMPLEMENT_OPAQUE_VALUE(BloomFilterVal)
broker::expected<broker::data> BloomFilterVal::DoSerialize() const broker::expected<broker::data> BloomFilterVal::DoSerialize() const {
{
broker::vector d; broker::vector d;
if ( type ) if ( type ) {
{
auto t = SerializeType(type); auto t = SerializeType(type);
if ( ! t ) if ( ! t )
return broker::ec::invalid_data; return broker::ec::invalid_data;
@ -897,16 +787,14 @@ broker::expected<broker::data> BloomFilterVal::DoSerialize() const
return {std::move(d)}; return {std::move(d)};
} }
bool BloomFilterVal::DoUnserialize(const broker::data& data) bool BloomFilterVal::DoUnserialize(const broker::data& data) {
{
auto v = broker::get_if<broker::vector>(&data); auto v = broker::get_if<broker::vector>(&data);
if ( ! (v && v->size() == 2) ) if ( ! (v && v->size() == 2) )
return false; return false;
auto no_type = broker::get_if<broker::none>(&(*v)[0]); auto no_type = broker::get_if<broker::none>(&(*v)[0]);
if ( ! no_type ) if ( ! no_type ) {
{
auto t = UnserializeType((*v)[0]); auto t = UnserializeType((*v)[0]);
if ( ! (t && Typify(std::move(t))) ) if ( ! (t && Typify(std::move(t))) )
@ -921,33 +809,26 @@ bool BloomFilterVal::DoUnserialize(const broker::data& data)
return true; return true;
} }
CardinalityVal::CardinalityVal() : OpaqueVal(cardinality_type) CardinalityVal::CardinalityVal() : OpaqueVal(cardinality_type) {
{
c = nullptr; c = nullptr;
hash = nullptr; hash = nullptr;
} }
CardinalityVal::CardinalityVal(probabilistic::detail::CardinalityCounter* arg_c) CardinalityVal::CardinalityVal(probabilistic::detail::CardinalityCounter* arg_c) : OpaqueVal(cardinality_type) {
: OpaqueVal(cardinality_type)
{
c = arg_c; c = arg_c;
hash = nullptr; hash = nullptr;
} }
CardinalityVal::~CardinalityVal() CardinalityVal::~CardinalityVal() {
{
delete c; delete c;
delete hash; delete hash;
} }
ValPtr CardinalityVal::DoClone(CloneState* state) ValPtr CardinalityVal::DoClone(CloneState* state) {
{ return state->NewClone(this, make_intrusive<CardinalityVal>(new probabilistic::detail::CardinalityCounter(*c)));
return state->NewClone(
this, make_intrusive<CardinalityVal>(new probabilistic::detail::CardinalityCounter(*c)));
} }
bool CardinalityVal::Typify(TypePtr arg_type) bool CardinalityVal::Typify(TypePtr arg_type) {
{
if ( type ) if ( type )
return false; return false;
@ -960,20 +841,17 @@ bool CardinalityVal::Typify(TypePtr arg_type)
return true; return true;
} }
void CardinalityVal::Add(const Val* val) void CardinalityVal::Add(const Val* val) {
{
auto key = hash->MakeHashKey(*val, true); auto key = hash->MakeHashKey(*val, true);
c->AddElement(key->Hash()); c->AddElement(key->Hash());
} }
IMPLEMENT_OPAQUE_VALUE(CardinalityVal) IMPLEMENT_OPAQUE_VALUE(CardinalityVal)
broker::expected<broker::data> CardinalityVal::DoSerialize() const broker::expected<broker::data> CardinalityVal::DoSerialize() const {
{
broker::vector d; broker::vector d;
if ( type ) if ( type ) {
{
auto t = SerializeType(type); auto t = SerializeType(type);
if ( ! t ) if ( ! t )
return broker::ec::invalid_data; return broker::ec::invalid_data;
@ -991,16 +869,14 @@ broker::expected<broker::data> CardinalityVal::DoSerialize() const
return {std::move(d)}; return {std::move(d)};
} }
bool CardinalityVal::DoUnserialize(const broker::data& data) bool CardinalityVal::DoUnserialize(const broker::data& data) {
{
auto v = broker::get_if<broker::vector>(&data); auto v = broker::get_if<broker::vector>(&data);
if ( ! (v && v->size() == 2) ) if ( ! (v && v->size() == 2) )
return false; return false;
auto no_type = broker::get_if<broker::none>(&(*v)[0]); auto no_type = broker::get_if<broker::none>(&(*v)[0]);
if ( ! no_type ) if ( ! no_type ) {
{
auto t = UnserializeType((*v)[0]); auto t = UnserializeType((*v)[0]);
if ( ! (t && Typify(std::move(t))) ) if ( ! (t && Typify(std::move(t))) )
@ -1015,13 +891,11 @@ bool CardinalityVal::DoUnserialize(const broker::data& data)
return true; return true;
} }
ParaglobVal::ParaglobVal(std::unique_ptr<paraglob::Paraglob> p) : OpaqueVal(paraglob_type) ParaglobVal::ParaglobVal(std::unique_ptr<paraglob::Paraglob> p) : OpaqueVal(paraglob_type) {
{
this->internal_paraglob = std::move(p); this->internal_paraglob = std::move(p);
} }
VectorValPtr ParaglobVal::Get(StringVal*& pattern) VectorValPtr ParaglobVal::Get(StringVal*& pattern) {
{
auto rval = make_intrusive<VectorVal>(id::string_vec); auto rval = make_intrusive<VectorVal>(id::string_vec);
std::string string_pattern(reinterpret_cast<const char*>(pattern->Bytes()), pattern->Len()); std::string string_pattern(reinterpret_cast<const char*>(pattern->Bytes()), pattern->Len());
@ -1032,15 +906,13 @@ VectorValPtr ParaglobVal::Get(StringVal*& pattern)
return rval; return rval;
} }
bool ParaglobVal::operator==(const ParaglobVal& other) const bool ParaglobVal::operator==(const ParaglobVal& other) const {
{
return *(this->internal_paraglob) == *(other.internal_paraglob); return *(this->internal_paraglob) == *(other.internal_paraglob);
} }
IMPLEMENT_OPAQUE_VALUE(ParaglobVal) IMPLEMENT_OPAQUE_VALUE(ParaglobVal)
broker::expected<broker::data> ParaglobVal::DoSerialize() const broker::expected<broker::data> ParaglobVal::DoSerialize() const {
{
broker::vector d; broker::vector d;
std::unique_ptr<std::vector<uint8_t>> iv = this->internal_paraglob->serialize(); std::unique_ptr<std::vector<uint8_t>> iv = this->internal_paraglob->serialize();
for ( uint8_t a : *(iv.get()) ) for ( uint8_t a : *(iv.get()) )
@ -1048,8 +920,7 @@ broker::expected<broker::data> ParaglobVal::DoSerialize() const
return {std::move(d)}; return {std::move(d)};
} }
bool ParaglobVal::DoUnserialize(const broker::data& data) bool ParaglobVal::DoUnserialize(const broker::data& data) {
{
auto d = broker::get_if<broker::vector>(&data); auto d = broker::get_if<broker::vector>(&data);
if ( ! d ) if ( ! d )
return false; return false;
@ -1057,23 +928,17 @@ bool ParaglobVal::DoUnserialize(const broker::data& data)
std::unique_ptr<std::vector<uint8_t>> iv(new std::vector<uint8_t>); std::unique_ptr<std::vector<uint8_t>> iv(new std::vector<uint8_t>);
iv->resize(d->size()); iv->resize(d->size());
for ( std::vector<broker::data>::size_type i = 0; i < d->size(); ++i ) for ( std::vector<broker::data>::size_type i = 0; i < d->size(); ++i ) {
{
if ( ! get_vector_idx<uint64_t>(*d, i, iv.get()->data() + i) ) if ( ! get_vector_idx<uint64_t>(*d, i, iv.get()->data() + i) )
return false; return false;
} }
try try {
{
this->internal_paraglob = std::make_unique<paraglob::Paraglob>(std::move(iv)); this->internal_paraglob = std::make_unique<paraglob::Paraglob>(std::move(iv));
} } catch ( const paraglob::underflow_error& e ) {
catch ( const paraglob::underflow_error& e )
{
reporter->Error("Paraglob underflow error -> %s", e.what()); reporter->Error("Paraglob underflow error -> %s", e.what());
return false; return false;
} } catch ( const paraglob::overflow_error& e ) {
catch ( const paraglob::overflow_error& e )
{
reporter->Error("Paraglob overflow error -> %s", e.what()); reporter->Error("Paraglob overflow error -> %s", e.what());
return false; return false;
} }
@ -1081,46 +946,31 @@ bool ParaglobVal::DoUnserialize(const broker::data& data)
return true; return true;
} }
ValPtr ParaglobVal::DoClone(CloneState* state) ValPtr ParaglobVal::DoClone(CloneState* state) {
{ try {
try return make_intrusive<ParaglobVal>(std::make_unique<paraglob::Paraglob>(this->internal_paraglob->serialize()));
{ } catch ( const paraglob::underflow_error& e ) {
return make_intrusive<ParaglobVal>(
std::make_unique<paraglob::Paraglob>(this->internal_paraglob->serialize()));
}
catch ( const paraglob::underflow_error& e )
{
reporter->Error("Paraglob underflow error while cloning -> %s", e.what()); reporter->Error("Paraglob underflow error while cloning -> %s", e.what());
return nullptr; return nullptr;
} } catch ( const paraglob::overflow_error& e ) {
catch ( const paraglob::overflow_error& e )
{
reporter->Error("Paraglob overflow error while cloning -> %s", e.what()); reporter->Error("Paraglob overflow error while cloning -> %s", e.what());
return nullptr; return nullptr;
} }
} }
broker::expected<broker::data> TelemetryVal::DoSerialize() const broker::expected<broker::data> TelemetryVal::DoSerialize() const {
{
return broker::make_error(broker::ec::invalid_data, "cannot serialize metric handles"); return broker::make_error(broker::ec::invalid_data, "cannot serialize metric handles");
} }
bool TelemetryVal::DoUnserialize(const broker::data&) bool TelemetryVal::DoUnserialize(const broker::data&) { return false; }
{
return false;
}
TelemetryVal::TelemetryVal(telemetry::IntCounter) : OpaqueVal(int_counter_metric_type) {} TelemetryVal::TelemetryVal(telemetry::IntCounter) : OpaqueVal(int_counter_metric_type) {}
TelemetryVal::TelemetryVal(telemetry::IntCounterFamily) : OpaqueVal(int_counter_metric_family_type) TelemetryVal::TelemetryVal(telemetry::IntCounterFamily) : OpaqueVal(int_counter_metric_family_type) {}
{
}
TelemetryVal::TelemetryVal(telemetry::DblCounter) : OpaqueVal(dbl_counter_metric_type) {} TelemetryVal::TelemetryVal(telemetry::DblCounter) : OpaqueVal(dbl_counter_metric_type) {}
TelemetryVal::TelemetryVal(telemetry::DblCounterFamily) : OpaqueVal(dbl_counter_metric_family_type) TelemetryVal::TelemetryVal(telemetry::DblCounterFamily) : OpaqueVal(dbl_counter_metric_family_type) {}
{
}
TelemetryVal::TelemetryVal(telemetry::IntGauge) : OpaqueVal(int_gauge_metric_type) {} TelemetryVal::TelemetryVal(telemetry::IntGauge) : OpaqueVal(int_gauge_metric_type) {}
@ -1132,16 +982,10 @@ TelemetryVal::TelemetryVal(telemetry::DblGaugeFamily) : OpaqueVal(dbl_gauge_metr
TelemetryVal::TelemetryVal(telemetry::IntHistogram) : OpaqueVal(int_histogram_metric_type) {} TelemetryVal::TelemetryVal(telemetry::IntHistogram) : OpaqueVal(int_histogram_metric_type) {}
TelemetryVal::TelemetryVal(telemetry::IntHistogramFamily) TelemetryVal::TelemetryVal(telemetry::IntHistogramFamily) : OpaqueVal(int_histogram_metric_family_type) {}
: OpaqueVal(int_histogram_metric_family_type)
{
}
TelemetryVal::TelemetryVal(telemetry::DblHistogram) : OpaqueVal(dbl_histogram_metric_type) {} TelemetryVal::TelemetryVal(telemetry::DblHistogram) : OpaqueVal(dbl_histogram_metric_type) {}
TelemetryVal::TelemetryVal(telemetry::DblHistogramFamily) TelemetryVal::TelemetryVal(telemetry::DblHistogramFamily) : OpaqueVal(dbl_histogram_metric_family_type) {}
: OpaqueVal(dbl_histogram_metric_family_type)
{
}
} } // namespace zeek

View file

@ -21,20 +21,16 @@
#include "zeek/telemetry/Gauge.h" #include "zeek/telemetry/Gauge.h"
#include "zeek/telemetry/Histogram.h" #include "zeek/telemetry/Histogram.h"
namespace broker namespace broker {
{
class data; class data;
} }
namespace zeek namespace zeek {
{
namespace probabilistic namespace probabilistic {
{
class BloomFilter; class BloomFilter;
} }
namespace probabilistic::detail namespace probabilistic::detail {
{
class CardinalityCounter; class CardinalityCounter;
} }
@ -48,8 +44,7 @@ using BloomFilterValPtr = IntrusivePtr<BloomFilterVal>;
* Singleton that registers all available all available types of opaque * Singleton that registers all available all available types of opaque
* values. This facilitates their serialization into Broker values. * values. This facilitates their serialization into Broker values.
*/ */
class OpaqueMgr class OpaqueMgr {
{
public: public:
using Factory = OpaqueValPtr(); using Factory = OpaqueValPtr();
@ -84,8 +79,8 @@ public:
* Internal helper class to register an OpaqueVal-derived classes * Internal helper class to register an OpaqueVal-derived classes
* with the manager. * with the manager.
*/ */
template <class T> class Register template<class T>
{ class Register {
public: public:
Register(const char* id) { OpaqueMgr::mgr()->_types.emplace(id, &T::OpaqueInstantiate); } Register(const char* id) { OpaqueMgr::mgr()->_types.emplace(id, &T::OpaqueInstantiate); }
}; };
@ -114,8 +109,7 @@ private:
* completely internally, with no further script-level operators provided * completely internally, with no further script-level operators provided
* (other than bif functions). See OpaqueVal.h for derived classes. * (other than bif functions). See OpaqueVal.h for derived classes.
*/ */
class OpaqueVal : public Val class OpaqueVal : public Val {
{
public: public:
explicit OpaqueVal(OpaqueTypePtr t); explicit OpaqueVal(OpaqueTypePtr t);
~OpaqueVal() override = default; ~OpaqueVal() override = default;
@ -187,12 +181,10 @@ protected:
void ValDescribeReST(ODesc* d) const override; void ValDescribeReST(ODesc* d) const override;
}; };
class HashVal : public OpaqueVal class HashVal : public OpaqueVal {
{
public: public:
template<class T> template<class T>
static void digest_all(detail::HashAlgorithm alg, const T& vlist, u_char* result) static void digest_all(detail::HashAlgorithm alg, const T& vlist, u_char* result) {
{
auto h = detail::hash_init(alg); auto h = detail::hash_init(alg);
for ( const auto& v : vlist ) for ( const auto& v : vlist )
@ -221,18 +213,15 @@ private:
bool valid; bool valid;
}; };
class MD5Val : public HashVal class MD5Val : public HashVal {
{
public: public:
template <class T> static void digest(const T& vlist, u_char result[MD5_DIGEST_LENGTH]) template<class T>
{ static void digest(const T& vlist, u_char result[MD5_DIGEST_LENGTH]) {
digest_all(detail::Hash_MD5, vlist, result); digest_all(detail::Hash_MD5, vlist, result);
} }
template<class T> template<class T>
static void hmac(const T& vlist, u_char key[MD5_DIGEST_LENGTH], static void hmac(const T& vlist, u_char key[MD5_DIGEST_LENGTH], u_char result[MD5_DIGEST_LENGTH]) {
u_char result[MD5_DIGEST_LENGTH])
{
digest(vlist, result); digest(vlist, result);
for ( int i = 0; i < MD5_DIGEST_LENGTH; ++i ) for ( int i = 0; i < MD5_DIGEST_LENGTH; ++i )
@ -262,11 +251,10 @@ private:
#endif #endif
}; };
class SHA1Val : public HashVal class SHA1Val : public HashVal {
{
public: public:
template <class T> static void digest(const T& vlist, u_char result[SHA_DIGEST_LENGTH]) template<class T>
{ static void digest(const T& vlist, u_char result[SHA_DIGEST_LENGTH]) {
digest_all(detail::Hash_SHA1, vlist, result); digest_all(detail::Hash_SHA1, vlist, result);
} }
@ -291,11 +279,10 @@ private:
#endif #endif
}; };
class SHA256Val : public HashVal class SHA256Val : public HashVal {
{
public: public:
template <class T> static void digest(const T& vlist, u_char result[SHA256_DIGEST_LENGTH]) template<class T>
{ static void digest(const T& vlist, u_char result[SHA256_DIGEST_LENGTH]) {
digest_all(detail::Hash_SHA256, vlist, result); digest_all(detail::Hash_SHA256, vlist, result);
} }
@ -320,8 +307,7 @@ private:
#endif #endif
}; };
class EntropyVal : public OpaqueVal class EntropyVal : public OpaqueVal {
{
public: public:
EntropyVal(); EntropyVal();
@ -336,8 +322,7 @@ private:
detail::RandTest state; detail::RandTest state;
}; };
class BloomFilterVal : public OpaqueVal class BloomFilterVal : public OpaqueVal {
{
public: public:
explicit BloomFilterVal(probabilistic::BloomFilter* bf); explicit BloomFilterVal(probabilistic::BloomFilter* bf);
~BloomFilterVal() override; ~BloomFilterVal() override;
@ -373,8 +358,7 @@ private:
probabilistic::BloomFilter* bloom_filter; probabilistic::BloomFilter* bloom_filter;
}; };
class CardinalityVal : public OpaqueVal class CardinalityVal : public OpaqueVal {
{
public: public:
explicit CardinalityVal(probabilistic::detail::CardinalityCounter*); explicit CardinalityVal(probabilistic::detail::CardinalityCounter*);
~CardinalityVal() override; ~CardinalityVal() override;
@ -399,8 +383,7 @@ private:
probabilistic::detail::CardinalityCounter* c; probabilistic::detail::CardinalityCounter* c;
}; };
class ParaglobVal : public OpaqueVal class ParaglobVal : public OpaqueVal {
{
public: public:
explicit ParaglobVal(std::unique_ptr<paraglob::Paraglob> p); explicit ParaglobVal(std::unique_ptr<paraglob::Paraglob> p);
VectorValPtr Get(StringVal*& pattern); VectorValPtr Get(StringVal*& pattern);
@ -419,8 +402,7 @@ private:
/** /**
* Base class for metric handles. Handle types are not serializable. * Base class for metric handles. Handle types are not serializable.
*/ */
class TelemetryVal : public OpaqueVal class TelemetryVal : public OpaqueVal {
{
protected: protected:
explicit TelemetryVal(telemetry::IntCounter); explicit TelemetryVal(telemetry::IntCounter);
explicit TelemetryVal(telemetry::IntCounterFamily); explicit TelemetryVal(telemetry::IntCounterFamily);
@ -439,8 +421,8 @@ protected:
bool DoUnserialize(const broker::data& data) override; bool DoUnserialize(const broker::data& data) override;
}; };
template <class Handle> class TelemetryValImpl : public TelemetryVal template<class Handle>
{ class TelemetryValImpl : public TelemetryVal {
public: public:
using HandleType = Handle; using HandleType = Handle;

View file

@ -19,18 +19,15 @@
#include "zeek/logging/writers/ascii/Ascii.h" #include "zeek/logging/writers/ascii/Ascii.h"
#include "zeek/script_opt/ScriptOpt.h" #include "zeek/script_opt/ScriptOpt.h"
namespace zeek namespace zeek {
{
void Options::filter_supervisor_options() void Options::filter_supervisor_options() {
{
pcap_filter = {}; pcap_filter = {};
signature_files = {}; signature_files = {};
pcap_output_file = {}; pcap_output_file = {};
} }
void Options::filter_supervised_node_options() void Options::filter_supervised_node_options() {
{
auto og = *this; auto og = *this;
*this = {}; *this = {};
@ -71,117 +68,108 @@ void Options::filter_supervised_node_options()
script_options_to_set = og.script_options_to_set; script_options_to_set = og.script_options_to_set;
} }
bool fake_dns() bool fake_dns() { return getenv("ZEEK_DNS_FAKE"); }
{
return getenv("ZEEK_DNS_FAKE");
}
extern const char* zeek_version(); extern const char* zeek_version();
void usage(const char* prog, int code) void usage(const char* prog, int code) {
{
fprintf(stderr, "zeek version %s\n", zeek_version()); fprintf(stderr, "zeek version %s\n", zeek_version());
fprintf(stderr, "usage: %s [options] [file ...]\n", prog); fprintf(stderr, "usage: %s [options] [file ...]\n", prog);
fprintf(stderr, "usage: %s --test [doctest-options] -- [options] [file ...]\n", prog); fprintf(stderr, "usage: %s --test [doctest-options] -- [options] [file ...]\n", prog);
fprintf(stderr, " <file> | Zeek script file, or read stdin\n"); fprintf(stderr, " <file> | Zeek script file, or read stdin\n");
fprintf(stderr, fprintf(stderr, " -a|--parse-only | exit immediately after parsing scripts\n");
" -a|--parse-only | exit immediately after parsing scripts\n"); fprintf(stderr, " -b|--bare-mode | don't load scripts from the base/ directory\n");
fprintf(stderr, fprintf(stderr, " -c|--capture-unprocessed <file> | write unprocessed packets to a tcpdump file\n");
" -b|--bare-mode | don't load scripts from the base/ directory\n");
fprintf(stderr,
" -c|--capture-unprocessed <file> | write unprocessed packets to a tcpdump file\n");
fprintf(stderr, " -d|--debug-script | activate Zeek script debugging\n"); fprintf(stderr, " -d|--debug-script | activate Zeek script debugging\n");
fprintf(stderr, " -e|--exec <zeek code> | augment loaded scripts by given code\n"); fprintf(stderr, " -e|--exec <zeek code> | augment loaded scripts by given code\n");
fprintf(stderr, " -f|--filter <filter> | tcpdump filter\n"); fprintf(stderr, " -f|--filter <filter> | tcpdump filter\n");
fprintf(stderr, " -h|--help | command line help\n"); fprintf(stderr, " -h|--help | command line help\n");
fprintf(stderr, " -i|--iface <interface> | read from given interface (only one allowed)\n");
fprintf(stderr, " -p|--prefix <prefix> | add given prefix to Zeek script file resolution\n");
fprintf(stderr, fprintf(stderr,
" -i|--iface <interface> | read from given interface (only one allowed)\n"); " -r|--readfile <readfile> | read from given tcpdump file (only one "
fprintf(
stderr,
" -p|--prefix <prefix> | add given prefix to Zeek script file resolution\n");
fprintf(stderr, " -r|--readfile <readfile> | read from given tcpdump file (only one "
"allowed, pass '-' as the filename to read from stdin)\n"); "allowed, pass '-' as the filename to read from stdin)\n");
fprintf(stderr, " -s|--rulefile <rulefile> | read rules from given file\n"); fprintf(stderr, " -s|--rulefile <rulefile> | read rules from given file\n");
fprintf(stderr, " -t|--tracefile <tracefile> | activate execution tracing\n"); fprintf(stderr, " -t|--tracefile <tracefile> | activate execution tracing\n");
fprintf(stderr, " -u|--usage-issues | find variable usage issues and exit\n"); fprintf(stderr, " -u|--usage-issues | find variable usage issues and exit\n");
fprintf(stderr, " --no-unused-warnings | suppress warnings of unused " fprintf(stderr,
" --no-unused-warnings | suppress warnings of unused "
"functions/hooks/events\n"); "functions/hooks/events\n");
fprintf(stderr, " -v|--version | print version and exit\n"); fprintf(stderr, " -v|--version | print version and exit\n");
fprintf(stderr, " -V|--build-info | print build information and exit\n"); fprintf(stderr, " -V|--build-info | print build information and exit\n");
fprintf(stderr, " -w|--writefile <writefile> | write to given tcpdump file\n"); fprintf(stderr, " -w|--writefile <writefile> | write to given tcpdump file\n");
#ifdef DEBUG #ifdef DEBUG
fprintf(stderr, " -B|--debug <dbgstreams> | Enable debugging output for selected " fprintf(stderr,
" -B|--debug <dbgstreams> | Enable debugging output for selected "
"streams ('-B help' for help)\n"); "streams ('-B help' for help)\n");
#endif #endif
fprintf(stderr, " -C|--no-checksums | ignore checksums\n"); fprintf(stderr, " -C|--no-checksums | ignore checksums\n");
fprintf(stderr, " -D|--deterministic | initialize random seeds to zero\n"); fprintf(stderr, " -D|--deterministic | initialize random seeds to zero\n");
fprintf(stderr, " -E|--event-trace <file> | generate a replayable event trace to " fprintf(stderr,
" -E|--event-trace <file> | generate a replayable event trace to "
"the given file\n"); "the given file\n");
fprintf(stderr, " -F|--force-dns | force DNS\n"); fprintf(stderr, " -F|--force-dns | force DNS\n");
fprintf(stderr, " -G|--load-seeds <file> | load seeds from given file\n"); fprintf(stderr, " -G|--load-seeds <file> | load seeds from given file\n");
fprintf(stderr, " -H|--save-seeds <file> | save seeds to given file\n"); fprintf(stderr, " -H|--save-seeds <file> | save seeds to given file\n");
fprintf(stderr, " -I|--print-id <ID name> | print out given ID\n"); fprintf(stderr, " -I|--print-id <ID name> | print out given ID\n");
fprintf(stderr, " -N|--print-plugins | print available plugins and exit (-NN " fprintf(stderr,
" -N|--print-plugins | print available plugins and exit (-NN "
"for verbose)\n"); "for verbose)\n");
fprintf(stderr, " -O|--optimize <option> | enable script optimization (use -O help " fprintf(stderr,
" -O|--optimize <option> | enable script optimization (use -O help "
"for options)\n"); "for options)\n");
fprintf(stderr, " -0|--optimize-files=<pat> | enable script optimization for all " fprintf(stderr,
" -0|--optimize-files=<pat> | enable script optimization for all "
"functions in files with names containing the given pattern\n"); "functions in files with names containing the given pattern\n");
fprintf(stderr, " -o|--optimize-funcs=<pat> | enable script optimization for " fprintf(stderr,
" -o|--optimize-funcs=<pat> | enable script optimization for "
"functions with names fully matching the given pattern\n"); "functions with names fully matching the given pattern\n");
fprintf(stderr, " -P|--prime-dns | prime DNS\n"); fprintf(stderr, " -P|--prime-dns | prime DNS\n");
fprintf(stderr, fprintf(stderr, " -Q|--time | print execution time summary to stderr\n");
" -Q|--time | print execution time summary to stderr\n");
fprintf(stderr, " -S|--debug-rules | enable rule debugging\n"); fprintf(stderr, " -S|--debug-rules | enable rule debugging\n");
fprintf(stderr, " -T|--re-level <level> | set 'RE_level' for rules\n"); fprintf(stderr, " -T|--re-level <level> | set 'RE_level' for rules\n");
fprintf(stderr, " -U|--status-file <file> | Record process status in file\n"); fprintf(stderr, " -U|--status-file <file> | Record process status in file\n");
fprintf(stderr, " -W|--watchdog | activate watchdog timer\n"); fprintf(stderr, " -W|--watchdog | activate watchdog timer\n");
fprintf(stderr, fprintf(stderr, " -X|--zeekygen <cfgfile> | generate documentation based on config file\n");
" -X|--zeekygen <cfgfile> | generate documentation based on config file\n");
#ifdef USE_PERFTOOLS_DEBUG #ifdef USE_PERFTOOLS_DEBUG
fprintf(stderr, " -m|--mem-leaks | show leaks [perftools]\n"); fprintf(stderr, " -m|--mem-leaks | show leaks [perftools]\n");
fprintf(stderr, " -M|--mem-profile | record heap [perftools]\n"); fprintf(stderr, " -M|--mem-profile | record heap [perftools]\n");
#endif #endif
fprintf( fprintf(stderr, " --profile-scripts[=file] | profile scripts to given file (default stdout)\n");
stderr,
" --profile-scripts[=file] | profile scripts to given file (default stdout)\n");
fprintf(stderr, fprintf(stderr,
" --profile-script-call-stacks | add call stacks to profile output (requires " " --profile-script-call-stacks | add call stacks to profile output (requires "
"--profile-scripts)\n"); "--profile-scripts)\n");
fprintf(stderr, " --pseudo-realtime[=<speedup>] | enable pseudo-realtime for performance " fprintf(stderr,
" --pseudo-realtime[=<speedup>] | enable pseudo-realtime for performance "
"evaluation (default 1)\n"); "evaluation (default 1)\n");
fprintf(stderr, " -j|--jobs | enable supervisor mode\n"); fprintf(stderr, " -j|--jobs | enable supervisor mode\n");
fprintf(stderr, " --test | run unit tests ('--test -h' for help, " fprintf(stderr,
" --test | run unit tests ('--test -h' for help, "
"not available when built without ENABLE_ZEEK_UNIT_TESTS)\n"); "not available when built without ENABLE_ZEEK_UNIT_TESTS)\n");
fprintf(stderr, " $ZEEKPATH | file search path (%s)\n", fprintf(stderr, " $ZEEKPATH | file search path (%s)\n", util::zeek_path().c_str());
util::zeek_path().c_str()); fprintf(stderr, " $ZEEK_PLUGIN_PATH | plugin search path (%s)\n", util::zeek_plugin_path());
fprintf(stderr, " $ZEEK_PLUGIN_PATH | plugin search path (%s)\n",
util::zeek_plugin_path());
fprintf(stderr, " $ZEEK_PLUGIN_ACTIVATE | plugins to always activate (%s)\n", fprintf(stderr, " $ZEEK_PLUGIN_ACTIVATE | plugins to always activate (%s)\n",
util::zeek_plugin_activate()); util::zeek_plugin_activate());
fprintf(stderr, " $ZEEK_PREFIXES | prefix list (%s)\n", fprintf(stderr, " $ZEEK_PREFIXES | prefix list (%s)\n", util::zeek_prefixes().c_str());
util::zeek_prefixes().c_str()); fprintf(stderr, " $ZEEK_DNS_FAKE | disable DNS lookups (%s)\n", fake_dns() ? "on" : "off");
fprintf(stderr, " $ZEEK_DNS_FAKE | disable DNS lookups (%s)\n",
fake_dns() ? "on" : "off");
fprintf(stderr, " $ZEEK_SEED_VALUES | list of space separated seeds (%s)\n", fprintf(stderr, " $ZEEK_SEED_VALUES | list of space separated seeds (%s)\n",
getenv("ZEEK_SEED_VALUES") ? "set" : "not set"); getenv("ZEEK_SEED_VALUES") ? "set" : "not set");
fprintf(stderr, " $ZEEK_SEED_FILE | file to load seeds from (not set)\n"); fprintf(stderr, " $ZEEK_SEED_FILE | file to load seeds from (not set)\n");
fprintf(stderr, " $ZEEK_LOG_SUFFIX | ASCII log file extension (.%s)\n", fprintf(stderr, " $ZEEK_LOG_SUFFIX | ASCII log file extension (.%s)\n",
logging::writer::detail::Ascii::LogExt().c_str()); logging::writer::detail::Ascii::LogExt().c_str());
fprintf(stderr, " $ZEEK_PROFILER_FILE | Output file for script execution " fprintf(stderr,
" $ZEEK_PROFILER_FILE | Output file for script execution "
"statistics (not set)\n"); "statistics (not set)\n");
fprintf(stderr, fprintf(stderr, " $ZEEK_DISABLE_ZEEKYGEN | Disable Zeekygen documentation support (%s)\n",
" $ZEEK_DISABLE_ZEEKYGEN | Disable Zeekygen documentation support (%s)\n",
getenv("ZEEK_DISABLE_ZEEKYGEN") ? "set" : "not set"); getenv("ZEEK_DISABLE_ZEEKYGEN") ? "set" : "not set");
fprintf(stderr, " $ZEEK_DNS_RESOLVER | IPv4/IPv6 address of DNS resolver to use (%s)\n",
getenv("ZEEK_DNS_RESOLVER") ? getenv("ZEEK_DNS_RESOLVER") :
"not set, will use first IPv4 address from /etc/resolv.conf");
fprintf(stderr, fprintf(stderr,
" $ZEEK_DNS_RESOLVER | IPv4/IPv6 address of DNS resolver to use (%s)\n", " $ZEEK_DEBUG_LOG_STDERR | Use stderr for debug logs generated via "
getenv("ZEEK_DNS_RESOLVER")
? getenv("ZEEK_DNS_RESOLVER")
: "not set, will use first IPv4 address from /etc/resolv.conf");
fprintf(stderr, " $ZEEK_DEBUG_LOG_STDERR | Use stderr for debug logs generated via "
"the -B flag"); "the -B flag");
fprintf(stderr, "\n"); fprintf(stderr, "\n");
@ -189,8 +177,7 @@ void usage(const char* prog, int code)
exit(code); exit(code);
} }
static void print_analysis_help() static void print_analysis_help() {
{
fprintf(stderr, "--optimize options when using ZAM:\n"); fprintf(stderr, "--optimize options when using ZAM:\n");
fprintf(stderr, " ZAM execute scripts using ZAM and all optimizations\n"); fprintf(stderr, " ZAM execute scripts using ZAM and all optimizations\n");
fprintf(stderr, " help print this list\n"); fprintf(stderr, " help print this list\n");
@ -199,21 +186,17 @@ static void print_analysis_help()
fprintf(stderr, " dump-uds dump use-defs to stdout; implies xform\n"); fprintf(stderr, " dump-uds dump use-defs to stdout; implies xform\n");
fprintf(stderr, " dump-xform dump transformed scripts to stdout; implies xform\n"); fprintf(stderr, " dump-xform dump transformed scripts to stdout; implies xform\n");
fprintf(stderr, " dump-ZAM dump generated ZAM code; implies gen-ZAM-code\n"); fprintf(stderr, " dump-ZAM dump generated ZAM code; implies gen-ZAM-code\n");
fprintf(stderr, fprintf(stderr, " gen-ZAM-code generate ZAM code (without turning on additional optimizations)\n");
" gen-ZAM-code generate ZAM code (without turning on additional optimizations)\n");
fprintf(stderr, " inline inline function calls\n"); fprintf(stderr, " inline inline function calls\n");
fprintf(stderr, " no-ZAM-opt omit low-level ZAM optimization\n"); fprintf(stderr, " no-ZAM-opt omit low-level ZAM optimization\n");
fprintf(stderr, " optimize-all optimize all scripts, even inlined ones\n"); fprintf(stderr, " optimize-all optimize all scripts, even inlined ones\n");
fprintf(stderr, " optimize-AST optimize the (transformed) AST; implies xform\n"); fprintf(stderr, " optimize-AST optimize the (transformed) AST; implies xform\n");
fprintf(stderr, fprintf(stderr, " profile-ZAM generate to stdout a ZAM execution profile; implies -O ZAM\n");
" profile-ZAM generate to stdout a ZAM execution profile; implies -O ZAM\n");
fprintf(stderr, " report-recursive report on recursive functions and exit\n"); fprintf(stderr, " report-recursive report on recursive functions and exit\n");
fprintf(stderr, " xform transform scripts to \"reduced\" form\n"); fprintf(stderr, " xform transform scripts to \"reduced\" form\n");
fprintf(stderr, "\n--optimize options when generating C++:\n"); fprintf(stderr, "\n--optimize options when generating C++:\n");
fprintf( fprintf(stderr, " allow-cond allow standalone compilation of functions influenced by conditionals\n");
stderr,
" allow-cond allow standalone compilation of functions influenced by conditionals\n");
fprintf(stderr, " gen-C++ generate C++ script bodies\n"); fprintf(stderr, " gen-C++ generate C++ script bodies\n");
fprintf(stderr, " gen-standalone-C++ generate \"standalone\" C++ script bodies\n"); fprintf(stderr, " gen-standalone-C++ generate \"standalone\" C++ script bodies\n");
fprintf(stderr, " help print this list\n"); fprintf(stderr, " help print this list\n");
@ -222,19 +205,16 @@ static void print_analysis_help()
fprintf(stderr, " use-C++ use available C++ script bodies\n"); fprintf(stderr, " use-C++ use available C++ script bodies\n");
} }
static void set_analysis_option(const char* opt, Options& opts) static void set_analysis_option(const char* opt, Options& opts) {
{
auto& a_o = opts.analysis_options; auto& a_o = opts.analysis_options;
if ( ! opt || util::streq(opt, "ZAM") ) if ( ! opt || util::streq(opt, "ZAM") ) {
{
a_o.inliner = a_o.optimize_AST = a_o.activate = true; a_o.inliner = a_o.optimize_AST = a_o.activate = true;
a_o.gen_ZAM = true; a_o.gen_ZAM = true;
return; return;
} }
if ( util::streq(opt, "help") ) if ( util::streq(opt, "help") ) {
{
print_analysis_help(); print_analysis_help();
exit(0); exit(0);
} }
@ -274,16 +254,14 @@ static void set_analysis_option(const char* opt, Options& opts)
else if ( util::streq(opt, "xform") ) else if ( util::streq(opt, "xform") )
a_o.activate = true; a_o.activate = true;
else else {
{
fprintf(stderr, "zeek: unrecognized -O/--optimize option: %s\n\n", opt); fprintf(stderr, "zeek: unrecognized -O/--optimize option: %s\n\n", opt);
print_analysis_help(); print_analysis_help();
exit(1); exit(1);
} }
} }
Options parse_cmdline(int argc, char** argv) Options parse_cmdline(int argc, char** argv) {
{
Options rval; Options rval;
// When running unit tests, the first argument on the command line must be // When running unit tests, the first argument on the command line must be
@ -295,26 +273,22 @@ Options parse_cmdline(int argc, char** argv)
// Just locally filtering out the args for Zeek usage from doctest args. // Just locally filtering out the args for Zeek usage from doctest args.
std::vector<std::string> zeek_args; std::vector<std::string> zeek_args;
if ( argc > 1 && strcmp(argv[1], "--test") == 0 ) if ( argc > 1 && strcmp(argv[1], "--test") == 0 ) {
{
#ifdef DOCTEST_CONFIG_DISABLE #ifdef DOCTEST_CONFIG_DISABLE
fprintf(stderr, "ERROR: C++ unit tests are disabled for this build.\n" fprintf(stderr,
"ERROR: C++ unit tests are disabled for this build.\n"
" Please re-compile with ENABLE_ZEEK_UNIT_TESTS " " Please re-compile with ENABLE_ZEEK_UNIT_TESTS "
"to run the C++ unit tests.\n"); "to run the C++ unit tests.\n");
exit(1); exit(1);
#endif #endif
auto is_separator = [](const char* cstr) auto is_separator = [](const char* cstr) { return strcmp(cstr, "--") == 0; };
{
return strcmp(cstr, "--") == 0;
};
auto first = argv; auto first = argv;
auto last = argv + argc; auto last = argv + argc;
auto separator = std::find_if(first, last, is_separator); auto separator = std::find_if(first, last, is_separator);
zeek_args.emplace_back(argv[0]); zeek_args.emplace_back(argv[0]);
if ( separator != last ) if ( separator != last ) {
{
auto first_zeek_arg = std::next(separator); auto first_zeek_arg = std::next(separator);
for ( auto i = first_zeek_arg; i != last; ++i ) for ( auto i = first_zeek_arg; i != last; ++i )
@ -326,34 +300,26 @@ Options parse_cmdline(int argc, char** argv)
for ( ptrdiff_t i = 0; i < std::distance(first, separator); ++i ) for ( ptrdiff_t i = 0; i < std::distance(first, separator); ++i )
rval.doctest_args.emplace_back(argv[i]); rval.doctest_args.emplace_back(argv[i]);
} }
else else {
{ if ( argc > 1 ) {
if ( argc > 1 ) auto endsWith = [](const std::string& str, const std::string& suffix) {
{
auto endsWith = [](const std::string& str, const std::string& suffix)
{
return str.size() >= suffix.size() && return str.size() >= suffix.size() &&
0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix); 0 == str.compare(str.size() - suffix.size(), suffix.size(), suffix);
}; };
auto i = 0; auto i = 0;
for ( ; i < argc && ! endsWith(argv[i], "--"); ++i ) for ( ; i < argc && ! endsWith(argv[i], "--"); ++i ) {
{
zeek_args.emplace_back(argv[i]); zeek_args.emplace_back(argv[i]);
} }
if ( i < argc ) if ( i < argc ) {
{
// If a script is invoked with Zeek as the interpreter, the arguments provided // If a script is invoked with Zeek as the interpreter, the arguments provided
// directly in the interpreter line of the script won't be broken apart in the // directly in the interpreter line of the script won't be broken apart in the
// argv on Linux so we split it up here. // argv on Linux so we split it up here.
if ( endsWith(argv[i], "--") && zeek_args.size() == 1 ) if ( endsWith(argv[i], "--") && zeek_args.size() == 1 ) {
{
std::istringstream iss(argv[i]); std::istringstream iss(argv[i]);
for ( std::string s; iss >> s; ) for ( std::string s; iss >> s; ) {
{ if ( ! endsWith(s, "--") ) {
if ( ! endsWith(s, "--") )
{
zeek_args.emplace_back(s); zeek_args.emplace_back(s);
} }
} }
@ -434,8 +400,7 @@ Options parse_cmdline(int argc, char** argv)
}; };
char opts[256]; char opts[256];
util::safe_strncpy(opts, "B:c:E:e:f:G:H:I:i:j::n:O:0:o:p:r:s:T:t:U:w:X:CDFMNPQSWabdhmuvV", util::safe_strncpy(opts, "B:c:E:e:f:G:H:I:i:j::n:O:0:o:p:r:s:T:t:U:w:X:CDFMNPQSWabdhmuvV", sizeof(opts));
sizeof(opts));
int op; int op;
int long_optsind; int long_optsind;
@ -447,40 +412,22 @@ Options parse_cmdline(int argc, char** argv)
for ( size_t i = 0; i < zeek_args.size(); ++i ) for ( size_t i = 0; i < zeek_args.size(); ++i )
zargs[i] = zeek_args[i].data(); zargs[i] = zeek_args[i].data();
while ( (op = getopt_long(zeek_args.size(), zargs.get(), opts, long_opts, &long_optsind)) != while ( (op = getopt_long(zeek_args.size(), zargs.get(), opts, long_opts, &long_optsind)) != EOF )
EOF ) switch ( op ) {
switch ( op ) case 'a': rval.parse_only = true; break;
{ case 'b': rval.bare_mode = true; break;
case 'a': case 'c': rval.unprocessed_output_file = optarg; break;
rval.parse_only = true; case 'd': rval.debug_scripts = true; break;
break; case 'e': rval.script_code_to_exec = optarg; break;
case 'b': case 'f': rval.pcap_filter = optarg; break;
rval.bare_mode = true; case 'h': rval.print_usage = true; break;
break;
case 'c':
rval.unprocessed_output_file = optarg;
break;
case 'd':
rval.debug_scripts = true;
break;
case 'e':
rval.script_code_to_exec = optarg;
break;
case 'f':
rval.pcap_filter = optarg;
break;
case 'h':
rval.print_usage = true;
break;
case 'i': case 'i':
if ( rval.interface ) if ( rval.interface ) {
{
fprintf(stderr, "ERROR: Only a single interface option (-i) is allowed.\n"); fprintf(stderr, "ERROR: Only a single interface option (-i) is allowed.\n");
exit(1); exit(1);
} }
if ( rval.pcap_file ) if ( rval.pcap_file ) {
{
fprintf(stderr, "ERROR: Using -i is not allow when reading a pcap file.\n"); fprintf(stderr, "ERROR: Using -i is not allow when reading a pcap file.\n");
exit(1); exit(1);
} }
@ -489,128 +436,74 @@ Options parse_cmdline(int argc, char** argv)
break; break;
case 'j': case 'j':
rval.supervisor_mode = true; rval.supervisor_mode = true;
if ( optarg ) if ( optarg ) {
{
// TODO: for supervised offline pcap reading, the argument is // TODO: for supervised offline pcap reading, the argument is
// expected to be number of workers like "-j 4" or possibly a // expected to be number of workers like "-j 4" or possibly a
// list of worker/proxy/logger counts like "-j 4,2,1" // list of worker/proxy/logger counts like "-j 4,2,1"
} }
break; break;
case 'p': case 'p': rval.script_prefixes.emplace_back(optarg); break;
rval.script_prefixes.emplace_back(optarg);
break;
case 'r': case 'r':
if ( rval.pcap_file ) if ( rval.pcap_file ) {
{
fprintf(stderr, "ERROR: Only a single readfile option (-r) is allowed.\n"); fprintf(stderr, "ERROR: Only a single readfile option (-r) is allowed.\n");
exit(1); exit(1);
} }
if ( rval.interface ) if ( rval.interface ) {
{
fprintf(stderr, "Using -r is not allowed when reading a live interface.\n"); fprintf(stderr, "Using -r is not allowed when reading a live interface.\n");
exit(1); exit(1);
} }
rval.pcap_file = optarg; rval.pcap_file = optarg;
break; break;
case 's': case 's': rval.signature_files.emplace_back(optarg); break;
rval.signature_files.emplace_back(optarg); case 't': rval.debug_script_tracing_file = optarg; break;
break; case 'u': ++rval.analysis_options.usage_issues; break;
case 't': case 'v': rval.print_version = true; break;
rval.debug_script_tracing_file = optarg; case 'V': rval.print_build_info = true; break;
break; case 'w': rval.pcap_output_file = optarg; break;
case 'u':
++rval.analysis_options.usage_issues;
break;
case 'v':
rval.print_version = true;
break;
case 'V':
rval.print_build_info = true;
break;
case 'w':
rval.pcap_output_file = optarg;
break;
case 'B': case 'B':
#ifdef DEBUG #ifdef DEBUG
rval.debug_log_streams = optarg; rval.debug_log_streams = optarg;
#else #else
if ( util::streq(optarg, "help") ) if ( util::streq(optarg, "help") ) {
{
fprintf(stderr, "debug streams unavailable\n"); fprintf(stderr, "debug streams unavailable\n");
exit(1); exit(1);
} }
#endif #endif
break; break;
case 'C': case 'C': rval.ignore_checksums = true; break;
rval.ignore_checksums = true; case 'D': rval.deterministic_mode = true; break;
break; case 'E': rval.event_trace_file = optarg; break;
case 'D':
rval.deterministic_mode = true;
break;
case 'E':
rval.event_trace_file = optarg;
break;
case 'F': case 'F':
if ( rval.dns_mode != detail::DNS_DEFAULT ) if ( rval.dns_mode != detail::DNS_DEFAULT )
usage(zargs[0], 1); usage(zargs[0], 1);
rval.dns_mode = detail::DNS_FORCE; rval.dns_mode = detail::DNS_FORCE;
break; break;
case 'G': case 'G': rval.random_seed_input_file = optarg; break;
rval.random_seed_input_file = optarg; case 'H': rval.random_seed_output_file = optarg; break;
break; case 'I': rval.identifier_to_print = optarg; break;
case 'H': case 'N': ++rval.print_plugins; break;
rval.random_seed_output_file = optarg; case 'O': set_analysis_option(optarg, rval); break;
break; case 'o': add_func_analysis_pattern(rval.analysis_options, optarg); break;
case 'I': case '0': add_file_analysis_pattern(rval.analysis_options, optarg); break;
rval.identifier_to_print = optarg;
break;
case 'N':
++rval.print_plugins;
break;
case 'O':
set_analysis_option(optarg, rval);
break;
case 'o':
add_func_analysis_pattern(rval.analysis_options, optarg);
break;
case '0':
add_file_analysis_pattern(rval.analysis_options, optarg);
break;
case 'P': case 'P':
if ( rval.dns_mode != detail::DNS_DEFAULT ) if ( rval.dns_mode != detail::DNS_DEFAULT )
usage(zargs[0], 1); usage(zargs[0], 1);
rval.dns_mode = detail::DNS_PRIME; rval.dns_mode = detail::DNS_PRIME;
break; break;
case 'Q': case 'Q': rval.print_execution_time = true; break;
rval.print_execution_time = true; case 'S': rval.print_signature_debug_info = true; break;
break; case 'T': rval.signature_re_level = atoi(optarg); break;
case 'S': case 'U': rval.process_status_file = optarg; break;
rval.print_signature_debug_info = true; case 'W': rval.use_watchdog = true; break;
break; case 'X': rval.zeekygen_config_file = optarg; break;
case 'T':
rval.signature_re_level = atoi(optarg);
break;
case 'U':
rval.process_status_file = optarg;
break;
case 'W':
rval.use_watchdog = true;
break;
case 'X':
rval.zeekygen_config_file = optarg;
break;
#ifdef USE_PERFTOOLS_DEBUG #ifdef USE_PERFTOOLS_DEBUG
case 'm': case 'm': rval.perftools_check_leaks = 1; break;
rval.perftools_check_leaks = 1; case 'M': rval.perftools_profile = 1; break;
break;
case 'M':
rval.perftools_profile = 1;
break;
#endif #endif
case '~': case '~':
@ -627,15 +520,13 @@ Options parse_cmdline(int argc, char** argv)
case 0: case 0:
// This happens for long options that don't have // This happens for long options that don't have
// a short-option equivalent. // a short-option equivalent.
if ( profile_scripts ) if ( profile_scripts ) {
{
profile_filename = optarg ? optarg : ""; profile_filename = optarg ? optarg : "";
enable_script_profile = true; enable_script_profile = true;
profile_scripts = 0; profile_scripts = 0;
} }
if ( profile_script_call_stacks ) if ( profile_script_call_stacks ) {
{
enable_script_profile_call_stacks = true; enable_script_profile_call_stacks = true;
profile_script_call_stacks = 0; profile_script_call_stacks = 0;
} }
@ -645,18 +536,13 @@ Options parse_cmdline(int argc, char** argv)
break; break;
case '?': case '?':
default: default: usage(zargs[0], 1); break;
usage(zargs[0], 1);
break;
} }
if ( ! enable_script_profile && enable_script_profile_call_stacks ) if ( ! enable_script_profile && enable_script_profile_call_stacks )
fprintf( fprintf(stderr, "ERROR: --profile-scripts-traces requires --profile-scripts to be passed as well.\n");
stderr,
"ERROR: --profile-scripts-traces requires --profile-scripts to be passed as well.\n");
if ( enable_script_profile ) if ( enable_script_profile ) {
{
activate_script_profiling(profile_filename.empty() ? nullptr : profile_filename.c_str(), activate_script_profiling(profile_filename.empty() ? nullptr : profile_filename.c_str(),
enable_script_profile_call_stacks); enable_script_profile_call_stacks);
} }
@ -664,8 +550,7 @@ Options parse_cmdline(int argc, char** argv)
// Process remaining arguments. X=Y arguments indicate script // Process remaining arguments. X=Y arguments indicate script
// variable/parameter assignments. X::Y arguments indicate plugins to // variable/parameter assignments. X::Y arguments indicate plugins to
// activate/query. The remainder are treated as scripts to load. // activate/query. The remainder are treated as scripts to load.
while ( optind < static_cast<int>(zeek_args.size()) ) while ( optind < static_cast<int>(zeek_args.size()) ) {
{
if ( strchr(zargs[optind], '=') ) if ( strchr(zargs[optind], '=') )
rval.script_options_to_set.emplace_back(zargs[optind++]); rval.script_options_to_set.emplace_back(zargs[optind++]);
else if ( strstr(zargs[optind], "::") ) else if ( strstr(zargs[optind], "::") )
@ -674,8 +559,7 @@ Options parse_cmdline(int argc, char** argv)
rval.scripts_to_load.emplace_back(zargs[optind++]); rval.scripts_to_load.emplace_back(zargs[optind++]);
} }
auto canonify_script_path = [](std::string* path) auto canonify_script_path = [](std::string* path) {
{
if ( path->empty() ) if ( path->empty() )
return; return;
@ -685,13 +569,11 @@ Options parse_cmdline(int argc, char** argv)
// Absolute path // Absolute path
return; return;
if ( (*path)[0] != '.' ) if ( (*path)[0] != '.' ) {
{
// Look up file in ZEEKPATH // Look up file in ZEEKPATH
auto res = util::find_script_file(*path, util::zeek_path()); auto res = util::find_script_file(*path, util::zeek_path());
if ( res.empty() ) if ( res.empty() ) {
{
fprintf(stderr, "failed to locate script: %s\n", path->data()); fprintf(stderr, "failed to locate script: %s\n", path->data());
exit(1); exit(1);
} }
@ -706,8 +588,7 @@ Options parse_cmdline(int argc, char** argv)
// Need to translate relative path to absolute. // Need to translate relative path to absolute.
char cwd[PATH_MAX]; char cwd[PATH_MAX];
if ( ! getcwd(cwd, sizeof(cwd)) ) if ( ! getcwd(cwd, sizeof(cwd)) ) {
{
fprintf(stderr, "failed to get current directory: %s\n", strerror(errno)); fprintf(stderr, "failed to get current directory: %s\n", strerror(errno));
exit(1); exit(1);
} }
@ -715,8 +596,7 @@ Options parse_cmdline(int argc, char** argv)
*path = std::string(cwd) + "/" + *path; *path = std::string(cwd) + "/" + *path;
}; };
if ( rval.supervisor_mode ) if ( rval.supervisor_mode ) {
{
// Translate any relative paths supplied to supervisor into absolute // Translate any relative paths supplied to supervisor into absolute
// paths for use by supervised nodes since they have the option to // paths for use by supervised nodes since they have the option to
// operate out of a different working directory. // operate out of a different working directory.

View file

@ -9,15 +9,13 @@
#include "zeek/DNS_Mgr.h" #include "zeek/DNS_Mgr.h"
#include "zeek/script_opt/ScriptOpt.h" #include "zeek/script_opt/ScriptOpt.h"
namespace zeek namespace zeek {
{
/** /**
* Options that define general Zeek processing behavior, usually determined * Options that define general Zeek processing behavior, usually determined
* from command-line arguments. * from command-line arguments.
*/ */
struct Options struct Options {
{
/** /**
* Unset options that aren't meant to be used by the supervisor, but may * Unset options that aren't meant to be used by the supervisor, but may
* make sense for supervised nodes to inherit (as opposed to flagging * make sense for supervised nodes to inherit (as opposed to flagging

View file

@ -4,11 +4,9 @@
#include "zeek/Val.h" #include "zeek/Val.h"
namespace zeek::detail namespace zeek::detail {
{
bool would_overflow(const zeek::Type* from_type, const zeek::Type* to_type, const Val* val) bool would_overflow(const zeek::Type* from_type, const zeek::Type* to_type, const Val* val) {
{
if ( ! to_type || ! from_type ) if ( ! to_type || ! from_type )
return true; return true;
@ -18,16 +16,14 @@ bool would_overflow(const zeek::Type* from_type, const zeek::Type* to_type, cons
if ( to_type->InternalType() == TYPE_INTERNAL_DOUBLE ) if ( to_type->InternalType() == TYPE_INTERNAL_DOUBLE )
return false; return false;
if ( to_type->InternalType() == TYPE_INTERNAL_UNSIGNED ) if ( to_type->InternalType() == TYPE_INTERNAL_UNSIGNED ) {
{
if ( from_type->InternalType() == TYPE_INTERNAL_DOUBLE ) if ( from_type->InternalType() == TYPE_INTERNAL_DOUBLE )
return double_to_count_would_overflow(val->InternalDouble()); return double_to_count_would_overflow(val->InternalDouble());
if ( from_type->InternalType() == TYPE_INTERNAL_INT ) if ( from_type->InternalType() == TYPE_INTERNAL_INT )
return int_to_count_would_overflow(val->InternalInt()); return int_to_count_would_overflow(val->InternalInt());
} }
if ( to_type->InternalType() == TYPE_INTERNAL_INT ) if ( to_type->InternalType() == TYPE_INTERNAL_INT ) {
{
if ( from_type->InternalType() == TYPE_INTERNAL_DOUBLE ) if ( from_type->InternalType() == TYPE_INTERNAL_DOUBLE )
return double_to_int_would_overflow(val->InternalDouble()); return double_to_int_would_overflow(val->InternalDouble());
if ( from_type->InternalType() == TYPE_INTERNAL_UNSIGNED ) if ( from_type->InternalType() == TYPE_INTERNAL_UNSIGNED )
@ -37,4 +33,4 @@ bool would_overflow(const zeek::Type* from_type, const zeek::Type* to_type, cons
return false; return false;
} }
} } // namespace zeek::detail

View file

@ -4,29 +4,18 @@
#include "zeek/Type.h" #include "zeek/Type.h"
namespace zeek::detail namespace zeek::detail {
{
inline bool double_to_count_would_overflow(double v) inline bool double_to_count_would_overflow(double v) { return v < 0.0 || v > static_cast<double>(UINT64_MAX); }
{
return v < 0.0 || v > static_cast<double>(UINT64_MAX);
}
inline bool int_to_count_would_overflow(zeek_int_t v) inline bool int_to_count_would_overflow(zeek_int_t v) { return v < 0.0; }
{
return v < 0.0;
}
inline bool double_to_int_would_overflow(double v) inline bool double_to_int_would_overflow(double v) {
{
return v < static_cast<double>(INT64_MIN) || v > static_cast<double>(INT64_MAX); return v < static_cast<double>(INT64_MIN) || v > static_cast<double>(INT64_MAX);
} }
inline bool count_to_int_would_overflow(zeek_uint_t v) inline bool count_to_int_would_overflow(zeek_uint_t v) { return v > INT64_MAX; }
{
return v > INT64_MAX;
}
extern bool would_overflow(const zeek::Type* from_type, const zeek::Type* to_type, const Val* val); extern bool would_overflow(const zeek::Type* from_type, const zeek::Type* to_type, const Val* val);
} } // namespace zeek::detail

View file

@ -2,24 +2,20 @@
#include "zeek/IP.h" #include "zeek/IP.h"
namespace zeek::detail namespace zeek::detail {
{
void PacketFilter::DeleteFilter(void* data) void PacketFilter::DeleteFilter(void* data) {
{
auto f = static_cast<Filter*>(data); auto f = static_cast<Filter*>(data);
delete f; delete f;
} }
PacketFilter::PacketFilter(bool arg_default) PacketFilter::PacketFilter(bool arg_default) {
{
default_match = arg_default; default_match = arg_default;
src_filter.SetDeleteFunction(PacketFilter::DeleteFilter); src_filter.SetDeleteFunction(PacketFilter::DeleteFilter);
dst_filter.SetDeleteFunction(PacketFilter::DeleteFilter); dst_filter.SetDeleteFunction(PacketFilter::DeleteFilter);
} }
void PacketFilter::AddSrc(const IPAddr& src, uint32_t tcp_flags, double probability) void PacketFilter::AddSrc(const IPAddr& src, uint32_t tcp_flags, double probability) {
{
Filter* f = new Filter; Filter* f = new Filter;
f->tcp_flags = tcp_flags; f->tcp_flags = tcp_flags;
f->probability = probability * static_cast<double>(util::detail::max_random()); f->probability = probability * static_cast<double>(util::detail::max_random());
@ -27,8 +23,7 @@ void PacketFilter::AddSrc(const IPAddr& src, uint32_t tcp_flags, double probabil
delete prev; delete prev;
} }
void PacketFilter::AddSrc(Val* src, uint32_t tcp_flags, double probability) void PacketFilter::AddSrc(Val* src, uint32_t tcp_flags, double probability) {
{
Filter* f = new Filter; Filter* f = new Filter;
f->tcp_flags = tcp_flags; f->tcp_flags = tcp_flags;
f->probability = probability * static_cast<double>(util::detail::max_random()); f->probability = probability * static_cast<double>(util::detail::max_random());
@ -36,8 +31,7 @@ void PacketFilter::AddSrc(Val* src, uint32_t tcp_flags, double probability)
delete prev; delete prev;
} }
void PacketFilter::AddDst(const IPAddr& dst, uint32_t tcp_flags, double probability) void PacketFilter::AddDst(const IPAddr& dst, uint32_t tcp_flags, double probability) {
{
Filter* f = new Filter; Filter* f = new Filter;
f->tcp_flags = tcp_flags; f->tcp_flags = tcp_flags;
f->probability = probability * static_cast<double>(util::detail::max_random()); f->probability = probability * static_cast<double>(util::detail::max_random());
@ -45,8 +39,7 @@ void PacketFilter::AddDst(const IPAddr& dst, uint32_t tcp_flags, double probabil
delete prev; delete prev;
} }
void PacketFilter::AddDst(Val* dst, uint32_t tcp_flags, double probability) void PacketFilter::AddDst(Val* dst, uint32_t tcp_flags, double probability) {
{
Filter* f = new Filter; Filter* f = new Filter;
f->tcp_flags = tcp_flags; f->tcp_flags = tcp_flags;
f->probability = probability * static_cast<double>(util::detail::max_random()); f->probability = probability * static_cast<double>(util::detail::max_random());
@ -54,36 +47,31 @@ void PacketFilter::AddDst(Val* dst, uint32_t tcp_flags, double probability)
delete prev; delete prev;
} }
bool PacketFilter::RemoveSrc(const IPAddr& src) bool PacketFilter::RemoveSrc(const IPAddr& src) {
{
auto f = static_cast<Filter*>(src_filter.Remove(src, 128)); auto f = static_cast<Filter*>(src_filter.Remove(src, 128));
delete f; delete f;
return f != nullptr; return f != nullptr;
} }
bool PacketFilter::RemoveSrc(Val* src) bool PacketFilter::RemoveSrc(Val* src) {
{
auto f = static_cast<Filter*>(src_filter.Remove(src)); auto f = static_cast<Filter*>(src_filter.Remove(src));
delete f; delete f;
return f != nullptr; return f != nullptr;
} }
bool PacketFilter::RemoveDst(const IPAddr& dst) bool PacketFilter::RemoveDst(const IPAddr& dst) {
{
auto f = static_cast<Filter*>(dst_filter.Remove(dst, 128)); auto f = static_cast<Filter*>(dst_filter.Remove(dst, 128));
delete f; delete f;
return f != nullptr; return f != nullptr;
} }
bool PacketFilter::RemoveDst(Val* dst) bool PacketFilter::RemoveDst(Val* dst) {
{
auto f = static_cast<Filter*>(dst_filter.Remove(dst)); auto f = static_cast<Filter*>(dst_filter.Remove(dst));
delete f; delete f;
return f != nullptr; return f != nullptr;
} }
bool PacketFilter::Match(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen) bool PacketFilter::Match(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen) {
{
Filter* f = (Filter*)src_filter.Lookup(ip->SrcAddr(), 128); Filter* f = (Filter*)src_filter.Lookup(ip->SrcAddr(), 128);
if ( f ) if ( f )
return MatchFilter(*f, *ip, len, caplen); return MatchFilter(*f, *ip, len, caplen);
@ -95,17 +83,14 @@ bool PacketFilter::Match(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen)
return default_match; return default_match;
} }
bool PacketFilter::MatchFilter(const Filter& f, const IP_Hdr& ip, int len, int caplen) bool PacketFilter::MatchFilter(const Filter& f, const IP_Hdr& ip, int len, int caplen) {
{ if ( ip.NextProto() == IPPROTO_TCP && f.tcp_flags ) {
if ( ip.NextProto() == IPPROTO_TCP && f.tcp_flags )
{
// Caution! The packet sanity checks have not been performed yet // Caution! The packet sanity checks have not been performed yet
int ip_hdr_len = ip.HdrLen(); int ip_hdr_len = ip.HdrLen();
len -= ip_hdr_len; // remove IP header len -= ip_hdr_len; // remove IP header
caplen -= ip_hdr_len; caplen -= ip_hdr_len;
if ( (unsigned int)len < sizeof(struct tcphdr) || if ( (unsigned int)len < sizeof(struct tcphdr) || (unsigned int)caplen < sizeof(struct tcphdr) )
(unsigned int)caplen < sizeof(struct tcphdr) )
// Packet too short, will be dropped anyway. // Packet too short, will be dropped anyway.
return false; return false;

View file

@ -7,17 +7,14 @@
#include "zeek/IPAddr.h" #include "zeek/IPAddr.h"
#include "zeek/PrefixTable.h" #include "zeek/PrefixTable.h"
namespace zeek namespace zeek {
{
class IP_Hdr; class IP_Hdr;
class Val; class Val;
namespace detail namespace detail {
{
class PacketFilter class PacketFilter {
{
public: public:
explicit PacketFilter(bool arg_default); explicit PacketFilter(bool arg_default);
~PacketFilter() {} ~PacketFilter() {}
@ -41,8 +38,7 @@ public:
bool Match(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen); bool Match(const std::shared_ptr<IP_Hdr>& ip, int len, int caplen);
private: private:
struct Filter struct Filter {
{
uint32_t tcp_flags; uint32_t tcp_flags;
double probability; double probability;
}; };

View file

@ -9,11 +9,9 @@
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
namespace zeek::detail namespace zeek::detail {
{
static void pipe_fail(int eno) static void pipe_fail(int eno) {
{
char tmp[256]; char tmp[256];
zeek::util::zeek_strerror_r(eno, tmp, sizeof(tmp)); zeek::util::zeek_strerror_r(eno, tmp, sizeof(tmp));
@ -23,15 +21,13 @@ static void pipe_fail(int eno)
fprintf(stderr, "Pipe failure: %s", tmp); fprintf(stderr, "Pipe failure: %s", tmp);
} }
static int set_flags(int fd, int flags) static int set_flags(int fd, int flags) {
{
auto rval = fcntl(fd, F_GETFD); auto rval = fcntl(fd, F_GETFD);
if ( rval == -1 ) if ( rval == -1 )
pipe_fail(errno); pipe_fail(errno);
if ( flags ) if ( flags ) {
{
rval |= flags; rval |= flags;
if ( fcntl(fd, F_SETFD, rval) == -1 ) if ( fcntl(fd, F_SETFD, rval) == -1 )
@ -41,15 +37,13 @@ static int set_flags(int fd, int flags)
return rval; return rval;
} }
static int unset_flags(int fd, int flags) static int unset_flags(int fd, int flags) {
{
auto rval = fcntl(fd, F_GETFD); auto rval = fcntl(fd, F_GETFD);
if ( rval == -1 ) if ( rval == -1 )
pipe_fail(errno); pipe_fail(errno);
if ( flags ) if ( flags ) {
{
rval &= ~flags; rval &= ~flags;
if ( fcntl(fd, F_SETFD, rval) == -1 ) if ( fcntl(fd, F_SETFD, rval) == -1 )
@ -59,15 +53,13 @@ static int unset_flags(int fd, int flags)
return rval; return rval;
} }
static int set_status_flags(int fd, int flags) static int set_status_flags(int fd, int flags) {
{
auto rval = fcntl(fd, F_GETFL); auto rval = fcntl(fd, F_GETFL);
if ( rval == -1 ) if ( rval == -1 )
pipe_fail(errno); pipe_fail(errno);
if ( flags ) if ( flags ) {
{
rval |= flags; rval |= flags;
if ( fcntl(fd, F_SETFL, rval) == -1 ) if ( fcntl(fd, F_SETFL, rval) == -1 )
@ -77,8 +69,7 @@ static int set_status_flags(int fd, int flags)
return rval; return rval;
} }
static int dup_or_fail(int fd, int flags, int status_flags) static int dup_or_fail(int fd, int flags, int status_flags) {
{
int rval = dup(fd); int rval = dup(fd);
if ( rval < 0 ) if ( rval < 0 )
@ -89,15 +80,12 @@ static int dup_or_fail(int fd, int flags, int status_flags)
return rval; return rval;
} }
Pipe::Pipe(int flags0, int flags1, int status_flags0, int status_flags1, int* arg_fds) Pipe::Pipe(int flags0, int flags1, int status_flags0, int status_flags1, int* arg_fds) {
{ if ( arg_fds ) {
if ( arg_fds )
{
fds[0] = arg_fds[0]; fds[0] = arg_fds[0];
fds[1] = arg_fds[1]; fds[1] = arg_fds[1];
} }
else else {
{
// pipe2 can set flags atomically, but not yet available everywhere. // pipe2 can set flags atomically, but not yet available everywhere.
if ( ::pipe(fds) ) if ( ::pipe(fds) )
pipe_fail(errno); pipe_fail(errno);
@ -109,26 +97,22 @@ Pipe::Pipe(int flags0, int flags1, int status_flags0, int status_flags1, int* ar
status_flags[1] = set_status_flags(fds[1], status_flags1); status_flags[1] = set_status_flags(fds[1], status_flags1);
} }
void Pipe::SetFlags(int arg_flags) void Pipe::SetFlags(int arg_flags) {
{
flags[0] = set_flags(fds[0], arg_flags); flags[0] = set_flags(fds[0], arg_flags);
flags[1] = set_flags(fds[1], arg_flags); flags[1] = set_flags(fds[1], arg_flags);
} }
void Pipe::UnsetFlags(int arg_flags) void Pipe::UnsetFlags(int arg_flags) {
{
flags[0] = unset_flags(fds[0], arg_flags); flags[0] = unset_flags(fds[0], arg_flags);
flags[1] = unset_flags(fds[1], arg_flags); flags[1] = unset_flags(fds[1], arg_flags);
} }
Pipe::~Pipe() Pipe::~Pipe() {
{
close(fds[0]); close(fds[0]);
close(fds[1]); close(fds[1]);
} }
Pipe::Pipe(const Pipe& other) Pipe::Pipe(const Pipe& other) {
{
fds[0] = dup_or_fail(other.fds[0], other.flags[0], other.status_flags[0]); fds[0] = dup_or_fail(other.fds[0], other.flags[0], other.status_flags[0]);
fds[1] = dup_or_fail(other.fds[1], other.flags[1], other.status_flags[1]); fds[1] = dup_or_fail(other.fds[1], other.flags[1], other.status_flags[1]);
flags[0] = other.flags[0]; flags[0] = other.flags[0];
@ -137,8 +121,7 @@ Pipe::Pipe(const Pipe& other)
status_flags[1] = other.status_flags[1]; status_flags[1] = other.status_flags[1];
} }
Pipe& Pipe::operator=(const Pipe& other) Pipe& Pipe::operator=(const Pipe& other) {
{
if ( this == &other ) if ( this == &other )
return *this; return *this;
@ -155,8 +138,6 @@ Pipe& Pipe::operator=(const Pipe& other)
PipePair::PipePair(int flags, int status_flags, int* fds) PipePair::PipePair(int flags, int status_flags, int* fds)
: pipes{Pipe(flags, flags, status_flags, status_flags, fds ? fds + 0 : nullptr), : pipes{Pipe(flags, flags, status_flags, status_flags, fds ? fds + 0 : nullptr),
Pipe(flags, flags, status_flags, status_flags, fds ? fds + 2 : nullptr)} Pipe(flags, flags, status_flags, status_flags, fds ? fds + 2 : nullptr)} {}
{
}
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -2,11 +2,9 @@
#pragma once #pragma once
namespace zeek::detail namespace zeek::detail {
{
class Pipe class Pipe {
{
public: public:
/** /**
* Create a pair of file descriptors via pipe(), or aborts if it cannot. * Create a pair of file descriptors via pipe(), or aborts if it cannot.
@ -18,8 +16,7 @@ public:
* than create ones from a new pipe. Should point to memory containing * than create ones from a new pipe. Should point to memory containing
* two consecutive file descriptors, the "read" one and then the "write" one. * two consecutive file descriptors, the "read" one and then the "write" one.
*/ */
explicit Pipe(int flags0 = 0, int flags1 = 0, int status_flags0 = 0, int status_flags1 = 0, explicit Pipe(int flags0 = 0, int flags1 = 0, int status_flags0 = 0, int status_flags1 = 0, int* fds = nullptr);
int* fds = nullptr);
/** /**
* Close the pair of file descriptors owned by the object. * Close the pair of file descriptors owned by the object.
@ -68,8 +65,7 @@ private:
/** /**
* A pair of pipes that can be used for bi-directional IPC. * A pair of pipes that can be used for bi-directional IPC.
*/ */
class PipePair class PipePair {
{
public: public:
/** /**
* Create a pair of pipes * Create a pair of pipes

View file

@ -17,15 +17,12 @@
using namespace std; using namespace std;
struct PolicyFile struct PolicyFile {
{ PolicyFile() {
PolicyFile()
{
filedata = nullptr; filedata = nullptr;
lmtime = 0; lmtime = 0;
} }
~PolicyFile() ~PolicyFile() {
{
delete[] filedata; delete[] filedata;
filedata = nullptr; filedata = nullptr;
} }
@ -38,17 +35,14 @@ struct PolicyFile
using PolicyFileMap = map<string, PolicyFile*>; using PolicyFileMap = map<string, PolicyFile*>;
static PolicyFileMap policy_files; static PolicyFileMap policy_files;
namespace zeek::detail namespace zeek::detail {
{
int how_many_lines_in(const char* policy_filename) int how_many_lines_in(const char* policy_filename) {
{
if ( ! policy_filename ) if ( ! policy_filename )
reporter->InternalError("NULL value passed to how_many_lines_in\n"); reporter->InternalError("NULL value passed to how_many_lines_in\n");
FILE* throwaway = fopen(policy_filename, "r"); FILE* throwaway = fopen(policy_filename, "r");
if ( ! throwaway ) if ( ! throwaway ) {
{
debug_msg("Could not open policy file: %s.\n", policy_filename); debug_msg("Could not open policy file: %s.\n", policy_filename);
return -1; return -1;
} }
@ -58,11 +52,9 @@ int how_many_lines_in(const char* policy_filename)
PolicyFileMap::iterator match; PolicyFileMap::iterator match;
match = policy_files.find(policy_filename); match = policy_files.find(policy_filename);
if ( match == policy_files.end() ) if ( match == policy_files.end() ) {
{
match = policy_files.find(policy_filename); match = policy_files.find(policy_filename);
if ( match == policy_files.end() ) if ( match == policy_files.end() ) {
{
debug_msg("Policy file %s was not loaded.\n", policy_filename); debug_msg("Policy file %s was not loaded.\n", policy_filename);
return -1; return -1;
} }
@ -72,9 +64,7 @@ int how_many_lines_in(const char* policy_filename)
return pf->lines.size(); return pf->lines.size();
} }
bool LoadPolicyFileText(const char* policy_filename, bool LoadPolicyFileText(const char* policy_filename, const std::optional<std::string>& preloaded_content) {
const std::optional<std::string>& preloaded_content)
{
if ( ! policy_filename ) if ( ! policy_filename )
return true; return true;
@ -84,26 +74,22 @@ bool LoadPolicyFileText(const char* policy_filename,
PolicyFile* pf = new PolicyFile; PolicyFile* pf = new PolicyFile;
policy_files.insert(PolicyFileMap::value_type(policy_filename, pf)); policy_files.insert(PolicyFileMap::value_type(policy_filename, pf));
if ( preloaded_content ) if ( preloaded_content ) {
{
auto size = preloaded_content->size(); auto size = preloaded_content->size();
pf->filedata = new char[size + 1]; pf->filedata = new char[size + 1];
memcpy(pf->filedata, preloaded_content->data(), size); memcpy(pf->filedata, preloaded_content->data(), size);
pf->filedata[size] = '\0'; pf->filedata[size] = '\0';
} }
else else {
{
FILE* f = fopen(policy_filename, "r"); FILE* f = fopen(policy_filename, "r");
if ( ! f ) if ( ! f ) {
{
debug_msg("Could not open policy file: %s.\n", policy_filename); debug_msg("Could not open policy file: %s.\n", policy_filename);
return false; return false;
} }
struct stat st; struct stat st;
if ( fstat(fileno(f), &st) != 0 ) if ( fstat(fileno(f), &st) != 0 ) {
{
char buf[256]; char buf[256];
util::zeek_strerror_r(errno, buf, sizeof(buf)); util::zeek_strerror_r(errno, buf, sizeof(buf));
reporter->Error("fstat failed on %s: %s", policy_filename, buf); reporter->Error("fstat failed on %s: %s", policy_filename, buf);
@ -127,10 +113,8 @@ bool LoadPolicyFileText(const char* policy_filename,
// Separate the string by newlines. // Separate the string by newlines.
pf->lines.push_back(pf->filedata); pf->lines.push_back(pf->filedata);
for ( char* iter = pf->filedata; *iter; ++iter ) for ( char* iter = pf->filedata; *iter; ++iter ) {
{ if ( *iter == '\n' ) {
if ( *iter == '\n' )
{
*iter = 0; *iter = 0;
if ( *(iter + 1) ) if ( *(iter + 1) )
pf->lines.push_back(iter + 1); pf->lines.push_back(iter + 1);
@ -144,15 +128,12 @@ bool LoadPolicyFileText(const char* policy_filename,
} }
// REMEMBER: line number arguments are indexed from 0. // REMEMBER: line number arguments are indexed from 0.
bool PrintLines(const char* policy_filename, unsigned int start_line, unsigned int how_many_lines, bool PrintLines(const char* policy_filename, unsigned int start_line, unsigned int how_many_lines, bool show_numbers) {
bool show_numbers)
{
if ( ! policy_filename ) if ( ! policy_filename )
return true; return true;
FILE* throwaway = fopen(policy_filename, "r"); FILE* throwaway = fopen(policy_filename, "r");
if ( ! throwaway ) if ( ! throwaway ) {
{
debug_msg("Could not open policy file: %s.\n", policy_filename); debug_msg("Could not open policy file: %s.\n", policy_filename);
return false; return false;
} }
@ -162,11 +143,9 @@ bool PrintLines(const char* policy_filename, unsigned int start_line, unsigned i
PolicyFileMap::iterator match; PolicyFileMap::iterator match;
match = policy_files.find(policy_filename); match = policy_files.find(policy_filename);
if ( match == policy_files.end() ) if ( match == policy_files.end() ) {
{
match = policy_files.find(policy_filename); match = policy_files.find(policy_filename);
if ( match == policy_files.end() ) if ( match == policy_files.end() ) {
{
debug_msg("Policy file %s was not loaded.\n", policy_filename); debug_msg("Policy file %s was not loaded.\n", policy_filename);
return false; return false;
} }
@ -177,18 +156,15 @@ bool PrintLines(const char* policy_filename, unsigned int start_line, unsigned i
if ( start_line < 1 ) if ( start_line < 1 )
start_line = 1; start_line = 1;
if ( start_line > pf->lines.size() ) if ( start_line > pf->lines.size() ) {
{ debug_msg("Line number %d out of range; %s has %d lines\n", start_line, policy_filename, int(pf->lines.size()));
debug_msg("Line number %d out of range; %s has %d lines\n", start_line, policy_filename,
int(pf->lines.size()));
return false; return false;
} }
if ( start_line + how_many_lines - 1 > pf->lines.size() ) if ( start_line + how_many_lines - 1 > pf->lines.size() )
how_many_lines = pf->lines.size() - start_line + 1; how_many_lines = pf->lines.size() - start_line + 1;
for ( unsigned int i = 0; i < how_many_lines; ++i ) for ( unsigned int i = 0; i < how_many_lines; ++i ) {
{
if ( show_numbers ) if ( show_numbers )
debug_msg("%d\t", i + start_line); debug_msg("%d\t", i + start_line);

View file

@ -17,16 +17,13 @@
#include <optional> #include <optional>
#include <string> #include <string>
namespace zeek::detail namespace zeek::detail {
{
int how_many_lines_in(const char* policy_filename); int how_many_lines_in(const char* policy_filename);
bool LoadPolicyFileText(const char* policy_filename, bool LoadPolicyFileText(const char* policy_filename, const std::optional<std::string>& preloaded_content = {});
const std::optional<std::string>& preloaded_content = {});
// start_line is 1-based (the intuitive way) // start_line is 1-based (the intuitive way)
bool PrintLines(const char* policy_filename, unsigned int start_line, unsigned int how_many_lines, bool PrintLines(const char* policy_filename, unsigned int start_line, unsigned int how_many_lines, bool show_numbers);
bool show_numbers);
} // namespace zeek::detail } // namespace zeek::detail

View file

@ -3,11 +3,9 @@
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
#include "zeek/Val.h" #include "zeek/Val.h"
namespace zeek::detail namespace zeek::detail {
{
prefix_t* PrefixTable::MakePrefix(const IPAddr& addr, int width) prefix_t* PrefixTable::MakePrefix(const IPAddr& addr, int width) {
{
prefix_t* prefix = (prefix_t*)util::safe_malloc(sizeof(prefix_t)); prefix_t* prefix = (prefix_t*)util::safe_malloc(sizeof(prefix_t));
addr.CopyIPv6(&prefix->add.sin6); addr.CopyIPv6(&prefix->add.sin6);
@ -18,21 +16,17 @@ prefix_t* PrefixTable::MakePrefix(const IPAddr& addr, int width)
return prefix; return prefix;
} }
IPPrefix PrefixTable::PrefixToIPPrefix(prefix_t* prefix) IPPrefix PrefixTable::PrefixToIPPrefix(prefix_t* prefix) {
{ return IPPrefix(IPAddr(IPv6, reinterpret_cast<const uint32_t*>(&prefix->add.sin6), IPAddr::Network), prefix->bitlen,
return IPPrefix( true);
IPAddr(IPv6, reinterpret_cast<const uint32_t*>(&prefix->add.sin6), IPAddr::Network),
prefix->bitlen, true);
} }
void* PrefixTable::Insert(const IPAddr& addr, int width, void* data) void* PrefixTable::Insert(const IPAddr& addr, int width, void* data) {
{
prefix_t* prefix = MakePrefix(addr, width); prefix_t* prefix = MakePrefix(addr, width);
patricia_node_t* node = patricia_lookup(tree, prefix); patricia_node_t* node = patricia_lookup(tree, prefix);
Deref_Prefix(prefix); Deref_Prefix(prefix);
if ( ! node ) if ( ! node ) {
{
reporter->InternalWarning("Cannot create node in patricia tree"); reporter->InternalWarning("Cannot create node in patricia tree");
return nullptr; return nullptr;
} }
@ -46,30 +40,21 @@ void* PrefixTable::Insert(const IPAddr& addr, int width, void* data)
return old; return old;
} }
void* PrefixTable::Insert(const Val* value, void* data) void* PrefixTable::Insert(const Val* value, void* data) {
{
// [elem] -> elem // [elem] -> elem
if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 ) if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 )
value = value->AsListVal()->Idx(0).get(); value = value->AsListVal()->Idx(0).get();
switch ( value->GetType()->Tag() ) switch ( value->GetType()->Tag() ) {
{ case TYPE_ADDR: return Insert(value->AsAddr(), 128, data); break;
case TYPE_ADDR:
return Insert(value->AsAddr(), 128, data);
break;
case TYPE_SUBNET: case TYPE_SUBNET: return Insert(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6(), data); break;
return Insert(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6(), data);
break;
default: default: reporter->InternalWarning("Wrong index type for PrefixTable"); return nullptr;
reporter->InternalWarning("Wrong index type for PrefixTable");
return nullptr;
} }
} }
std::list<std::tuple<IPPrefix, void*>> PrefixTable::FindAll(const IPAddr& addr, int width) const std::list<std::tuple<IPPrefix, void*>> PrefixTable::FindAll(const IPAddr& addr, int width) const {
{
std::list<std::tuple<IPPrefix, void*>> out; std::list<std::tuple<IPPrefix, void*>> out;
prefix_t* prefix = MakePrefix(addr, width); prefix_t* prefix = MakePrefix(addr, width);
@ -86,16 +71,13 @@ std::list<std::tuple<IPPrefix, void*>> PrefixTable::FindAll(const IPAddr& addr,
return out; return out;
} }
std::list<std::tuple<IPPrefix, void*>> PrefixTable::FindAll(const SubNetVal* value) const std::list<std::tuple<IPPrefix, void*>> PrefixTable::FindAll(const SubNetVal* value) const {
{
return FindAll(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6()); return FindAll(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6());
} }
void* PrefixTable::Lookup(const IPAddr& addr, int width, bool exact) const void* PrefixTable::Lookup(const IPAddr& addr, int width, bool exact) const {
{
prefix_t* prefix = MakePrefix(addr, width); prefix_t* prefix = MakePrefix(addr, width);
patricia_node_t* node = exact ? patricia_search_exact(tree, prefix) patricia_node_t* node = exact ? patricia_search_exact(tree, prefix) : patricia_search_best(tree, prefix);
: patricia_search_best(tree, prefix);
int elems = 0; int elems = 0;
patricia_node_t** list = nullptr; patricia_node_t** list = nullptr;
@ -104,31 +86,23 @@ void* PrefixTable::Lookup(const IPAddr& addr, int width, bool exact) const
return node ? node->data : nullptr; return node ? node->data : nullptr;
} }
void* PrefixTable::Lookup(const Val* value, bool exact) const void* PrefixTable::Lookup(const Val* value, bool exact) const {
{
// [elem] -> elem // [elem] -> elem
if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 ) if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 )
value = value->AsListVal()->Idx(0).get(); value = value->AsListVal()->Idx(0).get();
switch ( value->GetType()->Tag() ) switch ( value->GetType()->Tag() ) {
{ case TYPE_ADDR: return Lookup(value->AsAddr(), 128, exact); break;
case TYPE_ADDR:
return Lookup(value->AsAddr(), 128, exact);
break;
case TYPE_SUBNET: case TYPE_SUBNET: return Lookup(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6(), exact); break;
return Lookup(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6(), exact);
break;
default: default:
reporter->InternalWarning("Wrong index type %d for PrefixTable", reporter->InternalWarning("Wrong index type %d for PrefixTable", value->GetType()->Tag());
value->GetType()->Tag());
return nullptr; return nullptr;
} }
} }
void* PrefixTable::Remove(const IPAddr& addr, int width) void* PrefixTable::Remove(const IPAddr& addr, int width) {
{
prefix_t* prefix = MakePrefix(addr, width); prefix_t* prefix = MakePrefix(addr, width);
patricia_node_t* node = patricia_search_exact(tree, prefix); patricia_node_t* node = patricia_search_exact(tree, prefix);
Deref_Prefix(prefix); Deref_Prefix(prefix);
@ -142,30 +116,21 @@ void* PrefixTable::Remove(const IPAddr& addr, int width)
return old; return old;
} }
void* PrefixTable::Remove(const Val* value) void* PrefixTable::Remove(const Val* value) {
{
// [elem] -> elem // [elem] -> elem
if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 ) if ( value->GetType()->Tag() == TYPE_LIST && value->AsListVal()->Length() == 1 )
value = value->AsListVal()->Idx(0).get(); value = value->AsListVal()->Idx(0).get();
switch ( value->GetType()->Tag() ) switch ( value->GetType()->Tag() ) {
{ case TYPE_ADDR: return Remove(value->AsAddr(), 128); break;
case TYPE_ADDR:
return Remove(value->AsAddr(), 128);
break;
case TYPE_SUBNET: case TYPE_SUBNET: return Remove(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6()); break;
return Remove(value->AsSubNet().Prefix(), value->AsSubNet().LengthIPv6());
break;
default: default: reporter->InternalWarning("Wrong index type for PrefixTable"); return nullptr;
reporter->InternalWarning("Wrong index type for PrefixTable");
return nullptr;
} }
} }
PrefixTable::iterator PrefixTable::InitIterator() PrefixTable::iterator PrefixTable::InitIterator() {
{
iterator i; iterator i;
i.Xsp = i.Xstack; i.Xsp = i.Xstack;
i.Xrn = tree->head; i.Xrn = tree->head;
@ -173,16 +138,13 @@ PrefixTable::iterator PrefixTable::InitIterator()
return i; return i;
} }
void* PrefixTable::GetNext(iterator* i) void* PrefixTable::GetNext(iterator* i) {
{ while ( true ) {
while ( true )
{
i->Xnode = i->Xrn; i->Xnode = i->Xrn;
if ( ! i->Xnode ) if ( ! i->Xnode )
return nullptr; return nullptr;
if ( i->Xrn->l ) if ( i->Xrn->l ) {
{
if ( i->Xrn->r ) if ( i->Xrn->r )
*i->Xsp++ = i->Xrn->r; *i->Xsp++ = i->Xrn->r;

View file

@ -1,7 +1,6 @@
#pragma once #pragma once
extern "C" extern "C" {
{
#include "zeek/3rdparty/patricia.h" #include "zeek/3rdparty/patricia.h"
} }
@ -10,20 +9,16 @@ extern "C"
#include "zeek/IPAddr.h" #include "zeek/IPAddr.h"
namespace zeek namespace zeek {
{
class Val; class Val;
class SubNetVal; class SubNetVal;
namespace detail namespace detail {
{
class PrefixTable class PrefixTable {
{
private: private:
struct iterator struct iterator {
{
patricia_node_t* Xstack[PATRICIA_MAXBITS + 1]; patricia_node_t* Xstack[PATRICIA_MAXBITS + 1];
patricia_node_t** Xsp; patricia_node_t** Xsp;
patricia_node_t* Xrn; patricia_node_t* Xrn;
@ -31,8 +26,7 @@ private:
}; };
public: public:
PrefixTable() PrefixTable() {
{
tree = New_Patricia(128); tree = New_Patricia(128);
delete_function = nullptr; delete_function = nullptr;
} }

View file

@ -10,24 +10,18 @@
#include "zeek/Reporter.h" #include "zeek/Reporter.h"
#include "zeek/util.h" #include "zeek/util.h"
namespace zeek::detail namespace zeek::detail {
{
PriorityQueue::PriorityQueue(int initial_size) : max_heap_size(initial_size) PriorityQueue::PriorityQueue(int initial_size) : max_heap_size(initial_size) { heap = new PQ_Element*[max_heap_size]; }
{
heap = new PQ_Element*[max_heap_size];
}
PriorityQueue::~PriorityQueue() PriorityQueue::~PriorityQueue() {
{
for ( int i = 0; i < heap_size; ++i ) for ( int i = 0; i < heap_size; ++i )
delete heap[i]; delete heap[i];
delete[] heap; delete[] heap;
} }
PQ_Element* PriorityQueue::Remove() PQ_Element* PriorityQueue::Remove() {
{
if ( heap_size == 0 ) if ( heap_size == 0 )
return nullptr; return nullptr;
@ -41,8 +35,7 @@ PQ_Element* PriorityQueue::Remove()
return top; return top;
} }
PQ_Element* PriorityQueue::Remove(PQ_Element* e) PQ_Element* PriorityQueue::Remove(PQ_Element* e) {
{
if ( e->Offset() < 0 || e->Offset() >= heap_size || heap[e->Offset()] != e ) if ( e->Offset() < 0 || e->Offset() >= heap_size || heap[e->Offset()] != e )
return nullptr; // not in heap return nullptr; // not in heap
@ -57,8 +50,7 @@ PQ_Element* PriorityQueue::Remove(PQ_Element* e)
return e2; return e2;
} }
bool PriorityQueue::Add(PQ_Element* e) bool PriorityQueue::Add(PQ_Element* e) {
{
SetElement(heap_size, e); SetElement(heap_size, e);
BubbleUp(heap_size); BubbleUp(heap_size);
@ -74,8 +66,7 @@ bool PriorityQueue::Add(PQ_Element* e)
return true; return true;
} }
bool PriorityQueue::Resize(int new_size) bool PriorityQueue::Resize(int new_size) {
{
PQ_Element** tmp = new PQ_Element*[new_size]; PQ_Element** tmp = new PQ_Element*[new_size];
for ( int i = 0; i < max_heap_size; ++i ) for ( int i = 0; i < max_heap_size; ++i )
tmp[i] = heap[i]; tmp[i] = heap[i];
@ -88,21 +79,18 @@ bool PriorityQueue::Resize(int new_size)
return heap != nullptr; return heap != nullptr;
} }
void PriorityQueue::BubbleUp(int bin) void PriorityQueue::BubbleUp(int bin) {
{
if ( bin == 0 ) if ( bin == 0 )
return; return;
int p = Parent(bin); int p = Parent(bin);
if ( heap[p]->Time() > heap[bin]->Time() ) if ( heap[p]->Time() > heap[bin]->Time() ) {
{
Swap(p, bin); Swap(p, bin);
BubbleUp(p); BubbleUp(p);
} }
} }
void PriorityQueue::BubbleDown(int bin) void PriorityQueue::BubbleDown(int bin) {
{
double v = heap[bin]->Time(); double v = heap[bin]->Time();
int l = LeftChild(bin); int l = LeftChild(bin);
@ -111,28 +99,23 @@ void PriorityQueue::BubbleDown(int bin)
if ( l >= heap_size ) if ( l >= heap_size )
return; // No children. return; // No children.
if ( r >= heap_size ) if ( r >= heap_size ) { // Just a left child.
{ // Just a left child.
if ( heap[l]->Time() < v ) if ( heap[l]->Time() < v )
Swap(l, bin); Swap(l, bin);
} }
else else {
{
double lv = heap[l]->Time(); double lv = heap[l]->Time();
double rv = heap[r]->Time(); double rv = heap[r]->Time();
if ( lv < rv ) if ( lv < rv ) {
{ if ( lv < v ) {
if ( lv < v )
{
Swap(l, bin); Swap(l, bin);
BubbleDown(l); BubbleDown(l);
} }
} }
else if ( rv < v ) else if ( rv < v ) {
{
Swap(r, bin); Swap(r, bin);
BubbleDown(r); BubbleDown(r);
} }

View file

@ -7,13 +7,11 @@
#include <cmath> #include <cmath>
#include <cstdint> #include <cstdint>
namespace zeek::detail namespace zeek::detail {
{
class PriorityQueue; class PriorityQueue;
class PQ_Element class PQ_Element {
{
public: public:
explicit PQ_Element(double t) : time(t) {} explicit PQ_Element(double t) : time(t) {}
virtual ~PQ_Element() = default; virtual ~PQ_Element() = default;
@ -31,15 +29,13 @@ protected:
int offset = -1; int offset = -1;
}; };
class PriorityQueue class PriorityQueue {
{
public: public:
explicit PriorityQueue(int initial_size = 16); explicit PriorityQueue(int initial_size = 16);
~PriorityQueue(); ~PriorityQueue();
// Returns the top of queue, or nil if the queue is empty. // Returns the top of queue, or nil if the queue is empty.
PQ_Element* Top() const PQ_Element* Top() const {
{
if ( heap_size == 0 ) if ( heap_size == 0 )
return nullptr; return nullptr;
@ -74,14 +70,12 @@ protected:
int RightChild(int bin) const { return LeftChild(bin) + 1; } int RightChild(int bin) const { return LeftChild(bin) + 1; }
void SetElement(int bin, PQ_Element* e) void SetElement(int bin, PQ_Element* e) {
{
heap[bin] = e; heap[bin] = e;
e->SetOffset(bin); e->SetOffset(bin);
} }
void Swap(int bin1, int bin2) void Swap(int bin1, int bin2) {
{
PQ_Element* t = heap[bin1]; PQ_Element* t = heap[bin1];
SetElement(bin1, heap[bin2]); SetElement(bin1, heap[bin2]);
SetElement(bin2, t); SetElement(bin2, t);

196
src/RE.cc
View file

@ -24,14 +24,11 @@ extern int RE_parse(void);
extern void RE_set_input(const char* str); extern void RE_set_input(const char* str);
extern void RE_done_with_scan(); extern void RE_done_with_scan();
namespace zeek namespace zeek {
{ namespace detail {
namespace detail
{
Specific_RE_Matcher::Specific_RE_Matcher(match_type arg_mt, bool arg_multiline) Specific_RE_Matcher::Specific_RE_Matcher(match_type arg_mt, bool arg_multiline)
: mt(arg_mt), multiline(arg_multiline), equiv_class(NUM_SYM) : mt(arg_mt), multiline(arg_multiline), equiv_class(NUM_SYM) {
{
any_ccl = nullptr; any_ccl = nullptr;
single_line_ccl = nullptr; single_line_ccl = nullptr;
dfa = nullptr; dfa = nullptr;
@ -39,8 +36,7 @@ Specific_RE_Matcher::Specific_RE_Matcher(match_type arg_mt, bool arg_multiline)
accepted = new AcceptingSet(); accepted = new AcceptingSet();
} }
Specific_RE_Matcher::~Specific_RE_Matcher() Specific_RE_Matcher::~Specific_RE_Matcher() {
{
for ( int i = 0; i < ccl_list.length(); ++i ) for ( int i = 0; i < ccl_list.length(); ++i )
delete ccl_list[i]; delete ccl_list[i];
@ -48,12 +44,9 @@ Specific_RE_Matcher::~Specific_RE_Matcher()
delete accepted; delete accepted;
} }
CCL* Specific_RE_Matcher::AnyCCL(bool single_line_mode) CCL* Specific_RE_Matcher::AnyCCL(bool single_line_mode) {
{ if ( single_line_mode ) {
if ( single_line_mode ) if ( ! single_line_ccl ) {
{
if ( ! single_line_ccl )
{
single_line_ccl = new CCL(); single_line_ccl = new CCL();
single_line_ccl->Negate(); single_line_ccl->Negate();
EC()->CCL_Use(single_line_ccl); EC()->CCL_Use(single_line_ccl);
@ -62,8 +55,7 @@ CCL* Specific_RE_Matcher::AnyCCL(bool single_line_mode)
return single_line_ccl; return single_line_ccl;
} }
if ( ! any_ccl ) if ( ! any_ccl ) {
{
any_ccl = new CCL(); any_ccl = new CCL();
if ( ! multiline ) if ( ! multiline )
any_ccl->Add('\n'); any_ccl->Add('\n');
@ -74,52 +66,42 @@ CCL* Specific_RE_Matcher::AnyCCL(bool single_line_mode)
return any_ccl; return any_ccl;
} }
void Specific_RE_Matcher::ConvertCCLs() void Specific_RE_Matcher::ConvertCCLs() {
{
for ( int i = 0; i < ccl_list.length(); ++i ) for ( int i = 0; i < ccl_list.length(); ++i )
equiv_class.ConvertCCL(ccl_list[i]); equiv_class.ConvertCCL(ccl_list[i]);
} }
void Specific_RE_Matcher::AddPat(const char* new_pat) void Specific_RE_Matcher::AddPat(const char* new_pat) {
{
if ( mt == MATCH_EXACTLY ) if ( mt == MATCH_EXACTLY )
AddExactPat(new_pat); AddExactPat(new_pat);
else else
AddAnywherePat(new_pat); AddAnywherePat(new_pat);
} }
void Specific_RE_Matcher::AddAnywherePat(const char* new_pat) void Specific_RE_Matcher::AddAnywherePat(const char* new_pat) {
{
AddPat(new_pat, "^?(.|\\n)*(%s)", "(%s)|(^?(.|\\n)*(%s))"); AddPat(new_pat, "^?(.|\\n)*(%s)", "(%s)|(^?(.|\\n)*(%s))");
} }
void Specific_RE_Matcher::AddExactPat(const char* new_pat) void Specific_RE_Matcher::AddExactPat(const char* new_pat) { AddPat(new_pat, "^?(%s)$?", "(%s)|(^?(%s)$?)"); }
{
AddPat(new_pat, "^?(%s)$?", "(%s)|(^?(%s)$?)");
}
void Specific_RE_Matcher::AddPat(const char* new_pat, const char* orig_fmt, const char* app_fmt) void Specific_RE_Matcher::AddPat(const char* new_pat, const char* orig_fmt, const char* app_fmt) {
{
if ( ! pattern_text.empty() ) if ( ! pattern_text.empty() )
pattern_text = util::fmt(app_fmt, pattern_text.c_str(), new_pat); pattern_text = util::fmt(app_fmt, pattern_text.c_str(), new_pat);
else else
pattern_text = util::fmt(orig_fmt, new_pat); pattern_text = util::fmt(orig_fmt, new_pat);
} }
void Specific_RE_Matcher::MakeCaseInsensitive() void Specific_RE_Matcher::MakeCaseInsensitive() {
{
const char fmt[] = "(?i:%s)"; const char fmt[] = "(?i:%s)";
pattern_text = util::fmt(fmt, pattern_text.c_str()); pattern_text = util::fmt(fmt, pattern_text.c_str());
} }
void Specific_RE_Matcher::MakeSingleLine() void Specific_RE_Matcher::MakeSingleLine() {
{
const char fmt[] = "(?s:%s)"; const char fmt[] = "(?s:%s)";
pattern_text = util::fmt(fmt, pattern_text.c_str()); pattern_text = util::fmt(fmt, pattern_text.c_str());
} }
bool Specific_RE_Matcher::Compile(bool lazy) bool Specific_RE_Matcher::Compile(bool lazy) {
{
if ( pattern_text.empty() ) if ( pattern_text.empty() )
return false; return false;
@ -129,8 +111,7 @@ bool Specific_RE_Matcher::Compile(bool lazy)
int parse_status = RE_parse(); int parse_status = RE_parse();
RE_done_with_scan(); RE_done_with_scan();
if ( parse_status ) if ( parse_status ) {
{
reporter->Error("error compiling pattern /%s/", pattern_text.c_str()); reporter->Error("error compiling pattern /%s/", pattern_text.c_str());
Unref(nfa); Unref(nfa);
nfa = nullptr; nfa = nullptr;
@ -150,8 +131,7 @@ bool Specific_RE_Matcher::Compile(bool lazy)
return true; return true;
} }
bool Specific_RE_Matcher::CompileSet(const string_list& set, const int_list& idx) bool Specific_RE_Matcher::CompileSet(const string_list& set, const int_list& idx) {
{
if ( (size_t)set.length() != idx.size() ) if ( (size_t)set.length() != idx.size() )
reporter->InternalError("compileset: lengths of sets differ"); reporter->InternalError("compileset: lengths of sets differ");
@ -159,14 +139,12 @@ bool Specific_RE_Matcher::CompileSet(const string_list& set, const int_list& idx
NFA_Machine* set_nfa = nullptr; NFA_Machine* set_nfa = nullptr;
loop_over_list(set, i) loop_over_list(set, i) {
{
RE_set_input(set[i]); RE_set_input(set[i]);
int parse_status = RE_parse(); int parse_status = RE_parse();
RE_done_with_scan(); RE_done_with_scan();
if ( parse_status ) if ( parse_status ) {
{
reporter->Error("error compiling pattern /%s/", set[i]); reporter->Error("error compiling pattern /%s/", set[i]);
if ( set_nfa && set_nfa != nfa ) if ( set_nfa && set_nfa != nfa )
@ -197,8 +175,7 @@ bool Specific_RE_Matcher::CompileSet(const string_list& set, const int_list& idx
return true; return true;
} }
std::string Specific_RE_Matcher::LookupDef(const std::string& def) std::string Specific_RE_Matcher::LookupDef(const std::string& def) {
{
const auto& iter = defs.find(def); const auto& iter = defs.find(def);
if ( iter != defs.end() ) if ( iter != defs.end() )
return iter->second; return iter->second;
@ -206,39 +183,22 @@ std::string Specific_RE_Matcher::LookupDef(const std::string& def)
return {}; return {};
} }
bool Specific_RE_Matcher::MatchAll(const char* s) bool Specific_RE_Matcher::MatchAll(const char* s) { return MatchAll((const u_char*)(s), strlen(s)); }
{
return MatchAll((const u_char*)(s), strlen(s));
}
bool Specific_RE_Matcher::MatchAll(const String* s) bool Specific_RE_Matcher::MatchAll(const String* s) {
{
// s->Len() does not include '\0'. // s->Len() does not include '\0'.
return MatchAll(s->Bytes(), s->Len()); return MatchAll(s->Bytes(), s->Len());
} }
int Specific_RE_Matcher::Match(const char* s) int Specific_RE_Matcher::Match(const char* s) { return Match((const u_char*)(s), strlen(s)); }
{
return Match((const u_char*)(s), strlen(s));
}
int Specific_RE_Matcher::Match(const String* s) int Specific_RE_Matcher::Match(const String* s) { return Match(s->Bytes(), s->Len()); }
{
return Match(s->Bytes(), s->Len());
}
int Specific_RE_Matcher::LongestMatch(const char* s) int Specific_RE_Matcher::LongestMatch(const char* s) { return LongestMatch((const u_char*)(s), strlen(s)); }
{
return LongestMatch((const u_char*)(s), strlen(s));
}
int Specific_RE_Matcher::LongestMatch(const String* s) int Specific_RE_Matcher::LongestMatch(const String* s) { return LongestMatch(s->Bytes(), s->Len()); }
{
return LongestMatch(s->Bytes(), s->Len());
}
bool Specific_RE_Matcher::MatchAll(const u_char* bv, int n) bool Specific_RE_Matcher::MatchAll(const u_char* bv, int n) {
{
if ( ! dfa ) if ( ! dfa )
// An empty pattern matches "all" iff what's being // An empty pattern matches "all" iff what's being
// matched is empty. // matched is empty.
@ -247,8 +207,7 @@ bool Specific_RE_Matcher::MatchAll(const u_char* bv, int n)
DFA_State* d = dfa->StartState(); DFA_State* d = dfa->StartState();
d = d->Xtion(ecs[SYM_BOL], dfa); d = d->Xtion(ecs[SYM_BOL], dfa);
while ( d ) while ( d ) {
{
if ( --n < 0 ) if ( --n < 0 )
break; break;
@ -262,8 +221,7 @@ bool Specific_RE_Matcher::MatchAll(const u_char* bv, int n)
return d && d->Accept() != nullptr; return d && d->Accept() != nullptr;
} }
int Specific_RE_Matcher::Match(const u_char* bv, int n) int Specific_RE_Matcher::Match(const u_char* bv, int n) {
{
if ( ! dfa ) if ( ! dfa )
// An empty pattern matches anything. // An empty pattern matches anything.
return 1; return 1;
@ -274,8 +232,7 @@ int Specific_RE_Matcher::Match(const u_char* bv, int n)
if ( ! d ) if ( ! d )
return 0; return 0;
for ( int i = 0; i < n; ++i ) for ( int i = 0; i < n; ++i ) {
{
int ec = ecs[bv[i]]; int ec = ecs[bv[i]];
d = d->Xtion(ec, dfa); d = d->Xtion(ec, dfa);
if ( ! d ) if ( ! d )
@ -285,8 +242,7 @@ int Specific_RE_Matcher::Match(const u_char* bv, int n)
return i + 1; return i + 1;
} }
if ( d ) if ( d ) {
{
d = d->Xtion(ecs[SYM_EOL], dfa); d = d->Xtion(ecs[SYM_EOL], dfa);
if ( d && d->Accept() ) if ( d && d->Accept() )
return n > 0 ? n : 1; // we can't return 0 here for match... return n > 0 ? n : 1; // we can't return 0 here for match...
@ -295,23 +251,17 @@ int Specific_RE_Matcher::Match(const u_char* bv, int n)
return 0; return 0;
} }
void Specific_RE_Matcher::Dump(FILE* f) void Specific_RE_Matcher::Dump(FILE* f) { dfa->Dump(f); }
{
dfa->Dump(f);
}
inline void RE_Match_State::AddMatches(const AcceptingSet& as, MatchPos position) inline void RE_Match_State::AddMatches(const AcceptingSet& as, MatchPos position) {
{
using am_idx = std::pair<AcceptIdx, MatchPos>; using am_idx = std::pair<AcceptIdx, MatchPos>;
for ( AcceptingSet::const_iterator it = as.begin(); it != as.end(); ++it ) for ( AcceptingSet::const_iterator it = as.begin(); it != as.end(); ++it )
accepted_matches.insert(am_idx(*it, position)); accepted_matches.insert(am_idx(*it, position));
} }
bool RE_Match_State::Match(const u_char* bv, int n, bool bol, bool eol, bool clear) bool RE_Match_State::Match(const u_char* bv, int n, bool bol, bool eol, bool clear) {
{ if ( current_pos == -1 ) {
if ( current_pos == -1 )
{
// First call to Match(). // First call to Match().
if ( ! dfa ) if ( ! dfa )
return false; return false;
@ -340,8 +290,7 @@ bool RE_Match_State::Match(const u_char* bv, int n, bool bol, bool eol, bool cle
int m = bol ? n + 1 : n; int m = bol ? n + 1 : n;
int e = eol ? -1 : 0; int e = eol ? -1 : 0;
while ( --m >= e ) while ( --m >= e ) {
{
if ( m == n ) if ( m == n )
ec = ecs[SYM_BOL]; ec = ecs[SYM_BOL];
else if ( m == -1 ) else if ( m == -1 )
@ -351,8 +300,7 @@ bool RE_Match_State::Match(const u_char* bv, int n, bool bol, bool eol, bool cle
DFA_State* next_state = current_state->Xtion(ec, dfa); DFA_State* next_state = current_state->Xtion(ec, dfa);
if ( ! next_state ) if ( ! next_state ) {
{
current_state = nullptr; current_state = nullptr;
break; break;
} }
@ -370,8 +318,7 @@ bool RE_Match_State::Match(const u_char* bv, int n, bool bol, bool eol, bool cle
return accepted_matches.size() != old_matches; return accepted_matches.size() != old_matches;
} }
int Specific_RE_Matcher::LongestMatch(const u_char* bv, int n) int Specific_RE_Matcher::LongestMatch(const u_char* bv, int n) {
{
if ( ! dfa ) if ( ! dfa )
// An empty pattern matches anything. // An empty pattern matches anything.
return 0; return 0;
@ -387,8 +334,7 @@ int Specific_RE_Matcher::LongestMatch(const u_char* bv, int n)
if ( d->Accept() ) if ( d->Accept() )
last_accept = 0; last_accept = 0;
for ( int i = 0; i < n; ++i ) for ( int i = 0; i < n; ++i ) {
{
int ec = ecs[bv[i]]; int ec = ecs[bv[i]];
d = d->Xtion(ec, dfa); d = d->Xtion(ec, dfa);
@ -399,8 +345,7 @@ int Specific_RE_Matcher::LongestMatch(const u_char* bv, int n)
last_accept = i + 1; last_accept = i + 1;
} }
if ( d ) if ( d ) {
{
d = d->Xtion(ecs[SYM_EOL], dfa); d = d->Xtion(ecs[SYM_EOL], dfa);
if ( d && d->Accept() ) if ( d && d->Accept() )
return n; return n;
@ -409,8 +354,7 @@ int Specific_RE_Matcher::LongestMatch(const u_char* bv, int n)
return last_accept; return last_accept;
} }
static RE_Matcher* matcher_merge(const RE_Matcher* re1, const RE_Matcher* re2, const char* merge_op) static RE_Matcher* matcher_merge(const RE_Matcher* re1, const RE_Matcher* re2, const char* merge_op) {
{
const char* text1 = re1->PatternText(); const char* text1 = re1->PatternText();
const char* text2 = re2->PatternText(); const char* text2 = re2->PatternText();
@ -424,78 +368,61 @@ static RE_Matcher* matcher_merge(const RE_Matcher* re1, const RE_Matcher* re2, c
return merge; return merge;
} }
RE_Matcher* RE_Matcher_conjunction(const RE_Matcher* re1, const RE_Matcher* re2) RE_Matcher* RE_Matcher_conjunction(const RE_Matcher* re1, const RE_Matcher* re2) { return matcher_merge(re1, re2, ""); }
{
return matcher_merge(re1, re2, "");
}
RE_Matcher* RE_Matcher_disjunction(const RE_Matcher* re1, const RE_Matcher* re2) RE_Matcher* RE_Matcher_disjunction(const RE_Matcher* re1, const RE_Matcher* re2) {
{
return matcher_merge(re1, re2, "|"); return matcher_merge(re1, re2, "|");
} }
} // namespace detail } // namespace detail
RE_Matcher::RE_Matcher() RE_Matcher::RE_Matcher() {
{
re_anywhere = new detail::Specific_RE_Matcher(detail::MATCH_ANYWHERE); re_anywhere = new detail::Specific_RE_Matcher(detail::MATCH_ANYWHERE);
re_exact = new detail::Specific_RE_Matcher(detail::MATCH_EXACTLY); re_exact = new detail::Specific_RE_Matcher(detail::MATCH_EXACTLY);
} }
RE_Matcher::RE_Matcher(const char* pat) : orig_text(pat) RE_Matcher::RE_Matcher(const char* pat) : orig_text(pat) {
{
re_anywhere = new detail::Specific_RE_Matcher(detail::MATCH_ANYWHERE); re_anywhere = new detail::Specific_RE_Matcher(detail::MATCH_ANYWHERE);
re_exact = new detail::Specific_RE_Matcher(detail::MATCH_EXACTLY); re_exact = new detail::Specific_RE_Matcher(detail::MATCH_EXACTLY);
AddPat(pat); AddPat(pat);
} }
RE_Matcher::RE_Matcher(const char* exact_pat, const char* anywhere_pat) RE_Matcher::RE_Matcher(const char* exact_pat, const char* anywhere_pat) {
{
re_anywhere = new detail::Specific_RE_Matcher(detail::MATCH_ANYWHERE); re_anywhere = new detail::Specific_RE_Matcher(detail::MATCH_ANYWHERE);
re_anywhere->SetPat(anywhere_pat); re_anywhere->SetPat(anywhere_pat);
re_exact = new detail::Specific_RE_Matcher(detail::MATCH_EXACTLY); re_exact = new detail::Specific_RE_Matcher(detail::MATCH_EXACTLY);
re_exact->SetPat(exact_pat); re_exact->SetPat(exact_pat);
} }
RE_Matcher::~RE_Matcher() RE_Matcher::~RE_Matcher() {
{
delete re_anywhere; delete re_anywhere;
delete re_exact; delete re_exact;
} }
void RE_Matcher::AddPat(const char* new_pat) void RE_Matcher::AddPat(const char* new_pat) {
{
re_anywhere->AddPat(new_pat); re_anywhere->AddPat(new_pat);
re_exact->AddPat(new_pat); re_exact->AddPat(new_pat);
} }
void RE_Matcher::MakeCaseInsensitive() void RE_Matcher::MakeCaseInsensitive() {
{
re_anywhere->MakeCaseInsensitive(); re_anywhere->MakeCaseInsensitive();
re_exact->MakeCaseInsensitive(); re_exact->MakeCaseInsensitive();
is_case_insensitive = true; is_case_insensitive = true;
} }
void RE_Matcher::MakeSingleLine() void RE_Matcher::MakeSingleLine() {
{
re_anywhere->MakeSingleLine(); re_anywhere->MakeSingleLine();
re_exact->MakeSingleLine(); re_exact->MakeSingleLine();
is_single_line = true; is_single_line = true;
} }
bool RE_Matcher::Compile(bool lazy) bool RE_Matcher::Compile(bool lazy) { return re_anywhere->Compile(lazy) && re_exact->Compile(lazy); }
{
return re_anywhere->Compile(lazy) && re_exact->Compile(lazy);
}
TEST_SUITE("re_matcher") TEST_SUITE("re_matcher") {
{ TEST_CASE("simple_pattern") {
TEST_CASE("simple_pattern")
{
RE_Matcher match("[0-9]+"); RE_Matcher match("[0-9]+");
match.Compile(); match.Compile();
CHECK(strcmp(match.OrigText(), "[0-9]+") == 0); CHECK(strcmp(match.OrigText(), "[0-9]+") == 0);
@ -513,8 +440,7 @@ TEST_SUITE("re_matcher")
CHECK(match.MatchAnywhere("abcd") == 0); CHECK(match.MatchAnywhere("abcd") == 0);
} }
TEST_CASE("case_insensitive_mode") TEST_CASE("case_insensitive_mode") {
{
RE_Matcher match("[a-z]+"); RE_Matcher match("[a-z]+");
match.MakeCaseInsensitive(); match.MakeCaseInsensitive();
match.Compile(); match.Compile();
@ -523,8 +449,7 @@ TEST_SUITE("re_matcher")
CHECK(match.MatchExactly("abcDEF")); CHECK(match.MatchExactly("abcDEF"));
} }
TEST_CASE("multi_pattern") TEST_CASE("multi_pattern") {
{
RE_Matcher match("[0-9]+"); RE_Matcher match("[0-9]+");
match.AddPat("[a-z]+"); match.AddPat("[a-z]+");
match.Compile(); match.Compile();
@ -536,8 +461,7 @@ TEST_SUITE("re_matcher")
CHECK_FALSE(match.MatchExactly("abc123")); CHECK_FALSE(match.MatchExactly("abc123"));
} }
TEST_CASE("modes_multi_pattern") TEST_CASE("modes_multi_pattern") {
{
RE_Matcher match("[a-m]+"); RE_Matcher match("[a-m]+");
match.MakeCaseInsensitive(); match.MakeCaseInsensitive();
@ -550,8 +474,7 @@ TEST_SUITE("re_matcher")
CHECK_FALSE(match.MatchExactly("NoP")); CHECK_FALSE(match.MatchExactly("NoP"));
} }
TEST_CASE("single_line_mode") TEST_CASE("single_line_mode") {
{
RE_Matcher match(".*"); RE_Matcher match(".*");
match.MakeSingleLine(); match.MakeSingleLine();
match.Compile(); match.Compile();
@ -580,8 +503,7 @@ TEST_SUITE("re_matcher")
CHECK(match4.MatchExactly("a\nc")); CHECK(match4.MatchExactly("a\nc"));
} }
TEST_CASE("disjunction") TEST_CASE("disjunction") {
{
RE_Matcher match1("a.c"); RE_Matcher match1("a.c");
match1.MakeSingleLine(); match1.MakeSingleLine();
match1.Compile(); match1.Compile();

Some files were not shown because too many files have changed in this diff Show more