From f5a76c1aedc7f8886bc6abef0dfaa8065684b1f6 Mon Sep 17 00:00:00 2001 From: Benjamin Bannier Date: Tue, 10 Oct 2023 21:13:34 +0200 Subject: [PATCH] Reformat Zeek in Spicy style This largely copies over Spicy's `.clang-format` configuration file. The one place where we deviate is header include order since Zeek depends on headers being included in a certain order. --- .clang-format | 168 +- auxil/spicy | 2 +- cmake | 2 +- doc | 2 +- src/Anon.cc | 601 +- src/Anon.h | 144 +- src/Attr.cc | 1324 ++- src/Attr.h | 173 +- src/Base64.cc | 406 +- src/Base64.h | 87 +- src/BifReturnVal.cc | 7 +- src/BifReturnVal.h | 24 +- src/CCL.cc | 54 +- src/CCL.h | 41 +- src/CompHash.cc | 1695 ++- src/CompHash.h | 82 +- src/Conn.cc | 660 +- src/Conn.h | 371 +- src/DFA.cc | 835 +- src/DFA.h | 173 +- src/DNS_Mapping.cc | 631 +- src/DNS_Mapping.h | 92 +- src/DNS_Mgr.cc | 3162 +++--- src/DNS_Mgr.h | 494 +- src/DbgBreakpoint.cc | 646 +- src/DbgBreakpoint.h | 121 +- src/DbgDisplay.h | 31 +- src/DbgWatch.cc | 15 +- src/DbgWatch.h | 25 +- src/Debug.cc | 1336 ++- src/Debug.h | 142 +- src/DebugCmds.cc | 973 +- src/DebugCmds.h | 46 +- src/DebugLogger.cc | 294 +- src/DebugLogger.h | 143 +- src/Desc.cc | 784 +- src/Desc.h | 319 +- src/Dict.cc | 815 +- src/Dict.h | 2595 +++-- src/Discard.cc | 215 +- src/Discard.h | 37 +- src/EquivClass.cc | 282 +- src/EquivClass.h | 58 +- src/Event.cc | 309 +- src/Event.h | 181 +- src/EventHandler.cc | 183 +- src/EventHandler.h | 123 +- src/EventRegistry.cc | 285 +- src/EventRegistry.h | 198 +- src/EventTrace.cc | 2188 ++-- src/EventTrace.h | 616 +- src/Expr.cc | 9428 ++++++++--------- src/Expr.h | 2321 ++-- src/File.cc | 469 +- src/File.h | 133 +- src/Flare.cc | 204 +- src/Flare.h | 78 +- src/Frag.cc | 693 +- src/Frag.h | 113 +- src/Frame.cc | 304 +- src/Frame.h | 370 +- src/Func.cc | 2041 ++-- src/Func.h | 534 +- src/Hash.cc | 1118 +- src/Hash.h | 576 +- src/ID.cc | 1137 +- src/ID.h | 220 +- src/IP.cc | 1529 ++- src/IP.h | 819 +- src/IPAddr.cc | 766 +- src/IPAddr.h | 1076 +- src/IntSet.cc | 24 +- src/IntSet.h | 80 +- src/IntrusivePtr.h | 316 +- src/List.cc | 205 +- src/List.h | 464 +- src/NFA.cc | 662 +- src/NFA.h | 146 +- src/NetVar.cc | 207 +- src/NetVar.h | 5 +- src/Notifier.cc | 112 +- src/Notifier.h | 153 +- src/Obj.cc | 294 +- src/Obj.h | 240 +- src/OpaqueVal.cc | 1492 ++- src/OpaqueVal.h | 552 +- src/Options.cc | 1114 +- src/Options.h | 120 +- src/Overflow.cc | 50 +- src/Overflow.h | 27 +- src/PacketFilter.cc | 183 +- src/PacketFilter.h | 68 +- src/Pipe.cc | 221 +- src/Pipe.h | 196 +- src/PolicyFile.cc | 268 +- src/PolicyFile.h | 11 +- src/PrefixTable.cc | 280 +- src/PrefixTable.h | 96 +- src/PriorityQueue.cc | 181 +- src/PriorityQueue.h | 116 +- src/RE.cc | 1038 +- src/RE.h | 285 +- src/RandTest.cc | 203 +- src/RandTest.h | 47 +- src/Reassem.cc | 547 +- src/Reassem.h | 438 +- src/Reporter.cc | 1199 +-- src/Reporter.h | 501 +- src/Rule.cc | 139 +- src/Rule.h | 157 +- src/RuleAction.cc | 169 +- src/RuleAction.h | 103 +- src/RuleCondition.cc | 285 +- src/RuleCondition.h | 136 +- src/RuleMatcher.cc | 2328 ++-- src/RuleMatcher.h | 513 +- src/RunState.cc | 715 +- src/RunState.h | 23 +- src/ScannedFile.cc | 66 +- src/ScannedFile.h | 52 +- src/Scope.cc | 303 +- src/Scope.h | 82 +- src/ScriptCoverageManager.cc | 226 +- src/ScriptCoverageManager.h | 152 +- src/ScriptProfile.cc | 365 +- src/ScriptProfile.h | 246 +- src/ScriptValidation.cc | 145 +- src/ScriptValidation.h | 5 +- src/SerializationFormat.cc | 790 +- src/SerializationFormat.h | 216 +- src/SmithWaterman.cc | 745 +- src/SmithWaterman.h | 152 +- src/Span.h | 158 +- src/Stats.cc | 756 +- src/Stats.h | 144 +- src/Stmt.cc | 3890 ++++--- src/Stmt.h | 958 +- src/StmtBase.h | 259 +- src/StmtEnums.h | 75 +- src/Tag.cc | 102 +- src/Tag.h | 221 +- src/Timer.cc | 260 +- src/Timer.h | 237 +- src/Traverse.cc | 30 +- src/Traverse.h | 67 +- src/TraverseTypes.h | 48 +- src/Trigger.cc | 767 +- src/Trigger.h | 247 +- src/TunnelEncapsulation.cc | 95 +- src/TunnelEncapsulation.h | 368 +- src/Type.cc | 4962 ++++----- src/Type.h | 1372 ++- src/UID.cc | 49 +- src/UID.h | 133 +- src/Val.cc | 7797 +++++++------- src/Val.h | 2428 ++--- src/Var.cc | 1401 ++- src/Var.h | 31 +- src/WeirdState.cc | 49 +- src/WeirdState.h | 19 +- src/ZVal.cc | 378 +- src/ZVal.h | 270 +- src/ZeekArgs.cc | 75 +- src/ZeekArgs.h | 8 +- src/ZeekList.h | 10 +- src/ZeekString.cc | 1225 +-- src/ZeekString.h | 243 +- src/analyzer/Analyzer.cc | 1678 ++- src/analyzer/Analyzer.h | 1386 ++- src/analyzer/Component.cc | 51 +- src/analyzer/Component.h | 163 +- src/analyzer/Manager.cc | 661 +- src/analyzer/Manager.h | 613 +- .../protocol/bittorrent/BitTorrent.cc | 188 +- src/analyzer/protocol/bittorrent/BitTorrent.h | 35 +- .../protocol/bittorrent/BitTorrentTracker.cc | 1355 ++- .../protocol/bittorrent/BitTorrentTracker.h | 187 +- src/analyzer/protocol/bittorrent/Plugin.cc | 33 +- src/analyzer/protocol/conn-size/ConnSize.cc | 292 +- src/analyzer/protocol/conn-size/ConnSize.h | 61 +- src/analyzer/protocol/conn-size/Plugin.cc | 27 +- src/analyzer/protocol/dce-rpc/DCE_RPC.cc | 79 +- src/analyzer/protocol/dce-rpc/DCE_RPC.h | 37 +- src/analyzer/protocol/dce-rpc/Plugin.cc | 26 +- src/analyzer/protocol/dhcp/DHCP.cc | 41 +- src/analyzer/protocol/dhcp/DHCP.h | 23 +- src/analyzer/protocol/dhcp/Plugin.cc | 26 +- src/analyzer/protocol/dnp3/DNP3.cc | 649 +- src/analyzer/protocol/dnp3/DNP3.h | 125 +- src/analyzer/protocol/dnp3/Plugin.cc | 29 +- src/analyzer/protocol/dns/DNS.cc | 4021 ++++--- src/analyzer/protocol/dns/DNS.h | 770 +- src/analyzer/protocol/dns/Plugin.cc | 28 +- src/analyzer/protocol/file/File.cc | 99 +- src/analyzer/protocol/file/File.h | 43 +- src/analyzer/protocol/file/Plugin.cc | 26 +- src/analyzer/protocol/finger/legacy/Finger.cc | 117 +- src/analyzer/protocol/finger/legacy/Finger.h | 28 +- src/analyzer/protocol/finger/legacy/Plugin.cc | 26 +- src/analyzer/protocol/ftp/FTP.cc | 547 +- src/analyzer/protocol/ftp/FTP.h | 52 +- src/analyzer/protocol/ftp/Plugin.cc | 28 +- src/analyzer/protocol/gnutella/Gnutella.cc | 429 +- src/analyzer/protocol/gnutella/Gnutella.h | 88 +- src/analyzer/protocol/gnutella/Plugin.cc | 27 +- src/analyzer/protocol/gssapi/GSSAPI.cc | 69 +- src/analyzer/protocol/gssapi/GSSAPI.h | 31 +- src/analyzer/protocol/gssapi/Plugin.cc | 26 +- src/analyzer/protocol/http/HTTP.cc | 3107 +++--- src/analyzer/protocol/http/HTTP.h | 386 +- src/analyzer/protocol/http/Plugin.cc | 26 +- src/analyzer/protocol/ident/Ident.cc | 344 +- src/analyzer/protocol/ident/Ident.h | 34 +- src/analyzer/protocol/ident/Plugin.cc | 26 +- src/analyzer/protocol/imap/IMAP.cc | 119 +- src/analyzer/protocol/imap/IMAP.h | 34 +- src/analyzer/protocol/imap/Plugin.cc | 26 +- src/analyzer/protocol/irc/IRC.cc | 2164 ++-- src/analyzer/protocol/irc/IRC.h | 131 +- src/analyzer/protocol/irc/Plugin.cc | 29 +- src/analyzer/protocol/krb/KRB.cc | 204 +- src/analyzer/protocol/krb/KRB.h | 37 +- src/analyzer/protocol/krb/KRB_TCP.cc | 85 +- src/analyzer/protocol/krb/KRB_TCP.h | 39 +- src/analyzer/protocol/krb/Plugin.cc | 29 +- src/analyzer/protocol/login/Login.cc | 1038 +- src/analyzer/protocol/login/Login.h | 107 +- src/analyzer/protocol/login/NVT.cc | 979 +- src/analyzer/protocol/login/NVT.h | 221 +- src/analyzer/protocol/login/Plugin.cc | 40 +- src/analyzer/protocol/login/RSH.cc | 327 +- src/analyzer/protocol/login/RSH.h | 64 +- src/analyzer/protocol/login/Rlogin.cc | 370 +- src/analyzer/protocol/login/Rlogin.h | 74 +- src/analyzer/protocol/login/Telnet.cc | 22 +- src/analyzer/protocol/login/Telnet.h | 16 +- src/analyzer/protocol/mime/MIME.cc | 2416 ++--- src/analyzer/protocol/mime/MIME.h | 354 +- src/analyzer/protocol/mime/Plugin.cc | 23 +- src/analyzer/protocol/modbus/Modbus.cc | 55 +- src/analyzer/protocol/modbus/Modbus.h | 29 +- src/analyzer/protocol/modbus/Plugin.cc | 26 +- src/analyzer/protocol/mqtt/MQTT.cc | 66 +- src/analyzer/protocol/mqtt/MQTT.h | 40 +- src/analyzer/protocol/mqtt/Plugin.cc | 26 +- src/analyzer/protocol/mysql/MySQL.cc | 120 +- src/analyzer/protocol/mysql/MySQL.h | 37 +- src/analyzer/protocol/mysql/Plugin.cc | 26 +- src/analyzer/protocol/ncp/NCP.cc | 361 +- src/analyzer/protocol/ncp/NCP.h | 113 +- src/analyzer/protocol/ncp/Plugin.cc | 28 +- src/analyzer/protocol/netbios/NetbiosSSN.cc | 726 +- src/analyzer/protocol/netbios/NetbiosSSN.h | 212 +- src/analyzer/protocol/netbios/Plugin.cc | 29 +- src/analyzer/protocol/ntlm/NTLM.cc | 68 +- src/analyzer/protocol/ntlm/NTLM.h | 31 +- src/analyzer/protocol/ntlm/Plugin.cc | 26 +- src/analyzer/protocol/ntp/NTP.cc | 45 +- src/analyzer/protocol/ntp/NTP.h | 25 +- src/analyzer/protocol/ntp/Plugin.cc | 26 +- src/analyzer/protocol/pia/PIA.cc | 636 +- src/analyzer/protocol/pia/PIA.h | 228 +- src/analyzer/protocol/pia/Plugin.cc | 29 +- src/analyzer/protocol/pop3/POP3.cc | 1589 ++- src/analyzer/protocol/pop3/POP3.h | 166 +- src/analyzer/protocol/pop3/Plugin.cc | 26 +- src/analyzer/protocol/quic/decrypt_crypto.cc | 468 +- src/analyzer/protocol/radius/Plugin.cc | 26 +- src/analyzer/protocol/radius/RADIUS.cc | 44 +- src/analyzer/protocol/radius/RADIUS.h | 25 +- src/analyzer/protocol/rdp/Plugin.cc | 30 +- src/analyzer/protocol/rdp/RDP.cc | 138 +- src/analyzer/protocol/rdp/RDP.h | 36 +- src/analyzer/protocol/rdp/RDPEUDP.cc | 43 +- src/analyzer/protocol/rdp/RDPEUDP.h | 27 +- src/analyzer/protocol/rfb/Plugin.cc | 26 +- src/analyzer/protocol/rfb/RFB.cc | 100 +- src/analyzer/protocol/rfb/RFB.h | 38 +- src/analyzer/protocol/rpc/MOUNT.cc | 427 +- src/analyzer/protocol/rpc/MOUNT.h | 65 +- src/analyzer/protocol/rpc/NFS.cc | 1393 ++- src/analyzer/protocol/rpc/NFS.h | 139 +- src/analyzer/protocol/rpc/Plugin.cc | 37 +- src/analyzer/protocol/rpc/Portmap.cc | 434 +- src/analyzer/protocol/rpc/Portmap.h | 49 +- src/analyzer/protocol/rpc/RPC.cc | 1255 +-- src/analyzer/protocol/rpc/RPC.h | 377 +- src/analyzer/protocol/rpc/XDR.cc | 153 +- src/analyzer/protocol/rpc/XDR.h | 5 +- src/analyzer/protocol/sip/Plugin.cc | 32 +- src/analyzer/protocol/sip/SIP.cc | 53 +- src/analyzer/protocol/sip/SIP.h | 25 +- src/analyzer/protocol/sip/SIP_TCP.cc | 85 +- src/analyzer/protocol/sip/SIP_TCP.h | 30 +- src/analyzer/protocol/smb/Plugin.cc | 28 +- src/analyzer/protocol/smb/SMB.cc | 115 +- src/analyzer/protocol/smb/SMB.h | 38 +- src/analyzer/protocol/smtp/Plugin.cc | 26 +- src/analyzer/protocol/smtp/SMTP.cc | 1314 ++- src/analyzer/protocol/smtp/SMTP.h | 123 +- src/analyzer/protocol/snmp/Plugin.cc | 26 +- src/analyzer/protocol/snmp/SNMP.cc | 45 +- src/analyzer/protocol/snmp/SNMP.h | 27 +- src/analyzer/protocol/socks/Plugin.cc | 26 +- src/analyzer/protocol/socks/SOCKS.cc | 130 +- src/analyzer/protocol/socks/SOCKS.h | 44 +- src/analyzer/protocol/ssh/Plugin.cc | 26 +- src/analyzer/protocol/ssh/SSH.cc | 266 +- src/analyzer/protocol/ssh/SSH.h | 49 +- src/analyzer/protocol/ssl/DTLS.cc | 123 +- src/analyzer/protocol/ssl/DTLS.h | 89 +- src/analyzer/protocol/ssl/Plugin.cc | 29 +- src/analyzer/protocol/ssl/SSL.cc | 605 +- src/analyzer/protocol/ssl/SSL.h | 292 +- src/analyzer/protocol/syslog/legacy/Plugin.cc | 26 +- src/analyzer/protocol/syslog/legacy/Syslog.cc | 45 +- src/analyzer/protocol/syslog/legacy/Syslog.h | 25 +- src/analyzer/protocol/tcp/ContentLine.cc | 558 +- src/analyzer/protocol/tcp/ContentLine.h | 146 +- src/analyzer/protocol/tcp/Plugin.cc | 30 +- src/analyzer/protocol/tcp/TCP.cc | 373 +- src/analyzer/protocol/tcp/TCP.h | 158 +- src/analyzer/protocol/tcp/TCP_Endpoint.cc | 477 +- src/analyzer/protocol/tcp/TCP_Endpoint.h | 368 +- src/analyzer/protocol/tcp/TCP_Flags.h | 67 +- src/analyzer/protocol/tcp/TCP_Reassembler.cc | 1101 +- src/analyzer/protocol/tcp/TCP_Reassembler.h | 161 +- src/analyzer/protocol/xmpp/Plugin.cc | 26 +- src/analyzer/protocol/xmpp/XMPP.cc | 116 +- src/analyzer/protocol/xmpp/XMPP.h | 34 +- src/analyzer/protocol/zip/Plugin.cc | 25 +- src/analyzer/protocol/zip/ZIP.cc | 162 +- src/analyzer/protocol/zip/ZIP.h | 37 +- src/binpac_zeek.h | 12 +- src/broker/Data.cc | 2378 ++--- src/broker/Data.h | 205 +- src/broker/Manager.cc | 3636 +++---- src/broker/Manager.h | 789 +- src/broker/Store.cc | 133 +- src/broker/Store.h | 135 +- src/digest.cc | 107 +- src/digest.h | 57 +- src/file_analysis/Analyzer.cc | 126 +- src/file_analysis/Analyzer.h | 333 +- src/file_analysis/AnalyzerSet.cc | 264 +- src/file_analysis/AnalyzerSet.h | 323 +- src/file_analysis/Component.cc | 44 +- src/file_analysis/Component.h | 132 +- src/file_analysis/File.cc | 1085 +- src/file_analysis/File.h | 586 +- src/file_analysis/FileReassembler.cc | 241 +- src/file_analysis/FileReassembler.h | 77 +- src/file_analysis/FileTimer.cc | 56 +- src/file_analysis/FileTimer.h | 40 +- src/file_analysis/Manager.cc | 777 +- src/file_analysis/Manager.h | 738 +- .../analyzer/data_event/DataEvent.cc | 72 +- .../analyzer/data_event/DataEvent.h | 84 +- .../analyzer/data_event/Plugin.cc | 27 +- src/file_analysis/analyzer/entropy/Entropy.cc | 88 +- src/file_analysis/analyzer/entropy/Entropy.h | 104 +- src/file_analysis/analyzer/entropy/Plugin.cc | 26 +- src/file_analysis/analyzer/extract/Extract.cc | 274 +- src/file_analysis/analyzer/extract/Extract.h | 108 +- src/file_analysis/analyzer/extract/Plugin.cc | 26 +- src/file_analysis/analyzer/hash/Hash.cc | 69 +- src/file_analysis/analyzer/hash/Hash.h | 220 +- src/file_analysis/analyzer/hash/Plugin.cc | 32 +- src/file_analysis/analyzer/pe/PE.cc | 56 +- src/file_analysis/analyzer/pe/PE.h | 31 +- src/file_analysis/analyzer/pe/Plugin.cc | 26 +- src/file_analysis/analyzer/x509/OCSP.cc | 878 +- src/file_analysis/analyzer/x509/OCSP.h | 38 +- src/file_analysis/analyzer/x509/Plugin.cc | 43 +- src/file_analysis/analyzer/x509/X509.cc | 879 +- src/file_analysis/analyzer/x509/X509.h | 250 +- src/file_analysis/analyzer/x509/X509Common.cc | 497 +- src/file_analysis/analyzer/x509/X509Common.h | 60 +- src/fuzzers/FuzzBuffer.cc | 116 +- src/fuzzers/FuzzBuffer.h | 92 +- src/fuzzers/dns-fuzzer.cc | 80 +- src/fuzzers/fuzzer-setup.h | 140 +- src/fuzzers/generic-analyzer-fuzzer.cc | 249 +- src/fuzzers/packet-fuzzer.cc | 54 +- src/fuzzers/standalone-driver.cc | 88 +- src/input.h | 7 +- src/input/Component.cc | 30 +- src/input/Component.h | 76 +- src/input/Manager.cc | 3996 ++++--- src/input/Manager.h | 414 +- src/input/ReaderBackend.cc | 453 +- src/input/ReaderBackend.h | 602 +- src/input/ReaderFrontend.cc | 123 +- src/input/ReaderFrontend.h | 210 +- src/input/readers/ascii/Ascii.cc | 686 +- src/input/readers/ascii/Ascii.h | 137 +- src/input/readers/ascii/Plugin.cc | 26 +- src/input/readers/benchmark/Benchmark.cc | 376 +- src/input/readers/benchmark/Benchmark.h | 53 +- src/input/readers/benchmark/Plugin.cc | 26 +- src/input/readers/binary/Binary.cc | 377 +- src/input/readers/binary/Binary.h | 49 +- src/input/readers/binary/Plugin.cc | 26 +- src/input/readers/config/Config.cc | 451 +- src/input/readers/config/Config.h | 57 +- src/input/readers/config/Plugin.cc | 26 +- src/input/readers/raw/Plugin.cc | 29 +- src/input/readers/raw/Plugin.h | 22 +- src/input/readers/raw/Raw.cc | 1182 +-- src/input/readers/raw/Raw.h | 111 +- src/input/readers/sqlite/Plugin.cc | 26 +- src/input/readers/sqlite/SQLite.cc | 485 +- src/input/readers/sqlite/SQLite.h | 52 +- src/iosource/BPF_Program.cc | 194 +- src/iosource/BPF_Program.h | 130 +- src/iosource/Component.cc | 182 +- src/iosource/Component.h | 235 +- src/iosource/IOSource.h | 172 +- src/iosource/Manager.cc | 742 +- src/iosource/Manager.h | 342 +- src/iosource/Packet.cc | 270 +- src/iosource/Packet.h | 426 +- src/iosource/PktDumper.cc | 81 +- src/iosource/PktDumper.h | 206 +- src/iosource/PktSrc.cc | 573 +- src/iosource/PktSrc.h | 643 +- src/iosource/pcap/Dumper.cc | 160 +- src/iosource/pcap/Dumper.h | 37 +- src/iosource/pcap/Plugin.cc | 31 +- src/iosource/pcap/Source.cc | 712 +- src/iosource/pcap/Source.h | 55 +- src/logging/Component.cc | 30 +- src/logging/Component.h | 76 +- src/logging/Manager.cc | 2461 ++--- src/logging/Manager.h | 490 +- src/logging/WriterBackend.cc | 474 +- src/logging/WriterBackend.h | 662 +- src/logging/WriterFrontend.cc | 391 +- src/logging/WriterFrontend.h | 343 +- src/logging/writers/ascii/Ascii.cc | 1591 ++- src/logging/writers/ascii/Ascii.h | 109 +- src/logging/writers/ascii/Plugin.cc | 28 +- src/logging/writers/none/None.cc | 68 +- src/logging/writers/none/None.h | 37 +- src/logging/writers/none/Plugin.cc | 26 +- src/logging/writers/sqlite/Plugin.cc | 26 +- src/logging/writers/sqlite/SQLite.cc | 511 +- src/logging/writers/sqlite/SQLite.h | 54 +- src/main.cc | 139 +- src/module_util.cc | 151 +- src/module_util.h | 5 +- src/net_util.cc | 284 +- src/net_util.h | 377 +- src/packet_analysis/Analyzer.cc | 319 +- src/packet_analysis/Analyzer.h | 464 +- src/packet_analysis/Component.cc | 51 +- src/packet_analysis/Component.h | 72 +- src/packet_analysis/Dispatcher.cc | 139 +- src/packet_analysis/Dispatcher.h | 78 +- src/packet_analysis/Manager.cc | 370 +- src/packet_analysis/Manager.h | 350 +- src/packet_analysis/protocol/arp/ARP.cc | 253 +- src/packet_analysis/protocol/arp/ARP.h | 31 +- src/packet_analysis/protocol/arp/Plugin.cc | 26 +- src/packet_analysis/protocol/ayiya/AYIYA.cc | 116 +- src/packet_analysis/protocol/ayiya/AYIYA.h | 23 +- src/packet_analysis/protocol/ayiya/Plugin.cc | 27 +- .../protocol/ethernet/Ethernet.cc | 114 +- .../protocol/ethernet/Ethernet.h | 27 +- .../protocol/ethernet/Plugin.cc | 28 +- src/packet_analysis/protocol/fddi/FDDI.cc | 22 +- src/packet_analysis/protocol/fddi/FDDI.h | 21 +- src/packet_analysis/protocol/fddi/Plugin.cc | 27 +- src/packet_analysis/protocol/geneve/Geneve.cc | 137 +- src/packet_analysis/protocol/geneve/Geneve.h | 21 +- src/packet_analysis/protocol/geneve/Plugin.cc | 27 +- src/packet_analysis/protocol/gre/GRE.cc | 347 +- src/packet_analysis/protocol/gre/GRE.h | 21 +- src/packet_analysis/protocol/gre/Plugin.cc | 26 +- src/packet_analysis/protocol/gtpv1/GTPv1.cc | 148 +- src/packet_analysis/protocol/gtpv1/GTPv1.h | 49 +- src/packet_analysis/protocol/gtpv1/Plugin.cc | 27 +- src/packet_analysis/protocol/icmp/ICMP.cc | 1485 ++- src/packet_analysis/protocol/icmp/ICMP.h | 94 +- .../protocol/icmp/ICMPSessionAdapter.cc | 113 +- .../protocol/icmp/ICMPSessionAdapter.h | 33 +- src/packet_analysis/protocol/icmp/Plugin.cc | 29 +- .../protocol/ieee802_11/IEEE802_11.cc | 275 +- .../protocol/ieee802_11/IEEE802_11.h | 23 +- .../protocol/ieee802_11/Plugin.cc | 28 +- .../ieee802_11_radio/IEEE802_11_Radio.cc | 34 +- .../ieee802_11_radio/IEEE802_11_Radio.h | 21 +- .../protocol/ieee802_11_radio/Plugin.cc | 28 +- src/packet_analysis/protocol/ip/IP.cc | 500 +- src/packet_analysis/protocol/ip/IP.h | 45 +- .../protocol/ip/IPBasedAnalyzer.cc | 413 +- .../protocol/ip/IPBasedAnalyzer.h | 313 +- src/packet_analysis/protocol/ip/Plugin.cc | 26 +- .../protocol/ip/SessionAdapter.cc | 42 +- .../protocol/ip/SessionAdapter.h | 150 +- .../protocol/iptunnel/IPTunnel.cc | 322 +- .../protocol/iptunnel/IPTunnel.h | 136 +- .../protocol/iptunnel/Plugin.cc | 28 +- .../protocol/linux_sll/LinuxSLL.cc | 36 +- .../protocol/linux_sll/LinuxSLL.h | 38 +- .../protocol/linux_sll/Plugin.cc | 28 +- .../protocol/linux_sll2/LinuxSLL2.cc | 36 +- .../protocol/linux_sll2/LinuxSLL2.h | 42 +- .../protocol/linux_sll2/Plugin.cc | 28 +- src/packet_analysis/protocol/llc/LLC.cc | 45 +- src/packet_analysis/protocol/llc/LLC.h | 21 +- src/packet_analysis/protocol/llc/Plugin.cc | 26 +- src/packet_analysis/protocol/mpls/MPLS.cc | 37 +- src/packet_analysis/protocol/mpls/MPLS.h | 21 +- src/packet_analysis/protocol/mpls/Plugin.cc | 27 +- src/packet_analysis/protocol/nflog/NFLog.cc | 117 +- src/packet_analysis/protocol/nflog/NFLog.h | 18 +- src/packet_analysis/protocol/nflog/Plugin.cc | 27 +- .../protocol/novell_802_3/Novell_802_3.cc | 13 +- .../protocol/novell_802_3/Novell_802_3.h | 21 +- .../protocol/novell_802_3/Plugin.cc | 28 +- src/packet_analysis/protocol/null/Null.cc | 22 +- src/packet_analysis/protocol/null/Null.h | 21 +- src/packet_analysis/protocol/null/Plugin.cc | 27 +- src/packet_analysis/protocol/pbb/PBB.cc | 20 +- src/packet_analysis/protocol/pbb/PBB.h | 21 +- src/packet_analysis/protocol/pbb/Plugin.cc | 26 +- src/packet_analysis/protocol/ppp/PPP.cc | 56 +- src/packet_analysis/protocol/ppp/PPP.h | 21 +- src/packet_analysis/protocol/ppp/Plugin.cc | 26 +- .../protocol/ppp_serial/PPPSerial.cc | 24 +- .../protocol/ppp_serial/PPPSerial.h | 21 +- .../protocol/ppp_serial/Plugin.cc | 28 +- src/packet_analysis/protocol/pppoe/PPPoE.cc | 24 +- src/packet_analysis/protocol/pppoe/PPPoE.h | 21 +- src/packet_analysis/protocol/pppoe/Plugin.cc | 27 +- src/packet_analysis/protocol/root/Plugin.cc | 27 +- src/packet_analysis/protocol/root/Root.cc | 9 +- src/packet_analysis/protocol/root/Root.h | 21 +- src/packet_analysis/protocol/skip/Plugin.cc | 27 +- src/packet_analysis/protocol/skip/Skip.cc | 24 +- src/packet_analysis/protocol/skip/Skip.h | 25 +- src/packet_analysis/protocol/snap/Plugin.cc | 27 +- src/packet_analysis/protocol/snap/SNAP.cc | 68 +- src/packet_analysis/protocol/snap/SNAP.h | 21 +- src/packet_analysis/protocol/tcp/Plugin.cc | 28 +- src/packet_analysis/protocol/tcp/Stats.cc | 123 +- src/packet_analysis/protocol/tcp/Stats.h | 105 +- src/packet_analysis/protocol/tcp/TCP.cc | 247 +- src/packet_analysis/protocol/tcp/TCP.h | 116 +- .../protocol/tcp/TCPSessionAdapter.cc | 3261 +++--- .../protocol/tcp/TCPSessionAdapter.h | 250 +- src/packet_analysis/protocol/teredo/Plugin.cc | 27 +- src/packet_analysis/protocol/teredo/Teredo.cc | 485 +- src/packet_analysis/protocol/teredo/Teredo.h | 129 +- src/packet_analysis/protocol/udp/Plugin.cc | 28 +- src/packet_analysis/protocol/udp/UDP.cc | 320 +- src/packet_analysis/protocol/udp/UDP.h | 78 +- .../protocol/udp/UDPSessionAdapter.cc | 147 +- .../protocol/udp/UDPSessionAdapter.h | 41 +- src/packet_analysis/protocol/vlan/Plugin.cc | 27 +- src/packet_analysis/protocol/vlan/VLAN.cc | 88 +- src/packet_analysis/protocol/vlan/VLAN.h | 27 +- src/packet_analysis/protocol/vntag/Plugin.cc | 27 +- src/packet_analysis/protocol/vntag/VNTag.cc | 22 +- src/packet_analysis/protocol/vntag/VNTag.h | 21 +- src/packet_analysis/protocol/vxlan/Plugin.cc | 27 +- src/packet_analysis/protocol/vxlan/VXLAN.cc | 100 +- src/packet_analysis/protocol/vxlan/VXLAN.h | 21 +- .../protocol/wrapper/Plugin.cc | 28 +- .../protocol/wrapper/Wrapper.cc | 231 +- .../protocol/wrapper/Wrapper.h | 21 +- src/plugin/Component.cc | 109 +- src/plugin/Component.h | 215 +- src/plugin/ComponentManager.h | 472 +- src/plugin/Manager.cc | 1786 ++-- src/plugin/Manager.h | 886 +- src/plugin/Plugin.cc | 877 +- src/plugin/Plugin.h | 1860 ++-- src/probabilistic/BitVector.cc | 1028 +- src/probabilistic/BitVector.h | 558 +- src/probabilistic/BloomFilter.cc | 609 +- src/probabilistic/BloomFilter.h | 414 +- src/probabilistic/CardinalityCounter.cc | 326 +- src/probabilistic/CardinalityCounter.h | 307 +- src/probabilistic/CounterVector.cc | 266 +- src/probabilistic/CounterVector.h | 251 +- src/probabilistic/Hasher.cc | 218 +- src/probabilistic/Hasher.h | 356 +- src/probabilistic/Topk.cc | 758 +- src/probabilistic/Topk.h | 280 +- src/script_opt/CPP/Attrs.cc | 211 +- src/script_opt/CPP/Attrs.h | 20 +- src/script_opt/CPP/Compile.h | 1941 ++-- src/script_opt/CPP/Consts.cc | 198 +- src/script_opt/CPP/DeclFunc.cc | 738 +- src/script_opt/CPP/Driver.cc | 759 +- src/script_opt/CPP/Emit.cc | 32 +- src/script_opt/CPP/Exprs.cc | 2431 ++--- src/script_opt/CPP/Func.cc | 99 +- src/script_opt/CPP/Func.h | 112 +- src/script_opt/CPP/GenFunc.cc | 395 +- src/script_opt/CPP/Inits.cc | 484 +- src/script_opt/CPP/InitsInfo.cc | 933 +- src/script_opt/CPP/InitsInfo.h | 783 +- src/script_opt/CPP/Runtime.h | 5 +- src/script_opt/CPP/RuntimeInitSupport.cc | 360 +- src/script_opt/CPP/RuntimeInitSupport.h | 36 +- src/script_opt/CPP/RuntimeInits.cc | 903 +- src/script_opt/CPP/RuntimeInits.h | 724 +- src/script_opt/CPP/RuntimeOps.cc | 458 +- src/script_opt/CPP/RuntimeOps.h | 236 +- src/script_opt/CPP/RuntimeVec.cc | 607 +- src/script_opt/CPP/RuntimeVec.h | 25 +- src/script_opt/CPP/Stmts.cc | 1063 +- src/script_opt/CPP/Tracker.cc | 68 +- src/script_opt/CPP/Tracker.h | 89 +- src/script_opt/CPP/Types.cc | 640 +- src/script_opt/CPP/Util.cc | 174 +- src/script_opt/CPP/Util.h | 20 +- src/script_opt/CPP/Vars.cc | 220 +- src/script_opt/Expr.cc | 5591 +++++----- src/script_opt/ExprOptInfo.h | 16 +- src/script_opt/GenIDDefs.cc | 889 +- src/script_opt/GenIDDefs.h | 166 +- src/script_opt/IDOptInfo.cc | 964 +- src/script_opt/IDOptInfo.h | 415 +- src/script_opt/Inline.cc | 392 +- src/script_opt/Inline.h | 96 +- src/script_opt/ProfileFunc.cc | 1715 ++- src/script_opt/ProfileFunc.h | 551 +- src/script_opt/Reduce.cc | 2069 ++-- src/script_opt/Reduce.h | 511 +- src/script_opt/ScriptOpt.cc | 1115 +- src/script_opt/ScriptOpt.h | 200 +- src/script_opt/Stmt.cc | 2122 ++-- src/script_opt/StmtOptInfo.h | 40 +- src/script_opt/TempVar.cc | 33 +- src/script_opt/TempVar.h | 59 +- src/script_opt/UsageAnalyzer.cc | 402 +- src/script_opt/UsageAnalyzer.h | 100 +- src/script_opt/UseDefs.cc | 1334 ++- src/script_opt/UseDefs.h | 233 +- src/script_opt/ZAM/AM-Opt.cc | 1925 ++-- src/script_opt/ZAM/Branches.cc | 241 +- src/script_opt/ZAM/BuiltIn.cc | 1070 +- src/script_opt/ZAM/BuiltInSupport.cc | 232 +- src/script_opt/ZAM/BuiltInSupport.h | 141 +- src/script_opt/ZAM/Compile.h | 1029 +- src/script_opt/ZAM/Driver.cc | 838 +- src/script_opt/ZAM/Expr.cc | 2265 ++-- src/script_opt/ZAM/Inst-Gen.cc | 228 +- src/script_opt/ZAM/Inst-Gen.h | 9 +- src/script_opt/ZAM/IterInfo.h | 190 +- src/script_opt/ZAM/Low-Level.cc | 272 +- src/script_opt/ZAM/Stmt.cc | 2103 ++-- src/script_opt/ZAM/Support.cc | 143 +- src/script_opt/ZAM/Support.h | 15 +- src/script_opt/ZAM/Vars.cc | 245 +- src/script_opt/ZAM/ZBody.cc | 691 +- src/script_opt/ZAM/ZBody.h | 144 +- src/script_opt/ZAM/ZInst.cc | 1218 +-- src/script_opt/ZAM/ZInst.h | 674 +- src/script_opt/ZAM/ZOp.cc | 177 +- src/script_opt/ZAM/ZOp.h | 76 +- src/session/Key.cc | 115 +- src/session/Key.h | 93 +- src/session/Manager.cc | 374 +- src/session/Manager.h | 122 +- src/session/Session.cc | 306 +- src/session/Session.h | 401 +- src/supervisor/Supervisor.cc | 3169 +++--- src/supervisor/Supervisor.h | 796 +- src/telemetry/Counter.h | 261 +- src/telemetry/Gauge.h | 295 +- src/telemetry/Histogram.h | 265 +- src/telemetry/Manager.cc | 1030 +- src/telemetry/Manager.h | 748 +- src/telemetry/MetricFamily.h | 89 +- src/telemetry/Timer.h | 45 +- src/threading/BasicThread.cc | 265 +- src/threading/BasicThread.h | 332 +- src/threading/Formatter.cc | 169 +- src/threading/Formatter.h | 267 +- src/threading/Manager.cc | 383 +- src/threading/Manager.h | 220 +- src/threading/MsgThread.cc | 680 +- src/threading/MsgThread.h | 741 +- src/threading/Queue.h | 347 +- src/threading/SerialTypes.cc | 1200 +-- src/threading/SerialTypes.h | 413 +- src/threading/formatters/Ascii.cc | 924 +- src/threading/formatters/Ascii.h | 92 +- src/threading/formatters/JSON.cc | 258 +- src/threading/formatters/JSON.h | 65 +- src/threading/formatters/detail/json.h | 26 +- src/util.cc | 4139 ++++---- src/util.h | 252 +- src/zeek-affinity.cc | 54 +- src/zeek-affinity.h | 5 +- src/zeek-setup.cc | 1495 ++- src/zeek-setup.h | 14 +- src/zeekygen/Configuration.cc | 133 +- src/zeekygen/Configuration.h | 67 +- src/zeekygen/IdentifierInfo.cc | 215 +- src/zeekygen/IdentifierInfo.h | 293 +- src/zeekygen/Info.h | 80 +- src/zeekygen/Manager.cc | 652 +- src/zeekygen/Manager.h | 421 +- src/zeekygen/PackageInfo.cc | 61 +- src/zeekygen/PackageInfo.h | 45 +- src/zeekygen/ReStructuredTextTable.cc | 89 +- src/zeekygen/ReStructuredTextTable.h | 60 +- src/zeekygen/ScriptInfo.cc | 675 +- src/zeekygen/ScriptInfo.h | 148 +- src/zeekygen/SpicyModuleInfo.h | 73 +- src/zeekygen/Target.cc | 1092 +- src/zeekygen/Target.h | 467 +- src/zeekygen/utils.cc | 318 +- src/zeekygen/utils.h | 10 +- .../src/FOO.cc | 79 +- .../src/FOO.h | 35 +- .../src/Plugin.cc | 30 +- .../plugins/conflict-plugin/src/Plugin.cc | 24 +- .../plugins/conflict-plugin/src/Plugin.h | 14 +- .../plugins/doctest-plugin/src/Plugin.cc | 29 +- .../btest/plugins/doctest-plugin/src/Plugin.h | 14 +- testing/btest/plugins/file-plugin/src/Foo.cc | 38 +- testing/btest/plugins/file-plugin/src/Foo.h | 17 +- .../btest/plugins/file-plugin/src/Plugin.cc | 27 +- .../btest/plugins/file-plugin/src/Plugin.h | 14 +- .../plugins/func-hook-plugin/src/Plugin.cc | 115 +- .../plugins/func-hook-plugin/src/Plugin.h | 25 +- .../btest/plugins/hooks-plugin/src/Plugin.cc | 507 +- .../btest/plugins/hooks-plugin/src/Plugin.h | 61 +- .../plugins/iosource-plugin/src/Plugin.cc | 97 +- .../plugins/iosource-plugin/src/Plugin.h | 94 +- .../logging-hooks-plugin/src/Plugin.cc | 89 +- .../plugins/logging-hooks-plugin/src/Plugin.h | 29 +- .../packet-protocol-plugin/src/Plugin.cc | 33 +- .../packet-protocol-plugin/src/RawLayer.cc | 27 +- .../packet-protocol-plugin/src/RawLayer.h | 18 +- .../btest/plugins/pktdumper-plugin/src/Foo.cc | 37 +- .../btest/plugins/pktdumper-plugin/src/Foo.h | 24 +- .../plugins/pktdumper-plugin/src/Plugin.cc | 28 +- .../plugins/pktdumper-plugin/src/Plugin.h | 14 +- .../btest/plugins/pktsrc-plugin/src/Foo.cc | 96 +- testing/btest/plugins/pktsrc-plugin/src/Foo.h | 32 +- .../btest/plugins/pktsrc-plugin/src/Plugin.cc | 29 +- .../btest/plugins/pktsrc-plugin/src/Plugin.h | 14 +- .../plugin-load-dependency/1/src/Plugin.cc | 28 +- .../plugin-load-dependency/1/src/Plugin.h | 14 +- .../plugin-load-dependency/2/src/Plugin.cc | 29 +- .../plugin-load-dependency/2/src/Plugin.h | 14 +- .../plugin-load-dependency/3/src/Plugin.cc | 28 +- .../plugin-load-dependency/3/src/Plugin.h | 14 +- .../plugin-load-file-extended/src/Plugin.cc | 100 +- .../plugin-load-file-extended/src/Plugin.h | 20 +- .../src/Plugin.cc | 22 +- .../plugin-nopatchversion-plugin/src/Plugin.h | 14 +- .../src/Plugin.cc | 24 +- .../src/Plugin.h | 14 +- .../btest/plugins/protocol-plugin/src/Foo.cc | 65 +- .../btest/plugins/protocol-plugin/src/Foo.h | 36 +- .../plugins/protocol-plugin/src/Plugin.cc | 40 +- .../plugins/protocol-plugin/src/Plugin.h | 16 +- .../btest/plugins/reader-plugin/src/Foo.cc | 241 +- testing/btest/plugins/reader-plugin/src/Foo.h | 35 +- .../btest/plugins/reader-plugin/src/Plugin.cc | 26 +- .../btest/plugins/reader-plugin/src/Plugin.h | 14 +- .../reporter-hook-plugin/src/Plugin.cc | 55 +- .../plugins/reporter-hook-plugin/src/Plugin.h | 22 +- .../src/Plugin.cc | 48 +- .../src/Plugin.h | 16 +- .../btest/plugins/writer-plugin/src/Foo.cc | 38 +- testing/btest/plugins/writer-plugin/src/Foo.h | 47 +- .../btest/plugins/writer-plugin/src/Plugin.cc | 26 +- .../btest/plugins/writer-plugin/src/Plugin.h | 14 +- .../Files/protocol-plugin/src/Foo.cc | 65 +- .../Files/protocol-plugin/src/Foo.h | 36 +- .../Files/protocol-plugin/src/Plugin.cc | 40 +- .../Files/protocol-plugin/src/Plugin.h | 16 +- .../Files/py-lib-plugin/plugin/src/Plugin.cc | 39 +- .../Files/py-lib-plugin/plugin/src/Plugin.h | 21 +- .../Files/zeek-version-plugin/src/Plugin.cc | 31 +- .../Files/zeek-version-plugin/src/Plugin.h | 16 +- 786 files changed, 131714 insertions(+), 153609 deletions(-) diff --git a/.clang-format b/.clang-format index 4c628b3465..d32dc20b13 100644 --- a/.clang-format +++ b/.clang-format @@ -1,74 +1,66 @@ -# Clang-format configuration for Zeek. This configuration requires -# at least clang-format 12.0.1 to format correctly. - -Language: Cpp -Standard: c++17 - -BreakBeforeBraces: Whitesmiths - -# BraceWrapping: -# AfterCaseLabel: true -# AfterClass: false -# AfterControlStatement: Always -# AfterEnum: false -# AfterFunction: true -# AfterNamespace: false -# AfterStruct: false -# AfterUnion: false -# AfterExternBlock: false -# BeforeCatch: true -# BeforeElse: true -# BeforeWhile: false -# IndentBraces: true -# SplitEmptyFunction: false -# SplitEmptyRecord: false -# SplitEmptyNamespace: false +# Copyright (c) 2020-2023 by the Zeek Project. See LICENSE for details. +--- +Language: Cpp AccessModifierOffset: -4 AlignAfterOpenBracket: Align -AlignTrailingComments: false -AllowShortBlocksOnASingleLine: Empty -AllowShortEnumsOnASingleLine: true -AllowShortFunctionsOnASingleLine: Inline +AlignConsecutiveAssignments: false +AlignConsecutiveDeclarations: false +AlignEscapedNewlines: Right +AlignOperands: true +AlignTrailingComments: true +AllowAllParametersOfDeclarationOnNextLine: false +AllowShortBlocksOnASingleLine: false +AllowShortCaseLabelsOnASingleLine: true +AllowShortFunctionsOnASingleLine: true AllowShortIfStatementsOnASingleLine: false -AllowShortLambdasOnASingleLine: Empty AllowShortLoopsOnASingleLine: false +AlwaysBreakAfterDefinitionReturnType: None AlwaysBreakAfterReturnType: None +AlwaysBreakBeforeMultilineStrings: true +AlwaysBreakTemplateDeclarations: Yes BinPackArguments: true BinPackParameters: true -BreakConstructorInitializers: BeforeColon +BraceWrapping: + AfterClass: false + AfterControlStatement: false + AfterEnum: false + AfterFunction: false + AfterNamespace: false + AfterObjCDeclaration: false + AfterStruct: false + AfterUnion: false + AfterExternBlock: false + BeforeCatch: false + BeforeElse: true + IndentBraces: false + SplitEmptyFunction: false + SplitEmptyRecord: false + SplitEmptyNamespace: false +BreakBeforeBinaryOperators: None +BreakBeforeBraces: Custom +BreakBeforeInheritanceComma: false BreakInheritanceList: BeforeColon -ColumnLimit: 100 -ConstructorInitializerAllOnOneLineOrOnePerLine: false -FixNamespaceComments: false -IndentCaseLabels: true -IndentCaseBlocks: false -IndentExternBlock: NoIndent -IndentPPDirectives: None -IndentWidth: 4 -NamespaceIndentation: None -PointerAlignment: Left -SpaceAfterCStyleCast: false -SpaceAfterLogicalNot: true -SpaceBeforeAssignmentOperators: true -SpaceBeforeCpp11BracedList: false -SpaceBeforeCtorInitializerColon: true -SpaceBeforeInheritanceColon: true -SpaceBeforeParens: ControlStatements -SpaceBeforeRangeBasedForLoopColon: true -SpaceInEmptyBlock: true -SpaceInEmptyParentheses: false -SpacesInAngles: false -SpacesInConditionalStatement: true -SpacesInContainerLiterals: false -SpacesInParentheses: false -TabWidth: 4 -UseTab: AlignWithSpaces - -# Setting this to a high number causes clang-format to prefer breaking somewhere else -# over breaking after the assignment operator in a line that's over the column limit -PenaltyBreakAssignment: 100 - +BreakBeforeTernaryOperators: false +BreakConstructorInitializersBeforeComma: false +BreakConstructorInitializers: BeforeColon +BreakAfterJavaFieldAnnotations: false +BreakStringLiterals: true +ColumnLimit: 120 +CommentPragmas: 'NOLINT' +CompactNamespaces: false +ConstructorInitializerAllOnOneLineOrOnePerLine: true +ConstructorInitializerIndentWidth: 4 +ContinuationIndentWidth: 4 +Cpp11BracedListStyle: true +DerivePointerAlignment: false +DisableFormat: false +ExperimentalAutoDetectBinPacking: false +FixNamespaceComments: true +ForEachMacros: + - foreach + - Q_FOREACH + - BOOST_FOREACH IncludeBlocks: Regroup # Include categories go like this: @@ -98,3 +90,55 @@ IncludeCategories: Priority: 4 - Regex: '.*' Priority: 5 + +IncludeIsMainRegex: '$' +IndentCaseLabels: true +IndentPPDirectives: None +IndentWidth: 4 +IndentWrappedFunctionNames: false +JavaScriptQuotes: Leave +JavaScriptWrapImports: true +KeepEmptyLinesAtTheStartOfBlocks: false +MacroBlockBegin: '^BEGIN_' +MacroBlockEnd: '^END_' +MaxEmptyLinesToKeep: 2 +NamespaceIndentation: None +ObjCBinPackProtocolList: Auto +ObjCBlockIndentWidth: 2 +ObjCSpaceAfterProperty: false +ObjCSpaceBeforeProtocolList: true +PenaltyBreakAssignment: 2 +PenaltyBreakBeforeFirstCallParameter: 500 +PenaltyBreakComment: 300 +PenaltyBreakFirstLessLess: 120 +PenaltyBreakString: 1000 +PenaltyBreakTemplateDeclaration: 10 +PenaltyExcessCharacter: 1000000 +PenaltyReturnTypeOnItsOwnLine: 1000 +PointerAlignment: Left +ReflowComments: true +SortIncludes: true +SortUsingDeclarations: true +SpaceAfterCStyleCast: false +SpaceAfterTemplateKeyword: false +SpaceAfterLogicalNot: true +SpaceBeforeAssignmentOperators: true +SpaceBeforeCpp11BracedList: false +SpaceBeforeCtorInitializerColon: true +SpaceBeforeInheritanceColon: true +SpaceBeforeParens: ControlStatements +SpaceBeforeRangeBasedForLoopColon: true +SpaceInEmptyParentheses: false +SpacesBeforeTrailingComments: 1 +SpacesInAngles: false +SpacesInContainerLiterals: true +SpacesInCStyleCastParentheses: false +SpacesInParentheses: false +SpacesInSquareBrackets: false +SpacesInConditionalStatement: true +Standard: Cpp11 +StatementMacros: + - STANDARD_OPERATOR_1 +TabWidth: 4 +UseTab: Never +... diff --git a/auxil/spicy b/auxil/spicy index 7b8eff527f..d26c81c0a2 160000 --- a/auxil/spicy +++ b/auxil/spicy @@ -1 +1 @@ -Subproject commit 7b8eff527f60ec58eff3242253bdc1f5f1fccbef +Subproject commit d26c81c0a2982ef81339beebff455c23713fb526 diff --git a/cmake b/cmake index 98799bb51a..f7b4fbe489 160000 --- a/cmake +++ b/cmake @@ -1 +1 @@ -Subproject commit 98799bb51aabb282e7dd6372aea7dbcf909469ac +Subproject commit f7b4fbe4892594034d3d9ca639c0ffa6a99fcbe5 diff --git a/doc b/doc index 22fe25d980..01d78f885e 160000 --- a/doc +++ b/doc @@ -1 +1 @@ -Subproject commit 22fe25d980131abdfadb4bdb9390aee347e77023 +Subproject commit 01d78f885e6aac4e853a0b5da559b4c849fee743 diff --git a/src/Anon.cc b/src/Anon.cc index 08f8f370fb..82fbd51f15 100644 --- a/src/Anon.cc +++ b/src/Anon.cc @@ -15,435 +15,392 @@ #include "zeek/net_util.h" #include "zeek/util.h" -namespace zeek::detail - { +namespace zeek::detail { AnonymizeIPAddr* ip_anonymizer[NUM_ADDR_ANONYMIZATION_METHODS] = {nullptr}; -static uint32_t rand32() - { - return ((util::detail::random_number() & 0xffff) << 16) | - (util::detail::random_number() & 0xffff); - } +static uint32_t rand32() { + return ((util::detail::random_number() & 0xffff) << 16) | (util::detail::random_number() & 0xffff); +} // From tcpdpriv. -static int bi_ffs(uint32_t value) - { - int add = 0; - static uint8_t bvals[] = {0, 4, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1}; +static int bi_ffs(uint32_t value) { + int add = 0; + static uint8_t bvals[] = {0, 4, 3, 3, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1}; - if ( (value & 0xFFFF0000) == 0 ) - { - if ( value == 0 ) - // Zero input ==> zero output. - return 0; + if ( (value & 0xFFFF0000) == 0 ) { + if ( value == 0 ) + // Zero input ==> zero output. + return 0; - add += 16; - } + add += 16; + } - else - value >>= 16; + else + value >>= 16; - if ( (value & 0xFF00) == 0 ) - add += 8; - else - value >>= 8; + if ( (value & 0xFF00) == 0 ) + add += 8; + else + value >>= 8; - if ( (value & 0xF0) == 0 ) - add += 4; - else - value >>= 4; + if ( (value & 0xF0) == 0 ) + add += 4; + else + value >>= 4; - return add + bvals[value & 0xf]; - } + return add + bvals[value & 0xf]; +} #define first_n_bit_mask(n) (~(0xFFFFFFFFU >> n)) -ipaddr32_t AnonymizeIPAddr::Anonymize(ipaddr32_t addr) - { - std::map::iterator p = mapping.find(addr); - if ( p != mapping.end() ) - return p->second; - else - { - ipaddr32_t new_addr = anonymize(addr); - mapping[addr] = new_addr; +ipaddr32_t AnonymizeIPAddr::Anonymize(ipaddr32_t addr) { + std::map::iterator p = mapping.find(addr); + if ( p != mapping.end() ) + return p->second; + else { + ipaddr32_t new_addr = anonymize(addr); + mapping[addr] = new_addr; - return new_addr; - } - } + return new_addr; + } +} // Keep the specified prefix unchanged. -bool AnonymizeIPAddr::PreservePrefix(ipaddr32_t /* input */, int /* num_bits */) - { - reporter->InternalError("prefix preserving is not supported for the anonymizer"); - return false; - } +bool AnonymizeIPAddr::PreservePrefix(ipaddr32_t /* input */, int /* num_bits */) { + reporter->InternalError("prefix preserving is not supported for the anonymizer"); + return false; +} -bool AnonymizeIPAddr::PreserveNet(ipaddr32_t input) - { - switch ( addr_to_class(ntohl(input)) ) - { - case 'A': - return PreservePrefix(input, 8); - case 'B': - return PreservePrefix(input, 16); - case 'C': - return PreservePrefix(input, 24); - default: - return false; - } - } +bool AnonymizeIPAddr::PreserveNet(ipaddr32_t input) { + switch ( addr_to_class(ntohl(input)) ) { + case 'A': return PreservePrefix(input, 8); + case 'B': return PreservePrefix(input, 16); + case 'C': return PreservePrefix(input, 24); + default: return false; + } +} -ipaddr32_t AnonymizeIPAddr_Seq::anonymize(ipaddr32_t /* input */) - { - ++seq; - return htonl(seq); - } +ipaddr32_t AnonymizeIPAddr_Seq::anonymize(ipaddr32_t /* input */) { + ++seq; + return htonl(seq); +} -ipaddr32_t AnonymizeIPAddr_RandomMD5::anonymize(ipaddr32_t input) - { - uint8_t digest[16]; - ipaddr32_t output = 0; +ipaddr32_t AnonymizeIPAddr_RandomMD5::anonymize(ipaddr32_t input) { + uint8_t digest[16]; + ipaddr32_t output = 0; - util::detail::hmac_md5(sizeof(input), (u_char*)(&input), digest); + util::detail::hmac_md5(sizeof(input), (u_char*)(&input), digest); - for ( int i = 0; i < 4; ++i ) - output = (output << 8) | digest[i]; + for ( int i = 0; i < 4; ++i ) + output = (output << 8) | digest[i]; - return output; - } + return output; +} // This code is from "On the Design and Performance of Prefix-Preserving // IP Traffic Trace Anonymization", by Xu et al (IMW 2001) // // http://www.imconf.net/imw-2001/proceedings.html -ipaddr32_t AnonymizeIPAddr_PrefixMD5::anonymize(ipaddr32_t input) - { - uint8_t digest[16]; - ipaddr32_t prefix_mask = 0xffffffff; - input = ntohl(input); - ipaddr32_t output = input; +ipaddr32_t AnonymizeIPAddr_PrefixMD5::anonymize(ipaddr32_t input) { + uint8_t digest[16]; + ipaddr32_t prefix_mask = 0xffffffff; + input = ntohl(input); + ipaddr32_t output = input; - for ( int i = 0; i < 32; ++i ) - { - // PAD(x_0 ... x_{i-1}) = x_0 ... x_{i-1} 1 0 ... 0 . - prefix.len = htonl(i + 1); - prefix.prefix = htonl((input & ~(prefix_mask >> i)) | (1 << (31 - i))); + for ( int i = 0; i < 32; ++i ) { + // PAD(x_0 ... x_{i-1}) = x_0 ... x_{i-1} 1 0 ... 0 . + prefix.len = htonl(i + 1); + prefix.prefix = htonl((input & ~(prefix_mask >> i)) | (1 << (31 - i))); - // HK(PAD(x_0 ... x_{i-1})). - util::detail::hmac_md5(sizeof(prefix), (u_char*)&prefix, digest); + // HK(PAD(x_0 ... x_{i-1})). + util::detail::hmac_md5(sizeof(prefix), (u_char*)&prefix, digest); - // f_{i-1} = LSB(HK(PAD(x_0 ... x_{i-1}))). - ipaddr32_t bit_mask = (digest[0] & 1) << (31 - i); + // f_{i-1} = LSB(HK(PAD(x_0 ... x_{i-1}))). + ipaddr32_t bit_mask = (digest[0] & 1) << (31 - i); - // x_i' = x_i ^ f_{i-1}. - output ^= bit_mask; - } + // x_i' = x_i ^ f_{i-1}. + output ^= bit_mask; + } - return htonl(output); - } + return htonl(output); +} -AnonymizeIPAddr_A50::~AnonymizeIPAddr_A50() - { - for ( auto& b : blocks ) - delete[] b; - } +AnonymizeIPAddr_A50::~AnonymizeIPAddr_A50() { + for ( auto& b : blocks ) + delete[] b; +} -void AnonymizeIPAddr_A50::init() - { - root = next_free_node = nullptr; +void AnonymizeIPAddr_A50::init() { + root = next_free_node = nullptr; - // Prepare special nodes for 0.0.0.0 and 255.255.255.255. - memset(&special_nodes[0], 0, sizeof(special_nodes)); - special_nodes[0].input = special_nodes[0].output = 0; - special_nodes[1].input = special_nodes[1].output = 0xFFFFFFFF; + // Prepare special nodes for 0.0.0.0 and 255.255.255.255. + memset(&special_nodes[0], 0, sizeof(special_nodes)); + special_nodes[0].input = special_nodes[0].output = 0; + special_nodes[1].input = special_nodes[1].output = 0xFFFFFFFF; - method = 0; - before_anonymization = 1; - new_mapping = 0; - } + method = 0; + before_anonymization = 1; + new_mapping = 0; +} -bool AnonymizeIPAddr_A50::PreservePrefix(ipaddr32_t input, int num_bits) - { - DEBUG_MSG("%s/%d\n", IPAddr(IPv4, &input, IPAddr::Network).AsString().c_str(), num_bits); +bool AnonymizeIPAddr_A50::PreservePrefix(ipaddr32_t input, int num_bits) { + DEBUG_MSG("%s/%d\n", IPAddr(IPv4, &input, IPAddr::Network).AsString().c_str(), num_bits); - if ( ! before_anonymization ) - { - reporter->Error("prefix preservation specified after anonymization begun"); - return false; - } + if ( ! before_anonymization ) { + reporter->Error("prefix preservation specified after anonymization begun"); + return false; + } - input = ntohl(input); + input = ntohl(input); - // Sanitize input. - input = input & first_n_bit_mask(num_bits); + // Sanitize input. + input = input & first_n_bit_mask(num_bits); - Node* n = find_node(input); + Node* n = find_node(input); - // Preserve the first num_bits bits of addr. - if ( num_bits == 32 ) - n->output = input; + // Preserve the first num_bits bits of addr. + if ( num_bits == 32 ) + n->output = input; - else if ( num_bits > 0 ) - { - assert((0xFFFFFFFFU >> 1) == 0x7FFFFFFFU); - uint32_t suffix_mask = (0xFFFFFFFFU >> num_bits); - uint32_t prefix_mask = ~suffix_mask; - n->output = (input & prefix_mask) | (rand32() & suffix_mask); - } + else if ( num_bits > 0 ) { + assert((0xFFFFFFFFU >> 1) == 0x7FFFFFFFU); + uint32_t suffix_mask = (0xFFFFFFFFU >> num_bits); + uint32_t prefix_mask = ~suffix_mask; + n->output = (input & prefix_mask) | (rand32() & suffix_mask); + } - return true; - } + return true; +} -ipaddr32_t AnonymizeIPAddr_A50::anonymize(ipaddr32_t a) - { - before_anonymization = 0; - new_mapping = 0; +ipaddr32_t AnonymizeIPAddr_A50::anonymize(ipaddr32_t a) { + before_anonymization = 0; + new_mapping = 0; - if ( Node* n = find_node(ntohl(a)) ) - { - ipaddr32_t output = htonl(n->output); - return output; - } - else - return 0; - } + if ( Node* n = find_node(ntohl(a)) ) { + ipaddr32_t output = htonl(n->output); + return output; + } + else + return 0; +} -AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node_block() - { - assert(! next_free_node); +AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node_block() { + assert(! next_free_node); - int block_size = 1024; - Node* block = new Node[block_size]; - if ( ! block ) - reporter->InternalError("out of memory!"); + int block_size = 1024; + Node* block = new Node[block_size]; + if ( ! block ) + reporter->InternalError("out of memory!"); - blocks.push_back(block); + blocks.push_back(block); - for ( int i = 1; i < block_size - 1; ++i ) - block[i].child[0] = &block[i + 1]; + for ( int i = 1; i < block_size - 1; ++i ) + block[i].child[0] = &block[i + 1]; - block[block_size - 1].child[0] = nullptr; - next_free_node = &block[1]; + block[block_size - 1].child[0] = nullptr; + next_free_node = &block[1]; - return &block[0]; - } + return &block[0]; +} -inline AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node() - { - new_mapping = 1; +inline AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::new_node() { + new_mapping = 1; - if ( next_free_node ) - { - Node* n = next_free_node; - next_free_node = n->child[0]; - return n; - } - else - return new_node_block(); - } + if ( next_free_node ) { + Node* n = next_free_node; + next_free_node = n->child[0]; + return n; + } + else + return new_node_block(); +} -inline void AnonymizeIPAddr_A50::free_node(Node* n) - { - n->child[0] = next_free_node; - next_free_node = n; - } +inline void AnonymizeIPAddr_A50::free_node(Node* n) { + n->child[0] = next_free_node; + next_free_node = n; +} -ipaddr32_t AnonymizeIPAddr_A50::make_output(ipaddr32_t old_output, int swivel) const - { - // -A50 anonymization - if ( swivel == 32 ) - return old_output ^ 1; - else - { - // Bits up to swivel are unchanged; bit swivel is flipped. - ipaddr32_t known_part = ((old_output >> (32 - swivel)) ^ 1) << (32 - swivel); +ipaddr32_t AnonymizeIPAddr_A50::make_output(ipaddr32_t old_output, int swivel) const { + // -A50 anonymization + if ( swivel == 32 ) + return old_output ^ 1; + else { + // Bits up to swivel are unchanged; bit swivel is flipped. + ipaddr32_t known_part = ((old_output >> (32 - swivel)) ^ 1) << (32 - swivel); - // Remainder of bits are random. - return known_part | ((rand32() & 0x7FFFFFFF) >> swivel); - } - } + // Remainder of bits are random. + return known_part | ((rand32() & 0x7FFFFFFF) >> swivel); + } +} -AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::make_peer(ipaddr32_t a, Node* n) - { - if ( a == 0 || a == 0xFFFFFFFFU ) - reporter->InternalError("0.0.0.0 and 255.255.255.255 should never get into the tree"); +AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::make_peer(ipaddr32_t a, Node* n) { + if ( a == 0 || a == 0xFFFFFFFFU ) + reporter->InternalError("0.0.0.0 and 255.255.255.255 should never get into the tree"); - // Become a peer. - // Algorithm: create two nodes, the two peers. Leave orig node as - // the parent of the two new ones. + // Become a peer. + // Algorithm: create two nodes, the two peers. Leave orig node as + // the parent of the two new ones. - Node* down[2]; - down[0] = new_node(); - if ( ! down[0] ) - return nullptr; + Node* down[2]; + down[0] = new_node(); + if ( ! down[0] ) + return nullptr; - down[1] = new_node(); - if ( ! down[1] ) - { - free_node(down[0]); - return nullptr; - } + down[1] = new_node(); + if ( ! down[1] ) { + free_node(down[0]); + return nullptr; + } - // swivel is first bit 'a' and 'old->input' differ. - int swivel = bi_ffs(a ^ n->input); + // swivel is first bit 'a' and 'old->input' differ. + int swivel = bi_ffs(a ^ n->input); - // bitvalue is the value of that bit of 'a'. - int bitvalue = (a >> (32 - swivel)) & 1; + // bitvalue is the value of that bit of 'a'. + int bitvalue = (a >> (32 - swivel)) & 1; - down[bitvalue]->input = a; - down[bitvalue]->output = make_output(n->output, swivel); - down[bitvalue]->child[0] = down[bitvalue]->child[1] = nullptr; + down[bitvalue]->input = a; + down[bitvalue]->output = make_output(n->output, swivel); + down[bitvalue]->child[0] = down[bitvalue]->child[1] = nullptr; - *down[1 - bitvalue] = *n; // copy orig node down one level + *down[1 - bitvalue] = *n; // copy orig node down one level - n->input = down[1]->input; // NB: 1s to the right (0s to the left) - n->output = down[1]->output; - n->child[0] = down[0]; // point to children - n->child[1] = down[1]; + n->input = down[1]->input; // NB: 1s to the right (0s to the left) + n->output = down[1]->output; + n->child[0] = down[0]; // point to children + n->child[1] = down[1]; - return down[bitvalue]; - } + return down[bitvalue]; +} -AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::find_node(ipaddr32_t a) - { - // Watch out for special IP addresses, which never make it - // into the tree. - if ( a == 0 || a == 0xFFFFFFFFU ) - return &special_nodes[a & 1]; +AnonymizeIPAddr_A50::Node* AnonymizeIPAddr_A50::find_node(ipaddr32_t a) { + // Watch out for special IP addresses, which never make it + // into the tree. + if ( a == 0 || a == 0xFFFFFFFFU ) + return &special_nodes[a & 1]; - if ( ! root ) - { - root = new_node(); - root->input = a; - root->output = rand32(); - root->child[0] = root->child[1] = nullptr; + if ( ! root ) { + root = new_node(); + root->input = a; + root->output = rand32(); + root->child[0] = root->child[1] = nullptr; - return root; - } + return root; + } - // Straight from tcpdpriv. - Node* n = root; - while ( n ) - { - if ( n->input == a ) - return n; + // Straight from tcpdpriv. + Node* n = root; + while ( n ) { + if ( n->input == a ) + return n; - if ( ! n->child[0] ) - n = make_peer(a, n); + if ( ! n->child[0] ) + n = make_peer(a, n); - else - { - // swivel is the first bit in which the two children - // differ. - int swivel = bi_ffs(n->child[0]->input ^ n->child[1]->input); + else { + // swivel is the first bit in which the two children + // differ. + int swivel = bi_ffs(n->child[0]->input ^ n->child[1]->input); - if ( bi_ffs(a ^ n->input) < swivel ) - // Input differs earlier. - n = make_peer(a, n); + if ( bi_ffs(a ^ n->input) < swivel ) + // Input differs earlier. + n = make_peer(a, n); - else if ( a & (1 << (32 - swivel)) ) - n = n->child[1]; + else if ( a & (1 << (32 - swivel)) ) + n = n->child[1]; - else - n = n->child[0]; - } - } + else + n = n->child[0]; + } + } - reporter->InternalError("out of memory!"); - return nullptr; - } + reporter->InternalError("out of memory!"); + return nullptr; +} static TableValPtr anon_preserve_orig_addr; static TableValPtr anon_preserve_resp_addr; static TableValPtr anon_preserve_other_addr; -void init_ip_addr_anonymizers() - { - ip_anonymizer[KEEP_ORIG_ADDR] = nullptr; - ip_anonymizer[SEQUENTIALLY_NUMBERED] = new AnonymizeIPAddr_Seq(); - ip_anonymizer[RANDOM_MD5] = new AnonymizeIPAddr_RandomMD5(); - ip_anonymizer[PREFIX_PRESERVING_A50] = new AnonymizeIPAddr_A50(); - ip_anonymizer[PREFIX_PRESERVING_MD5] = new AnonymizeIPAddr_PrefixMD5(); +void init_ip_addr_anonymizers() { + ip_anonymizer[KEEP_ORIG_ADDR] = nullptr; + ip_anonymizer[SEQUENTIALLY_NUMBERED] = new AnonymizeIPAddr_Seq(); + ip_anonymizer[RANDOM_MD5] = new AnonymizeIPAddr_RandomMD5(); + ip_anonymizer[PREFIX_PRESERVING_A50] = new AnonymizeIPAddr_A50(); + ip_anonymizer[PREFIX_PRESERVING_MD5] = new AnonymizeIPAddr_PrefixMD5(); - auto id = global_scope()->Find("preserve_orig_addr"); + auto id = global_scope()->Find("preserve_orig_addr"); - if ( id ) - anon_preserve_orig_addr = cast_intrusive(id->GetVal()); + if ( id ) + anon_preserve_orig_addr = cast_intrusive(id->GetVal()); - id = global_scope()->Find("preserve_resp_addr"); + id = global_scope()->Find("preserve_resp_addr"); - if ( id ) - anon_preserve_resp_addr = cast_intrusive(id->GetVal()); + if ( id ) + anon_preserve_resp_addr = cast_intrusive(id->GetVal()); - id = global_scope()->Find("preserve_other_addr"); + id = global_scope()->Find("preserve_other_addr"); - if ( id ) - anon_preserve_other_addr = cast_intrusive(id->GetVal()); - } + if ( id ) + anon_preserve_other_addr = cast_intrusive(id->GetVal()); +} -ipaddr32_t anonymize_ip(ipaddr32_t ip, enum ip_addr_anonymization_class_t cl) - { - TableVal* preserve_addr = nullptr; - auto addr = make_intrusive(ip); +ipaddr32_t anonymize_ip(ipaddr32_t ip, enum ip_addr_anonymization_class_t cl) { + TableVal* preserve_addr = nullptr; + auto addr = make_intrusive(ip); - int method = -1; + int method = -1; - switch ( cl ) - { - case ORIG_ADDR: // client address - preserve_addr = anon_preserve_orig_addr.get(); - method = orig_addr_anonymization; - break; + switch ( cl ) { + case ORIG_ADDR: // client address + preserve_addr = anon_preserve_orig_addr.get(); + method = orig_addr_anonymization; + break; - case RESP_ADDR: // server address - preserve_addr = anon_preserve_resp_addr.get(); - method = resp_addr_anonymization; - break; + case RESP_ADDR: // server address + preserve_addr = anon_preserve_resp_addr.get(); + method = resp_addr_anonymization; + break; - default: - preserve_addr = anon_preserve_other_addr.get(); - method = other_addr_anonymization; - break; - } + default: + preserve_addr = anon_preserve_other_addr.get(); + method = other_addr_anonymization; + break; + } - ipaddr32_t new_ip = 0; + ipaddr32_t new_ip = 0; - if ( preserve_addr && preserve_addr->FindOrDefault(addr) ) - new_ip = ip; + if ( preserve_addr && preserve_addr->FindOrDefault(addr) ) + new_ip = ip; - else if ( method >= 0 && method < NUM_ADDR_ANONYMIZATION_METHODS ) - { - if ( method == KEEP_ORIG_ADDR ) - new_ip = ip; + else if ( method >= 0 && method < NUM_ADDR_ANONYMIZATION_METHODS ) { + if ( method == KEEP_ORIG_ADDR ) + new_ip = ip; - else if ( ! ip_anonymizer[method] ) - reporter->InternalError("IP anonymizer not initialized"); + else if ( ! ip_anonymizer[method] ) + reporter->InternalError("IP anonymizer not initialized"); - else - new_ip = ip_anonymizer[method]->Anonymize(ip); - } + else + new_ip = ip_anonymizer[method]->Anonymize(ip); + } - else - reporter->InternalError("invalid IP anonymization method"); + else + reporter->InternalError("invalid IP anonymization method"); #ifdef LOG_ANONYMIZATION_MAPPING - log_anonymization_mapping(ip, new_ip); + log_anonymization_mapping(ip, new_ip); #endif - return new_ip; - } + return new_ip; +} #ifdef LOG_ANONYMIZATION_MAPPING -void log_anonymization_mapping(ipaddr32_t input, ipaddr32_t output) - { - if ( anonymization_mapping ) - event_mgr.Enqueue(anonymization_mapping, make_intrusive(input), - make_intrusive(output)); - } +void log_anonymization_mapping(ipaddr32_t input, ipaddr32_t output) { + if ( anonymization_mapping ) + event_mgr.Enqueue(anonymization_mapping, make_intrusive(input), make_intrusive(output)); +} #endif - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/Anon.h b/src/Anon.h index f81b03bf94..454a47ca32 100644 --- a/src/Anon.h +++ b/src/Anon.h @@ -14,121 +14,111 @@ #include #include -namespace zeek::detail - { +namespace zeek::detail { // TODO: Anon.h may not be the right place to put these functions ... -enum ip_addr_anonymization_class_t - { - ORIG_ADDR, // client address - RESP_ADDR, // server address - OTHER_ADDR, - NUM_ADDR_ANONYMIZATION_CLASSES, - }; +enum ip_addr_anonymization_class_t { + ORIG_ADDR, // client address + RESP_ADDR, // server address + OTHER_ADDR, + NUM_ADDR_ANONYMIZATION_CLASSES, +}; -enum ip_addr_anonymization_method_t - { - KEEP_ORIG_ADDR, - SEQUENTIALLY_NUMBERED, - RANDOM_MD5, - PREFIX_PRESERVING_A50, - PREFIX_PRESERVING_MD5, - NUM_ADDR_ANONYMIZATION_METHODS, - }; +enum ip_addr_anonymization_method_t { + KEEP_ORIG_ADDR, + SEQUENTIALLY_NUMBERED, + RANDOM_MD5, + PREFIX_PRESERVING_A50, + PREFIX_PRESERVING_MD5, + NUM_ADDR_ANONYMIZATION_METHODS, +}; using ipaddr32_t = uint32_t; // NOTE: all addresses in parameters of *public* functions are in // network order. -class AnonymizeIPAddr - { +class AnonymizeIPAddr { public: - virtual ~AnonymizeIPAddr() = default; + virtual ~AnonymizeIPAddr() = default; - ipaddr32_t Anonymize(ipaddr32_t addr); + ipaddr32_t Anonymize(ipaddr32_t addr); - virtual bool PreservePrefix(ipaddr32_t input, int num_bits); + virtual bool PreservePrefix(ipaddr32_t input, int num_bits); - virtual ipaddr32_t anonymize(ipaddr32_t addr) = 0; + virtual ipaddr32_t anonymize(ipaddr32_t addr) = 0; - bool PreserveNet(ipaddr32_t input); + bool PreserveNet(ipaddr32_t input); protected: - std::map mapping; - }; + std::map mapping; +}; -class AnonymizeIPAddr_Seq : public AnonymizeIPAddr - { +class AnonymizeIPAddr_Seq : public AnonymizeIPAddr { public: - AnonymizeIPAddr_Seq() { seq = 1; } - ipaddr32_t anonymize(ipaddr32_t addr) override; + AnonymizeIPAddr_Seq() { seq = 1; } + ipaddr32_t anonymize(ipaddr32_t addr) override; protected: - ipaddr32_t seq; - }; + ipaddr32_t seq; +}; -class AnonymizeIPAddr_RandomMD5 : public AnonymizeIPAddr - { +class AnonymizeIPAddr_RandomMD5 : public AnonymizeIPAddr { public: - ipaddr32_t anonymize(ipaddr32_t addr) override; - }; + ipaddr32_t anonymize(ipaddr32_t addr) override; +}; -class AnonymizeIPAddr_PrefixMD5 : public AnonymizeIPAddr - { +class AnonymizeIPAddr_PrefixMD5 : public AnonymizeIPAddr { public: - ipaddr32_t anonymize(ipaddr32_t addr) override; + ipaddr32_t anonymize(ipaddr32_t addr) override; protected: - struct anon_prefix - { - int len; - ipaddr32_t prefix; - } prefix; - }; + struct anon_prefix { + int len; + ipaddr32_t prefix; + } prefix; +}; -class AnonymizeIPAddr_A50 : public AnonymizeIPAddr - { +class AnonymizeIPAddr_A50 : public AnonymizeIPAddr { public: - AnonymizeIPAddr_A50() { init(); } - ~AnonymizeIPAddr_A50() override; + AnonymizeIPAddr_A50() { init(); } + ~AnonymizeIPAddr_A50() override; - ipaddr32_t anonymize(ipaddr32_t addr) override; - bool PreservePrefix(ipaddr32_t input, int num_bits) override; + ipaddr32_t anonymize(ipaddr32_t addr) override; + bool PreservePrefix(ipaddr32_t input, int num_bits) override; protected: - struct Node - { - ipaddr32_t input; - ipaddr32_t output; - Node* child[2]; - }; + struct Node { + ipaddr32_t input; + ipaddr32_t output; + Node* child[2]; + }; - int method; - int before_anonymization; - int new_mapping; + int method; + int before_anonymization; + int new_mapping; - // The root of prefix preserving mapping tree. - Node* root; + // The root of prefix preserving mapping tree. + Node* root; - // A node pool for new_node. - Node* next_free_node; - std::vector blocks; + // A node pool for new_node. + Node* next_free_node; + std::vector blocks; - // for 0.0.0.0 and 255.255.255.255. - Node special_nodes[2]; + // for 0.0.0.0 and 255.255.255.255. + Node special_nodes[2]; - void init(); + void init(); - Node* new_node(); - Node* new_node_block(); - void free_node(Node*); + Node* new_node(); + Node* new_node_block(); + void free_node(Node*); - ipaddr32_t make_output(ipaddr32_t, int) const; - Node* make_peer(ipaddr32_t, Node*); - Node* find_node(ipaddr32_t); - }; + ipaddr32_t make_output(ipaddr32_t, int) const; + Node* make_peer(ipaddr32_t, Node*); + Node* find_node(ipaddr32_t); +}; // The global IP anonymizers. extern AnonymizeIPAddr* ip_anonymizer[NUM_ADDR_ANONYMIZATION_METHODS]; @@ -139,4 +129,4 @@ ipaddr32_t anonymize_ip(ipaddr32_t ip, enum ip_addr_anonymization_class_t cl); #define LOG_ANONYMIZATION_MAPPING void log_anonymization_mapping(ipaddr32_t input, ipaddr32_t output); - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/Attr.cc b/src/Attr.cc index 6ffcae2d7e..f3a5ab6b9e 100644 --- a/src/Attr.cc +++ b/src/Attr.cc @@ -11,13 +11,11 @@ #include "zeek/input/Manager.h" #include "zeek/threading/SerialTypes.h" -namespace zeek::detail - { +namespace zeek::detail { -const char* attr_name(AttrTag t) - { - // Do not collapse the list. - // clang-format off +const char* attr_name(AttrTag t) { + // Do not collapse the list. + // clang-format off static const char* attr_names[int(NUM_ATTRS)] = { "&optional", "&default", @@ -45,395 +43,335 @@ const char* attr_name(AttrTag t) "&is_used", "&ordered", }; - // clang-format on + // clang-format on - return attr_names[int(t)]; - } + return attr_names[int(t)]; +} -Attr::Attr(AttrTag t, ExprPtr e) : expr(std::move(e)) - { - tag = t; - SetLocationInfo(&start_location, &end_location); - } +Attr::Attr(AttrTag t, ExprPtr e) : expr(std::move(e)) { + tag = t; + SetLocationInfo(&start_location, &end_location); +} -Attr::Attr(AttrTag t) : Attr(t, nullptr) { } +Attr::Attr(AttrTag t) : Attr(t, nullptr) {} -void Attr::SetAttrExpr(ExprPtr e) - { - expr = std::move(e); - } +void Attr::SetAttrExpr(ExprPtr e) { expr = std::move(e); } -std::string Attr::DeprecationMessage() const - { - if ( tag != ATTR_DEPRECATED ) - return ""; +std::string Attr::DeprecationMessage() const { + if ( tag != ATTR_DEPRECATED ) + return ""; - if ( ! expr ) - return ""; + if ( ! expr ) + return ""; - auto ce = static_cast(expr.get()); - return ce->Value()->AsStringVal()->CheckString(); - } + auto ce = static_cast(expr.get()); + return ce->Value()->AsStringVal()->CheckString(); +} -void Attr::Describe(ODesc* d) const - { - AddTag(d); +void Attr::Describe(ODesc* d) const { + AddTag(d); - if ( expr ) - { - if ( ! d->IsBinary() ) - d->Add("="); + if ( expr ) { + if ( ! d->IsBinary() ) + d->Add("="); - expr->Describe(d); - } - } + expr->Describe(d); + } +} -void Attr::DescribeReST(ODesc* d, bool shorten) const - { - auto add_long_expr_string = [](ODesc* d, const std::string& s, bool shorten) - { - constexpr auto max_expr_chars = 32; - constexpr auto shortened_expr = "*...*"; +void Attr::DescribeReST(ODesc* d, bool shorten) const { + auto add_long_expr_string = [](ODesc* d, const std::string& s, bool shorten) { + constexpr auto max_expr_chars = 32; + constexpr auto shortened_expr = "*...*"; - if ( s.size() > max_expr_chars ) - { - if ( shorten ) - d->Add(shortened_expr); - else - { - // Long inline-literals likely won't wrap well in HTML render - d->Add("*"); - d->Add(s); - d->Add("*"); - } - } - else - { - d->Add("``"); - d->Add(s); - d->Add("``"); - } - }; + if ( s.size() > max_expr_chars ) { + if ( shorten ) + d->Add(shortened_expr); + else { + // Long inline-literals likely won't wrap well in HTML render + d->Add("*"); + d->Add(s); + d->Add("*"); + } + } + else { + d->Add("``"); + d->Add(s); + d->Add("``"); + } + }; - d->Add(":zeek:attr:`"); - AddTag(d); - d->Add("`"); + d->Add(":zeek:attr:`"); + AddTag(d); + d->Add("`"); - if ( expr ) - { - d->SP(); - d->Add("="); - d->SP(); + if ( expr ) { + d->SP(); + d->Add("="); + d->SP(); - if ( expr->Tag() == EXPR_NAME ) - { - d->Add(":zeek:see:`"); - expr->Describe(d); - d->Add("`"); - } + if ( expr->Tag() == EXPR_NAME ) { + d->Add(":zeek:see:`"); + expr->Describe(d); + d->Add("`"); + } - else if ( expr->GetType()->Tag() == TYPE_FUNC ) - { - d->Add(":zeek:type:`"); - d->Add(expr->GetType()->AsFuncType()->FlavorString()); - d->Add("`"); - } + else if ( expr->GetType()->Tag() == TYPE_FUNC ) { + d->Add(":zeek:type:`"); + d->Add(expr->GetType()->AsFuncType()->FlavorString()); + d->Add("`"); + } - else if ( expr->Tag() == EXPR_CONST ) - { - ODesc dd; - dd.SetQuotes(true); - expr->Describe(&dd); - std::string s = dd.Description(); - add_long_expr_string(d, s, shorten); - } + else if ( expr->Tag() == EXPR_CONST ) { + ODesc dd; + dd.SetQuotes(true); + expr->Describe(&dd); + std::string s = dd.Description(); + add_long_expr_string(d, s, shorten); + } - else - { - ODesc dd; - expr->Eval(nullptr)->Describe(&dd); - std::string s = dd.Description(); + else { + ODesc dd; + expr->Eval(nullptr)->Describe(&dd); + std::string s = dd.Description(); - for ( size_t i = 0; i < s.size(); ++i ) - if ( s[i] == '\n' ) - s[i] = ' '; + for ( size_t i = 0; i < s.size(); ++i ) + if ( s[i] == '\n' ) + s[i] = ' '; - add_long_expr_string(d, s, shorten); - } - } - } + add_long_expr_string(d, s, shorten); + } + } +} -void Attr::AddTag(ODesc* d) const - { - if ( d->IsBinary() ) - d->Add(static_cast(Tag())); - else - d->Add(attr_name(Tag())); - } +void Attr::AddTag(ODesc* d) const { + if ( d->IsBinary() ) + d->Add(static_cast(Tag())); + else + d->Add(attr_name(Tag())); +} -detail::TraversalCode Attr::Traverse(detail::TraversalCallback* cb) const - { - auto tc = cb->PreAttr(this); - HANDLE_TC_ATTR_PRE(tc); +detail::TraversalCode Attr::Traverse(detail::TraversalCallback* cb) const { + auto tc = cb->PreAttr(this); + HANDLE_TC_ATTR_PRE(tc); - if ( expr ) - { - auto tc = expr->Traverse(cb); - HANDLE_TC_ATTR_PRE(tc); - } + if ( expr ) { + auto tc = expr->Traverse(cb); + HANDLE_TC_ATTR_PRE(tc); + } - tc = cb->PostAttr(this); - HANDLE_TC_ATTR_POST(tc); - } + tc = cb->PostAttr(this); + HANDLE_TC_ATTR_POST(tc); +} Attributes::Attributes(TypePtr t, bool arg_in_record, bool is_global) - : Attributes(std::vector{}, std::move(t), arg_in_record, is_global) - { - } + : Attributes(std::vector{}, std::move(t), arg_in_record, is_global) {} -Attributes::Attributes(std::vector a, TypePtr t, bool arg_in_record, bool is_global) - : type(std::move(t)) - { - attrs.reserve(a.size()); - in_record = arg_in_record; - global_var = is_global; +Attributes::Attributes(std::vector a, TypePtr t, bool arg_in_record, bool is_global) : type(std::move(t)) { + attrs.reserve(a.size()); + in_record = arg_in_record; + global_var = is_global; - SetLocationInfo(&start_location, &end_location); + SetLocationInfo(&start_location, &end_location); - // We loop through 'a' and add each attribute individually, - // rather than just taking over 'a' for ourselves, so that - // the necessary checking gets done. + // We loop through 'a' and add each attribute individually, + // rather than just taking over 'a' for ourselves, so that + // the necessary checking gets done. - for ( auto& attr : a ) - AddAttr(std::move(attr)); - } + for ( auto& attr : a ) + AddAttr(std::move(attr)); +} -void Attributes::AddAttr(AttrPtr attr, bool is_redef) - { - auto acceptable_duplicate_attr = [](const AttrPtr& attr, const AttrPtr& existing) -> bool - { - if ( attr == existing ) - return true; +void Attributes::AddAttr(AttrPtr attr, bool is_redef) { + auto acceptable_duplicate_attr = [](const AttrPtr& attr, const AttrPtr& existing) -> bool { + if ( attr == existing ) + return true; - AttrTag new_tag = attr->Tag(); + AttrTag new_tag = attr->Tag(); - if ( new_tag == ATTR_DEPRECATED ) - { - if ( ! attr->DeprecationMessage().empty() || - (existing && ! existing->DeprecationMessage().empty()) ) - return false; + if ( new_tag == ATTR_DEPRECATED ) { + if ( ! attr->DeprecationMessage().empty() || (existing && ! existing->DeprecationMessage().empty()) ) + return false; - return true; - } + return true; + } - return new_tag == ATTR_LOG || new_tag == ATTR_OPTIONAL || new_tag == ATTR_REDEF || - new_tag == ATTR_BROKER_STORE_ALLOW_COMPLEX || new_tag == ATTR_RAW_OUTPUT || - new_tag == ATTR_ERROR_HANDLER || new_tag == ATTR_IS_USED; - }; + return new_tag == ATTR_LOG || new_tag == ATTR_OPTIONAL || new_tag == ATTR_REDEF || + new_tag == ATTR_BROKER_STORE_ALLOW_COMPLEX || new_tag == ATTR_RAW_OUTPUT || + new_tag == ATTR_ERROR_HANDLER || new_tag == ATTR_IS_USED; + }; - // A `redef` is allowed to overwrite an existing attribute instead of - // flagging it as ambiguous. - if ( ! is_redef ) - { - auto existing = Find(attr->Tag()); - if ( existing && ! acceptable_duplicate_attr(attr, existing) ) - reporter->Error("Duplicate %s attribute is ambiguous", attr_name(attr->Tag())); - } + // A `redef` is allowed to overwrite an existing attribute instead of + // flagging it as ambiguous. + if ( ! is_redef ) { + auto existing = Find(attr->Tag()); + if ( existing && ! acceptable_duplicate_attr(attr, existing) ) + reporter->Error("Duplicate %s attribute is ambiguous", attr_name(attr->Tag())); + } - // We overwrite old attributes by deleting them first. - RemoveAttr(attr->Tag()); - attrs.emplace_back(attr); + // We overwrite old attributes by deleting them first. + RemoveAttr(attr->Tag()); + attrs.emplace_back(attr); - // We only check the attribute after we've added it, to facilitate - // generating error messages via Attributes::Describe. If the - // instantiator of the object specified a null type, however, then - // that's a signal to skip the checking. If the type is error, - // there's no point checking attributes either. - if ( type && ! IsErrorType(type->Tag()) ) - CheckAttr(attr.get()); + // We only check the attribute after we've added it, to facilitate + // generating error messages via Attributes::Describe. If the + // instantiator of the object specified a null type, however, then + // that's a signal to skip the checking. If the type is error, + // there's no point checking attributes either. + if ( type && ! IsErrorType(type->Tag()) ) + CheckAttr(attr.get()); - // For ADD_FUNC or DEL_FUNC, add in an implicit REDEF, since - // those attributes only have meaning for a redefinable value. - if ( (attr->Tag() == ATTR_ADD_FUNC || attr->Tag() == ATTR_DEL_FUNC) && ! Find(ATTR_REDEF) ) - { - auto a = make_intrusive(ATTR_REDEF); - attrs.emplace_back(std::move(a)); - } + // For ADD_FUNC or DEL_FUNC, add in an implicit REDEF, since + // those attributes only have meaning for a redefinable value. + if ( (attr->Tag() == ATTR_ADD_FUNC || attr->Tag() == ATTR_DEL_FUNC) && ! Find(ATTR_REDEF) ) { + auto a = make_intrusive(ATTR_REDEF); + attrs.emplace_back(std::move(a)); + } - // For DEFAULT, add an implicit OPTIONAL if it's not a global. - if ( ! global_var && attr->Tag() == ATTR_DEFAULT && ! Find(ATTR_OPTIONAL) ) - { - auto a = make_intrusive(ATTR_OPTIONAL); - attrs.emplace_back(std::move(a)); - } - } + // For DEFAULT, add an implicit OPTIONAL if it's not a global. + if ( ! global_var && attr->Tag() == ATTR_DEFAULT && ! Find(ATTR_OPTIONAL) ) { + auto a = make_intrusive(ATTR_OPTIONAL); + attrs.emplace_back(std::move(a)); + } +} -void Attributes::AddAttrs(const AttributesPtr& a, bool is_redef) - { - for ( const auto& attr : a->GetAttrs() ) - AddAttr(attr, is_redef); - } +void Attributes::AddAttrs(const AttributesPtr& a, bool is_redef) { + for ( const auto& attr : a->GetAttrs() ) + AddAttr(attr, is_redef); +} -const AttrPtr& Attributes::Find(AttrTag t) const - { - for ( const auto& a : attrs ) - if ( a->Tag() == t ) - return a; +const AttrPtr& Attributes::Find(AttrTag t) const { + for ( const auto& a : attrs ) + if ( a->Tag() == t ) + return a; - return Attr::nil; - } + return Attr::nil; +} -void Attributes::RemoveAttr(AttrTag t) - { - for ( auto it = attrs.begin(); it != attrs.end(); ) - { - if ( (*it)->Tag() == t ) - it = attrs.erase(it); - else - ++it; - } - } +void Attributes::RemoveAttr(AttrTag t) { + for ( auto it = attrs.begin(); it != attrs.end(); ) { + if ( (*it)->Tag() == t ) + it = attrs.erase(it); + else + ++it; + } +} -void Attributes::Describe(ODesc* d) const - { - if ( attrs.empty() ) - { - d->AddCount(0); - return; - } +void Attributes::Describe(ODesc* d) const { + if ( attrs.empty() ) { + d->AddCount(0); + return; + } - d->AddCount(static_cast(attrs.size())); + d->AddCount(static_cast(attrs.size())); - for ( size_t i = 0; i < attrs.size(); ++i ) - { - if ( d->IsReadable() && i > 0 ) - d->Add(", "); + for ( size_t i = 0; i < attrs.size(); ++i ) { + if ( d->IsReadable() && i > 0 ) + d->Add(", "); - attrs[i]->Describe(d); - } - } + attrs[i]->Describe(d); + } +} -void Attributes::DescribeReST(ODesc* d, bool shorten) const - { - for ( size_t i = 0; i < attrs.size(); ++i ) - { - if ( i > 0 ) - d->Add(" "); +void Attributes::DescribeReST(ODesc* d, bool shorten) const { + for ( size_t i = 0; i < attrs.size(); ++i ) { + if ( i > 0 ) + d->Add(" "); - attrs[i]->DescribeReST(d, shorten); - } - } + attrs[i]->DescribeReST(d, shorten); + } +} -void Attributes::CheckAttr(Attr* a) - { - switch ( a->Tag() ) - { - case ATTR_DEPRECATED: - case ATTR_REDEF: - case ATTR_IS_ASSIGNED: - case ATTR_IS_USED: - break; +void Attributes::CheckAttr(Attr* a) { + switch ( a->Tag() ) { + case ATTR_DEPRECATED: + case ATTR_REDEF: + case ATTR_IS_ASSIGNED: + case ATTR_IS_USED: break; - case ATTR_OPTIONAL: - if ( global_var ) - Error("&optional is not valid for global variables"); - break; + case ATTR_OPTIONAL: + if ( global_var ) + Error("&optional is not valid for global variables"); + break; - case ATTR_ADD_FUNC: - case ATTR_DEL_FUNC: - { - bool is_add = a->Tag() == ATTR_ADD_FUNC; + case ATTR_ADD_FUNC: + case ATTR_DEL_FUNC: { + bool is_add = a->Tag() == ATTR_ADD_FUNC; - const auto& at = a->GetExpr()->GetType(); - if ( at->Tag() != TYPE_FUNC ) - { - a->GetExpr()->Error(is_add ? "&add_func must be a function" - : "&delete_func must be a function"); - break; - } + const auto& at = a->GetExpr()->GetType(); + if ( at->Tag() != TYPE_FUNC ) { + a->GetExpr()->Error(is_add ? "&add_func must be a function" : "&delete_func must be a function"); + break; + } - FuncType* aft = at->AsFuncType(); - if ( ! same_type(aft->Yield(), type) ) - { - a->GetExpr()->Error(is_add - ? "&add_func function must yield same type as variable" - : "&delete_func function must yield same type as variable"); - break; - } - } - break; + FuncType* aft = at->AsFuncType(); + if ( ! same_type(aft->Yield(), type) ) { + a->GetExpr()->Error(is_add ? "&add_func function must yield same type as variable" : + "&delete_func function must yield same type as variable"); + break; + } + } break; - case ATTR_DEFAULT_INSERT: - { - if ( ! type->IsTable() ) - { - Error("&default_insert only applicable to tables"); - break; - } + case ATTR_DEFAULT_INSERT: { + if ( ! type->IsTable() ) { + Error("&default_insert only applicable to tables"); + break; + } - if ( Find(ATTR_DEFAULT) ) - { - Error("&default and &default_insert cannot be used together"); - break; - } + if ( Find(ATTR_DEFAULT) ) { + Error("&default and &default_insert cannot be used together"); + break; + } - std::string err_msg; - if ( ! check_default_attr(a, type, global_var, in_record, err_msg) && - ! err_msg.empty() ) - Error(err_msg.c_str()); - break; - } + std::string err_msg; + if ( ! check_default_attr(a, type, global_var, in_record, err_msg) && ! err_msg.empty() ) + Error(err_msg.c_str()); + break; + } - case ATTR_DEFAULT: - { - if ( Find(ATTR_DEFAULT_INSERT) ) - { - Error("&default and &default_insert cannot be used together"); - break; - } + case ATTR_DEFAULT: { + if ( Find(ATTR_DEFAULT_INSERT) ) { + Error("&default and &default_insert cannot be used together"); + break; + } - std::string err_msg; - if ( ! check_default_attr(a, type, global_var, in_record, err_msg) && - ! err_msg.empty() ) - Error(err_msg.c_str()); - break; - } + std::string err_msg; + if ( ! check_default_attr(a, type, global_var, in_record, err_msg) && ! err_msg.empty() ) + Error(err_msg.c_str()); + break; + } - case ATTR_EXPIRE_READ: - { - if ( Find(ATTR_BROKER_STORE) ) - Error("&broker_store and &read_expire cannot be used simultaneously"); + case ATTR_EXPIRE_READ: { + if ( Find(ATTR_BROKER_STORE) ) + Error("&broker_store and &read_expire cannot be used simultaneously"); - if ( Find(ATTR_BACKEND) ) - Error("&backend and &read_expire cannot be used simultaneously"); - } - // fallthrough + if ( Find(ATTR_BACKEND) ) + Error("&backend and &read_expire cannot be used simultaneously"); + } + // fallthrough - case ATTR_EXPIRE_WRITE: - case ATTR_EXPIRE_CREATE: - { - if ( type->Tag() != TYPE_TABLE ) - { - Error("expiration only applicable to sets/tables"); - break; - } + case ATTR_EXPIRE_WRITE: + case ATTR_EXPIRE_CREATE: { + if ( type->Tag() != TYPE_TABLE ) { + Error("expiration only applicable to sets/tables"); + break; + } - int num_expires = 0; + int num_expires = 0; - for ( const auto& at : attrs ) - { - if ( at->Tag() == ATTR_EXPIRE_READ || at->Tag() == ATTR_EXPIRE_WRITE || - at->Tag() == ATTR_EXPIRE_CREATE ) - num_expires++; - } + for ( const auto& at : attrs ) { + if ( at->Tag() == ATTR_EXPIRE_READ || at->Tag() == ATTR_EXPIRE_WRITE || + at->Tag() == ATTR_EXPIRE_CREATE ) + num_expires++; + } - if ( num_expires > 1 ) - { - Error("set/table can only have one of &read_expire, &write_expire, " - "&create_expire"); - break; - } - } + if ( num_expires > 1 ) { + Error( + "set/table can only have one of &read_expire, &write_expire, " + "&create_expire"); + break; + } + } #if 0 //### not easy to test this w/o knowing the ID. @@ -441,389 +379,345 @@ void Attributes::CheckAttr(Attr* a) Error("expiration not supported for local variables"); #endif - break; - - case ATTR_EXPIRE_FUNC: - { - if ( type->Tag() != TYPE_TABLE ) - { - Error("expiration only applicable to tables"); - break; - } - - type->AsTableType()->CheckExpireFuncCompatibility({NewRef{}, a}); - - if ( Find(ATTR_BROKER_STORE) ) - Error("&broker_store and &expire_func cannot be used simultaneously"); - - if ( Find(ATTR_BACKEND) ) - Error("&backend and &expire_func cannot be used simultaneously"); - - break; - } - - case ATTR_ON_CHANGE: - { - if ( type->Tag() != TYPE_TABLE ) - { - Error("&on_change only applicable to sets/tables"); - break; - } - - const auto& change_func = a->GetExpr(); - - if ( change_func->GetType()->Tag() != TYPE_FUNC || - change_func->GetType()->AsFuncType()->Flavor() != FUNC_FLAVOR_FUNCTION ) - Error("&on_change attribute is not a function"); - - const FuncType* c_ft = change_func->GetType()->AsFuncType(); - - if ( c_ft->Yield()->Tag() != TYPE_VOID ) - { - Error("&on_change must not return a value"); - break; - } - - const TableType* the_table = type->AsTableType(); - - if ( the_table->IsUnspecifiedTable() ) - break; - - const auto& args = c_ft->ParamList()->GetTypes(); - const auto& t_indexes = the_table->GetIndexTypes(); - if ( args.size() != (type->IsSet() ? 2 : 3) + t_indexes.size() ) - { - Error("&on_change function has incorrect number of arguments"); - break; - } - - if ( ! same_type(args[0], the_table->AsTableType()) ) - { - Error("&on_change: first argument must be of same type as table"); - break; - } - - // can't check exact type here yet - the data structures don't exist yet. - if ( args[1]->Tag() != TYPE_ENUM ) - { - Error("&on_change: second argument must be a TableChange enum"); - break; - } - - for ( size_t i = 0; i < t_indexes.size(); i++ ) - { - if ( ! same_type(args[2 + i], t_indexes[i]) ) - { - Error("&on_change: index types do not match table"); - break; - } - } - - if ( ! type->IsSet() ) - if ( ! same_type(args[2 + t_indexes.size()], the_table->Yield()) ) - { - Error("&on_change: value type does not match table"); - break; - } - } - break; - - case ATTR_BACKEND: - { - if ( ! global_var || type->Tag() != TYPE_TABLE ) - { - Error("&backend only applicable to global sets/tables"); - break; - } - - // cannot do better equality check - the Broker types are not - // actually existing yet when we are here. We will do that - // later - before actually attaching to a broker store - if ( a->GetExpr()->GetType()->Tag() != TYPE_ENUM ) - { - Error("&backend must take an enum argument"); - break; - } - - // Only support atomic types for the moment, unless - // explicitly overridden - if ( ! type->AsTableType()->IsSet() && - ! input::Manager::IsCompatibleType(type->AsTableType()->Yield().get(), true) && - ! Find(ATTR_BROKER_STORE_ALLOW_COMPLEX) ) - { - Error("&backend only supports atomic types as table value"); - } - - if ( Find(ATTR_EXPIRE_FUNC) ) - Error("&backend and &expire_func cannot be used simultaneously"); - - if ( Find(ATTR_EXPIRE_READ) ) - Error("&backend and &read_expire cannot be used simultaneously"); - - if ( Find(ATTR_BROKER_STORE) ) - Error("&backend and &broker_store cannot be used simultaneously"); - - break; - } - - case ATTR_BROKER_STORE: - { - if ( type->Tag() != TYPE_TABLE ) - { - Error("&broker_store only applicable to sets/tables"); - break; - } - - if ( a->GetExpr()->GetType()->Tag() != TYPE_STRING ) - { - Error("&broker_store must take a string argument"); - break; - } - - // Only support atomic types for the moment, unless - // explicitly overridden - if ( ! type->AsTableType()->IsSet() && - ! input::Manager::IsCompatibleType(type->AsTableType()->Yield().get(), true) && - ! Find(ATTR_BROKER_STORE_ALLOW_COMPLEX) ) - { - Error("&broker_store only supports atomic types as table value"); - } - - if ( Find(ATTR_EXPIRE_FUNC) ) - Error("&broker_store and &expire_func cannot be used simultaneously"); - - if ( Find(ATTR_EXPIRE_READ) ) - Error("&broker_store and &read_expire cannot be used simultaneously"); - - if ( Find(ATTR_BACKEND) ) - Error("&backend and &broker_store cannot be used simultaneously"); - - break; - } - - case ATTR_BROKER_STORE_ALLOW_COMPLEX: - { - if ( type->Tag() != TYPE_TABLE ) - { - Error("&broker_allow_complex_type only applicable to sets/tables"); - break; - } - } - - case ATTR_TRACKED: - // FIXME: Check here for global ID? - break; - - case ATTR_RAW_OUTPUT: - if ( type->Tag() != TYPE_FILE ) - Error("&raw_output only applicable to files"); - break; - - case ATTR_PRIORITY: - Error("&priority only applicable to event bodies"); - break; - - case ATTR_GROUP: - if ( type->Tag() != TYPE_FUNC || type->AsFuncType()->Flavor() != FUNC_FLAVOR_EVENT ) - Error("&group only applicable to events"); - break; - - case ATTR_ERROR_HANDLER: - if ( type->Tag() != TYPE_FUNC || type->AsFuncType()->Flavor() != FUNC_FLAVOR_EVENT ) - Error("&error_handler only applicable to events"); - break; - - case ATTR_LOG: - if ( ! threading::Value::IsCompatibleType(type.get()) ) - Error("&log applied to a type that cannot be logged"); - break; - - case ATTR_TYPE_COLUMN: - { - if ( type->Tag() != TYPE_PORT ) - { - Error("type_column tag only applicable to ports"); - break; - } - - const auto& atype = a->GetExpr()->GetType(); - - if ( atype->Tag() != TYPE_STRING ) - { - Error("type column needs to have a string argument"); - break; - } - - break; - } - - case ATTR_ORDERED: - if ( type->Tag() != TYPE_TABLE ) - Error("&ordered only applicable to tables"); - break; - - default: - BadTag("Attributes::CheckAttr", attr_name(a->Tag())); - } - } - -bool Attributes::operator==(const Attributes& other) const - { - if ( attrs.empty() ) - return other.attrs.empty(); - - if ( other.attrs.empty() ) - return false; - - for ( const auto& a : attrs ) - { - const auto& o = other.Find(a->Tag()); - - if ( ! o ) - return false; - - if ( ! (*a == *o) ) - return false; - } - - for ( const auto& o : other.attrs ) - { - const auto& a = Find(o->Tag()); - - if ( ! a ) - return false; - - if ( ! (*a == *o) ) - return false; - } - - return true; - } - -bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_record, - std::string& err_msg) - { - ASSERT(a->Tag() == ATTR_DEFAULT || a->Tag() == ATTR_DEFAULT_INSERT); - std::string aname = attr_name(a->Tag()); - // &default is allowed for global tables, since it's used in - // initialization of table fields. It's not allowed otherwise. - if ( global_var && ! type->IsTable() ) - { - err_msg = aname + " is not valid for global variables except for tables"; - return false; - } - - const auto& atype = a->GetExpr()->GetType(); - - if ( type->Tag() != TYPE_TABLE || (type->IsSet() && ! in_record) ) - { - if ( same_type(atype, type) ) - // Ok. - return true; - - // Record defaults may be promotable. - if ( (type->Tag() == TYPE_RECORD && atype->Tag() == TYPE_RECORD && - record_promotion_compatible(atype->AsRecordType(), type->AsRecordType())) ) - // Ok. - return true; - - if ( type->Tag() == TYPE_TABLE && type->AsTableType()->IsUnspecifiedTable() ) - // Ok. - return true; - - auto e = check_and_promote_expr(a->GetExpr(), type); - - if ( e ) - { - a->SetAttrExpr(std::move(e)); - // Ok. - return true; - } - - a->GetExpr()->Error(util::fmt("%s value has inconsistent type", aname.c_str()), type.get()); - return false; - } - - TableType* tt = type->AsTableType(); - const auto& ytype = tt->Yield(); - - if ( ! in_record ) - { // &default applies to the type itself. - if ( same_type(atype, ytype) ) - return true; - - // It can still be a default function. - if ( atype->Tag() == TYPE_FUNC ) - { - FuncType* f = atype->AsFuncType(); - if ( ! f->CheckArgs(tt->GetIndexTypes()) || ! same_type(f->Yield(), ytype) ) - { - err_msg = aname + " function type clash"; - return false; - } - - // Ok. - return true; - } - - // Table defaults may be promotable. - if ( (ytype->Tag() == TYPE_RECORD && atype->Tag() == TYPE_RECORD && - record_promotion_compatible(atype->AsRecordType(), ytype->AsRecordType())) ) - // Ok. - return true; - - auto e = check_and_promote_expr(a->GetExpr(), ytype); - - if ( e ) - { - a->SetAttrExpr(std::move(e)); - // Ok. - return true; - } - - err_msg = aname + " value has inconsistent type"; - return false; - } - - // &default applies to record field. - - if ( same_type(atype, type) ) - return true; - - if ( (atype->Tag() == TYPE_TABLE && atype->AsTableType()->IsUnspecifiedTable()) ) - { - auto e = check_and_promote_expr(a->GetExpr(), type); - - if ( e ) - { - a->SetAttrExpr(std::move(e)); - return true; - } - } - - // Table defaults may be promotable. - if ( ytype && ytype->Tag() == TYPE_RECORD && atype->Tag() == TYPE_RECORD && - record_promotion_compatible(atype->AsRecordType(), ytype->AsRecordType()) ) - // Ok. - return true; - - err_msg = "&default value has inconsistent type"; - return false; - } - -detail::TraversalCode Attributes::Traverse(detail::TraversalCallback* cb) const - { - auto tc = cb->PreAttrs(this); - HANDLE_TC_ATTRS_PRE(tc); - - for ( const auto& a : attrs ) - { - tc = a->Traverse(cb); - HANDLE_TC_ATTRS_PRE(tc); - } - - tc = cb->PostAttrs(this); - HANDLE_TC_ATTRS_POST(tc); - } - - } + break; + + case ATTR_EXPIRE_FUNC: { + if ( type->Tag() != TYPE_TABLE ) { + Error("expiration only applicable to tables"); + break; + } + + type->AsTableType()->CheckExpireFuncCompatibility({NewRef{}, a}); + + if ( Find(ATTR_BROKER_STORE) ) + Error("&broker_store and &expire_func cannot be used simultaneously"); + + if ( Find(ATTR_BACKEND) ) + Error("&backend and &expire_func cannot be used simultaneously"); + + break; + } + + case ATTR_ON_CHANGE: { + if ( type->Tag() != TYPE_TABLE ) { + Error("&on_change only applicable to sets/tables"); + break; + } + + const auto& change_func = a->GetExpr(); + + if ( change_func->GetType()->Tag() != TYPE_FUNC || + change_func->GetType()->AsFuncType()->Flavor() != FUNC_FLAVOR_FUNCTION ) + Error("&on_change attribute is not a function"); + + const FuncType* c_ft = change_func->GetType()->AsFuncType(); + + if ( c_ft->Yield()->Tag() != TYPE_VOID ) { + Error("&on_change must not return a value"); + break; + } + + const TableType* the_table = type->AsTableType(); + + if ( the_table->IsUnspecifiedTable() ) + break; + + const auto& args = c_ft->ParamList()->GetTypes(); + const auto& t_indexes = the_table->GetIndexTypes(); + if ( args.size() != (type->IsSet() ? 2 : 3) + t_indexes.size() ) { + Error("&on_change function has incorrect number of arguments"); + break; + } + + if ( ! same_type(args[0], the_table->AsTableType()) ) { + Error("&on_change: first argument must be of same type as table"); + break; + } + + // can't check exact type here yet - the data structures don't exist yet. + if ( args[1]->Tag() != TYPE_ENUM ) { + Error("&on_change: second argument must be a TableChange enum"); + break; + } + + for ( size_t i = 0; i < t_indexes.size(); i++ ) { + if ( ! same_type(args[2 + i], t_indexes[i]) ) { + Error("&on_change: index types do not match table"); + break; + } + } + + if ( ! type->IsSet() ) + if ( ! same_type(args[2 + t_indexes.size()], the_table->Yield()) ) { + Error("&on_change: value type does not match table"); + break; + } + } break; + + case ATTR_BACKEND: { + if ( ! global_var || type->Tag() != TYPE_TABLE ) { + Error("&backend only applicable to global sets/tables"); + break; + } + + // cannot do better equality check - the Broker types are not + // actually existing yet when we are here. We will do that + // later - before actually attaching to a broker store + if ( a->GetExpr()->GetType()->Tag() != TYPE_ENUM ) { + Error("&backend must take an enum argument"); + break; + } + + // Only support atomic types for the moment, unless + // explicitly overridden + if ( ! type->AsTableType()->IsSet() && + ! input::Manager::IsCompatibleType(type->AsTableType()->Yield().get(), true) && + ! Find(ATTR_BROKER_STORE_ALLOW_COMPLEX) ) { + Error("&backend only supports atomic types as table value"); + } + + if ( Find(ATTR_EXPIRE_FUNC) ) + Error("&backend and &expire_func cannot be used simultaneously"); + + if ( Find(ATTR_EXPIRE_READ) ) + Error("&backend and &read_expire cannot be used simultaneously"); + + if ( Find(ATTR_BROKER_STORE) ) + Error("&backend and &broker_store cannot be used simultaneously"); + + break; + } + + case ATTR_BROKER_STORE: { + if ( type->Tag() != TYPE_TABLE ) { + Error("&broker_store only applicable to sets/tables"); + break; + } + + if ( a->GetExpr()->GetType()->Tag() != TYPE_STRING ) { + Error("&broker_store must take a string argument"); + break; + } + + // Only support atomic types for the moment, unless + // explicitly overridden + if ( ! type->AsTableType()->IsSet() && + ! input::Manager::IsCompatibleType(type->AsTableType()->Yield().get(), true) && + ! Find(ATTR_BROKER_STORE_ALLOW_COMPLEX) ) { + Error("&broker_store only supports atomic types as table value"); + } + + if ( Find(ATTR_EXPIRE_FUNC) ) + Error("&broker_store and &expire_func cannot be used simultaneously"); + + if ( Find(ATTR_EXPIRE_READ) ) + Error("&broker_store and &read_expire cannot be used simultaneously"); + + if ( Find(ATTR_BACKEND) ) + Error("&backend and &broker_store cannot be used simultaneously"); + + break; + } + + case ATTR_BROKER_STORE_ALLOW_COMPLEX: { + if ( type->Tag() != TYPE_TABLE ) { + Error("&broker_allow_complex_type only applicable to sets/tables"); + break; + } + } + + case ATTR_TRACKED: + // FIXME: Check here for global ID? + break; + + case ATTR_RAW_OUTPUT: + if ( type->Tag() != TYPE_FILE ) + Error("&raw_output only applicable to files"); + break; + + case ATTR_PRIORITY: Error("&priority only applicable to event bodies"); break; + + case ATTR_GROUP: + if ( type->Tag() != TYPE_FUNC || type->AsFuncType()->Flavor() != FUNC_FLAVOR_EVENT ) + Error("&group only applicable to events"); + break; + + case ATTR_ERROR_HANDLER: + if ( type->Tag() != TYPE_FUNC || type->AsFuncType()->Flavor() != FUNC_FLAVOR_EVENT ) + Error("&error_handler only applicable to events"); + break; + + case ATTR_LOG: + if ( ! threading::Value::IsCompatibleType(type.get()) ) + Error("&log applied to a type that cannot be logged"); + break; + + case ATTR_TYPE_COLUMN: { + if ( type->Tag() != TYPE_PORT ) { + Error("type_column tag only applicable to ports"); + break; + } + + const auto& atype = a->GetExpr()->GetType(); + + if ( atype->Tag() != TYPE_STRING ) { + Error("type column needs to have a string argument"); + break; + } + + break; + } + + case ATTR_ORDERED: + if ( type->Tag() != TYPE_TABLE ) + Error("&ordered only applicable to tables"); + break; + + default: BadTag("Attributes::CheckAttr", attr_name(a->Tag())); + } +} + +bool Attributes::operator==(const Attributes& other) const { + if ( attrs.empty() ) + return other.attrs.empty(); + + if ( other.attrs.empty() ) + return false; + + for ( const auto& a : attrs ) { + const auto& o = other.Find(a->Tag()); + + if ( ! o ) + return false; + + if ( ! (*a == *o) ) + return false; + } + + for ( const auto& o : other.attrs ) { + const auto& a = Find(o->Tag()); + + if ( ! a ) + return false; + + if ( ! (*a == *o) ) + return false; + } + + return true; +} + +bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_record, std::string& err_msg) { + ASSERT(a->Tag() == ATTR_DEFAULT || a->Tag() == ATTR_DEFAULT_INSERT); + std::string aname = attr_name(a->Tag()); + // &default is allowed for global tables, since it's used in + // initialization of table fields. It's not allowed otherwise. + if ( global_var && ! type->IsTable() ) { + err_msg = aname + " is not valid for global variables except for tables"; + return false; + } + + const auto& atype = a->GetExpr()->GetType(); + + if ( type->Tag() != TYPE_TABLE || (type->IsSet() && ! in_record) ) { + if ( same_type(atype, type) ) + // Ok. + return true; + + // Record defaults may be promotable. + if ( (type->Tag() == TYPE_RECORD && atype->Tag() == TYPE_RECORD && + record_promotion_compatible(atype->AsRecordType(), type->AsRecordType())) ) + // Ok. + return true; + + if ( type->Tag() == TYPE_TABLE && type->AsTableType()->IsUnspecifiedTable() ) + // Ok. + return true; + + auto e = check_and_promote_expr(a->GetExpr(), type); + + if ( e ) { + a->SetAttrExpr(std::move(e)); + // Ok. + return true; + } + + a->GetExpr()->Error(util::fmt("%s value has inconsistent type", aname.c_str()), type.get()); + return false; + } + + TableType* tt = type->AsTableType(); + const auto& ytype = tt->Yield(); + + if ( ! in_record ) { // &default applies to the type itself. + if ( same_type(atype, ytype) ) + return true; + + // It can still be a default function. + if ( atype->Tag() == TYPE_FUNC ) { + FuncType* f = atype->AsFuncType(); + if ( ! f->CheckArgs(tt->GetIndexTypes()) || ! same_type(f->Yield(), ytype) ) { + err_msg = aname + " function type clash"; + return false; + } + + // Ok. + return true; + } + + // Table defaults may be promotable. + if ( (ytype->Tag() == TYPE_RECORD && atype->Tag() == TYPE_RECORD && + record_promotion_compatible(atype->AsRecordType(), ytype->AsRecordType())) ) + // Ok. + return true; + + auto e = check_and_promote_expr(a->GetExpr(), ytype); + + if ( e ) { + a->SetAttrExpr(std::move(e)); + // Ok. + return true; + } + + err_msg = aname + " value has inconsistent type"; + return false; + } + + // &default applies to record field. + + if ( same_type(atype, type) ) + return true; + + if ( (atype->Tag() == TYPE_TABLE && atype->AsTableType()->IsUnspecifiedTable()) ) { + auto e = check_and_promote_expr(a->GetExpr(), type); + + if ( e ) { + a->SetAttrExpr(std::move(e)); + return true; + } + } + + // Table defaults may be promotable. + if ( ytype && ytype->Tag() == TYPE_RECORD && atype->Tag() == TYPE_RECORD && + record_promotion_compatible(atype->AsRecordType(), ytype->AsRecordType()) ) + // Ok. + return true; + + err_msg = "&default value has inconsistent type"; + return false; +} + +detail::TraversalCode Attributes::Traverse(detail::TraversalCallback* cb) const { + auto tc = cb->PreAttrs(this); + HANDLE_TC_ATTRS_PRE(tc); + + for ( const auto& a : attrs ) { + tc = a->Traverse(cb); + HANDLE_TC_ATTRS_PRE(tc); + } + + tc = cb->PostAttrs(this); + HANDLE_TC_ATTRS_POST(tc); +} + +} // namespace zeek::detail diff --git a/src/Attr.h b/src/Attr.h index 7d9f6a1a7f..546afb7d88 100644 --- a/src/Attr.h +++ b/src/Attr.h @@ -14,137 +14,131 @@ // modify expressions or supply metadata on types, and the kind that // are extra metadata on every variable instance. -namespace zeek - { +namespace zeek { class Type; using TypePtr = IntrusivePtr; -namespace detail - { +namespace detail { class Expr; using ExprPtr = IntrusivePtr; -enum AttrTag - { - ATTR_OPTIONAL, - ATTR_DEFAULT, - ATTR_DEFAULT_INSERT, // insert default value on failed lookups - ATTR_REDEF, - ATTR_ADD_FUNC, - ATTR_DEL_FUNC, - ATTR_EXPIRE_FUNC, - ATTR_EXPIRE_READ, - ATTR_EXPIRE_WRITE, - ATTR_EXPIRE_CREATE, - ATTR_RAW_OUTPUT, - ATTR_PRIORITY, - ATTR_GROUP, - ATTR_LOG, - ATTR_ERROR_HANDLER, - ATTR_TYPE_COLUMN, // for input framework - ATTR_TRACKED, // hidden attribute, tracked by NotifierRegistry - ATTR_ON_CHANGE, // for table change tracking - ATTR_BROKER_STORE, // for Broker store backed tables - ATTR_BROKER_STORE_ALLOW_COMPLEX, // for Broker store backed tables - ATTR_BACKEND, // for Broker store backed tables - ATTR_DEPRECATED, - ATTR_IS_ASSIGNED, // to suppress usage warnings - ATTR_IS_USED, // to suppress usage warnings - ATTR_ORDERED, // used to store tables in ordered mode - NUM_ATTRS // this item should always be last - }; +enum AttrTag { + ATTR_OPTIONAL, + ATTR_DEFAULT, + ATTR_DEFAULT_INSERT, // insert default value on failed lookups + ATTR_REDEF, + ATTR_ADD_FUNC, + ATTR_DEL_FUNC, + ATTR_EXPIRE_FUNC, + ATTR_EXPIRE_READ, + ATTR_EXPIRE_WRITE, + ATTR_EXPIRE_CREATE, + ATTR_RAW_OUTPUT, + ATTR_PRIORITY, + ATTR_GROUP, + ATTR_LOG, + ATTR_ERROR_HANDLER, + ATTR_TYPE_COLUMN, // for input framework + ATTR_TRACKED, // hidden attribute, tracked by NotifierRegistry + ATTR_ON_CHANGE, // for table change tracking + ATTR_BROKER_STORE, // for Broker store backed tables + ATTR_BROKER_STORE_ALLOW_COMPLEX, // for Broker store backed tables + ATTR_BACKEND, // for Broker store backed tables + ATTR_DEPRECATED, + ATTR_IS_ASSIGNED, // to suppress usage warnings + ATTR_IS_USED, // to suppress usage warnings + ATTR_ORDERED, // used to store tables in ordered mode + NUM_ATTRS // this item should always be last +}; class Attr; using AttrPtr = IntrusivePtr; class Attributes; using AttributesPtr = IntrusivePtr; -class Attr final : public Obj - { +class Attr final : public Obj { public: - static inline const AttrPtr nil; + static inline const AttrPtr nil; - Attr(AttrTag t, ExprPtr e); - explicit Attr(AttrTag t); + Attr(AttrTag t, ExprPtr e); + explicit Attr(AttrTag t); - ~Attr() override = default; + ~Attr() override = default; - AttrTag Tag() const { return tag; } + AttrTag Tag() const { return tag; } - const ExprPtr& GetExpr() const { return expr; } + const ExprPtr& GetExpr() const { return expr; } - void SetAttrExpr(ExprPtr e); + void SetAttrExpr(ExprPtr e); - void Describe(ODesc* d) const override; - void DescribeReST(ODesc* d, bool shorten = false) const; + void Describe(ODesc* d) const override; + void DescribeReST(ODesc* d, bool shorten = false) const; - /** - * Returns the deprecation string associated with a &deprecated attribute - * or an empty string if this is not such an attribute. - */ - std::string DeprecationMessage() const; + /** + * Returns the deprecation string associated with a &deprecated attribute + * or an empty string if this is not such an attribute. + */ + std::string DeprecationMessage() const; - bool operator==(const Attr& other) const - { - if ( tag != other.tag ) - return false; + bool operator==(const Attr& other) const { + if ( tag != other.tag ) + return false; - if ( expr || other.expr ) - // Too hard to check for equivalency, since one - // might be expressed/compiled differently than - // the other, so assume they're compatible, as - // long as both are present. - return expr && other.expr; + if ( expr || other.expr ) + // Too hard to check for equivalency, since one + // might be expressed/compiled differently than + // the other, so assume they're compatible, as + // long as both are present. + return expr && other.expr; - return true; - } + return true; + } - detail::TraversalCode Traverse(detail::TraversalCallback* cb) const; + detail::TraversalCode Traverse(detail::TraversalCallback* cb) const; protected: - void AddTag(ODesc* d) const; + void AddTag(ODesc* d) const; - AttrTag tag; - ExprPtr expr; - }; + AttrTag tag; + ExprPtr expr; +}; // Manages a collection of attributes. -class Attributes final : public Obj - { +class Attributes final : public Obj { public: - Attributes(std::vector a, TypePtr t, bool in_record, bool is_global); - Attributes(TypePtr t, bool in_record, bool is_global); + Attributes(std::vector a, TypePtr t, bool in_record, bool is_global); + Attributes(TypePtr t, bool in_record, bool is_global); - ~Attributes() override = default; + ~Attributes() override = default; - void AddAttr(AttrPtr a, bool is_redef = false); + void AddAttr(AttrPtr a, bool is_redef = false); - void AddAttrs(const AttributesPtr& a, bool is_redef = false); + void AddAttrs(const AttributesPtr& a, bool is_redef = false); - const AttrPtr& Find(AttrTag t) const; + const AttrPtr& Find(AttrTag t) const; - void RemoveAttr(AttrTag t); + void RemoveAttr(AttrTag t); - void Describe(ODesc* d) const override; - void DescribeReST(ODesc* d, bool shorten = false) const; + void Describe(ODesc* d) const override; + void DescribeReST(ODesc* d, bool shorten = false) const; - const std::vector& GetAttrs() const { return attrs; } + const std::vector& GetAttrs() const { return attrs; } - bool operator==(const Attributes& other) const; + bool operator==(const Attributes& other) const; - detail::TraversalCode Traverse(detail::TraversalCallback* cb) const; + detail::TraversalCode Traverse(detail::TraversalCallback* cb) const; protected: - void CheckAttr(Attr* attr); + void CheckAttr(Attr* attr); - TypePtr type; - std::vector attrs; + TypePtr type; + std::vector attrs; - bool in_record; - bool global_var; - }; + bool in_record; + bool global_var; +}; // Checks whether default attribute "a" is compatible with the given type. // "global_var" specifies whether the attribute is being associated with @@ -154,8 +148,7 @@ protected: // Returns true on compatibility (which might include modifying "a"), false // on an error. If an error message hasn't been directly generated, then // it will be returned in err_msg. -extern bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_record, - std::string& err_msg); +extern bool check_default_attr(Attr* a, const TypePtr& type, bool global_var, bool in_record, std::string& err_msg); - } // namespace detail - } // namespace zeek +} // namespace detail +} // namespace zeek diff --git a/src/Base64.cc b/src/Base64.cc index b88c7dff60..cd6b509e3a 100644 --- a/src/Base64.cc +++ b/src/Base64.cc @@ -8,277 +8,251 @@ #include "zeek/Reporter.h" #include "zeek/ZeekString.h" -namespace zeek::detail - { +namespace zeek::detail { int Base64Converter::default_base64_table[256]; const std::string Base64Converter::default_alphabet = - "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; -void Base64Converter::Encode(int len, const unsigned char* data, int* pblen, char** pbuf) - { - int blen; - char* buf; +void Base64Converter::Encode(int len, const unsigned char* data, int* pblen, char** pbuf) { + int blen; + char* buf; - if ( ! pbuf ) - reporter->InternalError("nil pointer to encoding result buffer"); + if ( ! pbuf ) + reporter->InternalError("nil pointer to encoding result buffer"); - if ( *pbuf && (*pblen % 4 != 0) ) - reporter->InternalError("Base64 encode buffer not a multiple of 4"); + if ( *pbuf && (*pblen % 4 != 0) ) + reporter->InternalError("Base64 encode buffer not a multiple of 4"); - if ( *pbuf ) - { - buf = *pbuf; - blen = *pblen; - } - else - { - blen = (int)(4 * ceil((double)len / 3)); - *pbuf = buf = new char[blen]; - *pblen = blen; - } + if ( *pbuf ) { + buf = *pbuf; + blen = *pblen; + } + else { + blen = (int)(4 * ceil((double)len / 3)); + *pbuf = buf = new char[blen]; + *pblen = blen; + } - for ( int i = 0, j = 0; (i < len) && (j < blen); ) - { - uint32_t bit32 = data[i++] << 16; - bit32 += (i++ < len ? data[i - 1] : 0) << 8; - bit32 += i++ < len ? data[i - 1] : 0; + for ( int i = 0, j = 0; (i < len) && (j < blen); ) { + uint32_t bit32 = data[i++] << 16; + bit32 += (i++ < len ? data[i - 1] : 0) << 8; + bit32 += i++ < len ? data[i - 1] : 0; - buf[j++] = alphabet[(bit32 >> 18) & 0x3f]; - buf[j++] = alphabet[(bit32 >> 12) & 0x3f]; - buf[j++] = (i == (len + 2)) ? '=' : alphabet[(bit32 >> 6) & 0x3f]; - buf[j++] = (i >= (len + 1)) ? '=' : alphabet[bit32 & 0x3f]; - } - } + buf[j++] = alphabet[(bit32 >> 18) & 0x3f]; + buf[j++] = alphabet[(bit32 >> 12) & 0x3f]; + buf[j++] = (i == (len + 2)) ? '=' : alphabet[(bit32 >> 6) & 0x3f]; + buf[j++] = (i >= (len + 1)) ? '=' : alphabet[bit32 & 0x3f]; + } +} -int* Base64Converter::InitBase64Table(const std::string& alphabet) - { - assert(alphabet.size() == 64); +int* Base64Converter::InitBase64Table(const std::string& alphabet) { + assert(alphabet.size() == 64); - static bool default_table_initialized = false; + static bool default_table_initialized = false; - if ( alphabet == default_alphabet && default_table_initialized ) - return default_base64_table; + if ( alphabet == default_alphabet && default_table_initialized ) + return default_base64_table; - int* base64_table = nullptr; + int* base64_table = nullptr; - if ( alphabet == default_alphabet ) - { - base64_table = default_base64_table; - default_table_initialized = true; - } - else - base64_table = new int[256]; + if ( alphabet == default_alphabet ) { + base64_table = default_base64_table; + default_table_initialized = true; + } + else + base64_table = new int[256]; - int i; - for ( i = 0; i < 256; ++i ) - base64_table[i] = -1; + int i; + for ( i = 0; i < 256; ++i ) + base64_table[i] = -1; - for ( i = 0; i < 26; ++i ) - { - base64_table[int(alphabet[0 + i])] = i; - base64_table[int(alphabet[26 + i])] = i + 26; - } + for ( i = 0; i < 26; ++i ) { + base64_table[int(alphabet[0 + i])] = i; + base64_table[int(alphabet[26 + i])] = i + 26; + } - for ( i = 0; i < 10; ++i ) - base64_table[int(alphabet[52 + i])] = i + 52; + for ( i = 0; i < 10; ++i ) + base64_table[int(alphabet[52 + i])] = i + 52; - // Casts to avoid compiler warnings. - base64_table[int(alphabet[62])] = 62; - base64_table[int(alphabet[63])] = 63; - base64_table[int('=')] = 0; + // Casts to avoid compiler warnings. + base64_table[int(alphabet[62])] = 62; + base64_table[int(alphabet[63])] = 63; + base64_table[int('=')] = 0; - return base64_table; - } + return base64_table; +} -Base64Converter::Base64Converter(Connection* arg_conn, const std::string& arg_alphabet) - { - if ( arg_alphabet.size() > 0 ) - { - assert(arg_alphabet.size() == 64); - alphabet = arg_alphabet; - } - else - { - alphabet = default_alphabet; - } +Base64Converter::Base64Converter(Connection* arg_conn, const std::string& arg_alphabet) { + if ( arg_alphabet.size() > 0 ) { + assert(arg_alphabet.size() == 64); + alphabet = arg_alphabet; + } + else { + alphabet = default_alphabet; + } - base64_table = nullptr; - base64_group_next = 0; - base64_padding = base64_after_padding = 0; - errored = 0; - conn = arg_conn; - } + base64_table = nullptr; + base64_group_next = 0; + base64_padding = base64_after_padding = 0; + errored = 0; + conn = arg_conn; +} -Base64Converter::~Base64Converter() - { - if ( base64_table != default_base64_table ) - delete[] base64_table; - } +Base64Converter::~Base64Converter() { + if ( base64_table != default_base64_table ) + delete[] base64_table; +} -int Base64Converter::Decode(int len, const char* data, int* pblen, char** pbuf) - { - int blen; - char* buf; +int Base64Converter::Decode(int len, const char* data, int* pblen, char** pbuf) { + int blen; + char* buf; - // Initialization of table on first_time call of Decode. - if ( ! base64_table ) - base64_table = InitBase64Table(alphabet); + // Initialization of table on first_time call of Decode. + if ( ! base64_table ) + base64_table = InitBase64Table(alphabet); - if ( ! pbuf ) - reporter->InternalError("nil pointer to decoding result buffer"); + if ( ! pbuf ) + reporter->InternalError("nil pointer to decoding result buffer"); - if ( *pbuf ) - { - buf = *pbuf; - blen = *pblen; - } - else - { - // Estimate the maximal number of 3-byte groups needed, - // plus 1 byte for the optional ending NUL. - blen = int((len + base64_group_next + 3) / 4) * 3 + 1; - *pbuf = buf = new char[blen]; - } + if ( *pbuf ) { + buf = *pbuf; + blen = *pblen; + } + else { + // Estimate the maximal number of 3-byte groups needed, + // plus 1 byte for the optional ending NUL. + blen = int((len + base64_group_next + 3) / 4) * 3 + 1; + *pbuf = buf = new char[blen]; + } - int dlen = 0; + int dlen = 0; - while ( true ) - { - if ( base64_group_next == 4 ) - { - // For every group of 4 6-bit numbers, - // write the decoded 3 bytes to the buffer. - if ( base64_after_padding ) - { - if ( ++errored == 1 ) - IllegalEncoding("extra base64 groups after '=' padding are ignored"); - base64_group_next = 0; - continue; - } + while ( true ) { + if ( base64_group_next == 4 ) { + // For every group of 4 6-bit numbers, + // write the decoded 3 bytes to the buffer. + if ( base64_after_padding ) { + if ( ++errored == 1 ) + IllegalEncoding("extra base64 groups after '=' padding are ignored"); + base64_group_next = 0; + continue; + } - int num_octets = 3 - base64_padding; + int num_octets = 3 - base64_padding; - if ( buf + num_octets > *pbuf + blen ) - break; + if ( buf + num_octets > *pbuf + blen ) + break; - uint32_t bit32 = ((base64_group[0] & 0x3f) << 18) | ((base64_group[1] & 0x3f) << 12) | - ((base64_group[2] & 0x3f) << 6) | ((base64_group[3] & 0x3f)); + uint32_t bit32 = ((base64_group[0] & 0x3f) << 18) | ((base64_group[1] & 0x3f) << 12) | + ((base64_group[2] & 0x3f) << 6) | ((base64_group[3] & 0x3f)); - if ( --num_octets >= 0 ) - *buf++ = char((bit32 >> 16) & 0xff); + if ( --num_octets >= 0 ) + *buf++ = char((bit32 >> 16) & 0xff); - if ( --num_octets >= 0 ) - *buf++ = char((bit32 >> 8) & 0xff); + if ( --num_octets >= 0 ) + *buf++ = char((bit32 >> 8) & 0xff); - if ( --num_octets >= 0 ) - *buf++ = char((bit32)&0xff); + if ( --num_octets >= 0 ) + *buf++ = char((bit32)&0xff); - if ( base64_padding > 0 ) - base64_after_padding = 1; + if ( base64_padding > 0 ) + base64_after_padding = 1; - base64_group_next = 0; - base64_padding = 0; - } + base64_group_next = 0; + base64_padding = 0; + } - if ( dlen >= len ) - break; + if ( dlen >= len ) + break; - unsigned char c = (unsigned char)data[dlen]; - if ( c == '=' ) - ++base64_padding; + unsigned char c = (unsigned char)data[dlen]; + if ( c == '=' ) + ++base64_padding; - int k = base64_table[c]; - if ( k >= 0 ) - base64_group[base64_group_next++] = k; - else - { - if ( ++errored == 1 ) - IllegalEncoding(util::fmt("character %d ignored by Base64 decoding", (int)c)); - } + int k = base64_table[c]; + if ( k >= 0 ) + base64_group[base64_group_next++] = k; + else { + if ( ++errored == 1 ) + IllegalEncoding(util::fmt("character %d ignored by Base64 decoding", (int)c)); + } - ++dlen; - } + ++dlen; + } - *pblen = buf - *pbuf; - return dlen; - } + *pblen = buf - *pbuf; + return dlen; +} -int Base64Converter::Done(int* pblen, char** pbuf) - { - const char* padding = "==="; +int Base64Converter::Done(int* pblen, char** pbuf) { + const char* padding = "==="; - if ( base64_group_next != 0 ) - { - if ( base64_group_next < 4 ) - IllegalEncoding(util::fmt("incomplete base64 group, padding with %d bits of 0", - (4 - base64_group_next) * 6)); - Decode(4 - base64_group_next, padding, pblen, pbuf); - return -1; - } + if ( base64_group_next != 0 ) { + if ( base64_group_next < 4 ) + IllegalEncoding( + util::fmt("incomplete base64 group, padding with %d bits of 0", (4 - base64_group_next) * 6)); + Decode(4 - base64_group_next, padding, pblen, pbuf); + return -1; + } - if ( pblen ) - *pblen = 0; + if ( pblen ) + *pblen = 0; - return 0; - } + return 0; +} -void Base64Converter::IllegalEncoding(const char* msg) - { - // strncpy(error_msg, msg, sizeof(error_msg)); - if ( conn ) - conn->Weird("base64_illegal_encoding", msg); - else - reporter->Error("%s", msg); - } +void Base64Converter::IllegalEncoding(const char* msg) { + // strncpy(error_msg, msg, sizeof(error_msg)); + if ( conn ) + conn->Weird("base64_illegal_encoding", msg); + else + reporter->Error("%s", msg); +} -String* decode_base64(const String* s, const String* a, Connection* conn) - { - if ( a && a->Len() != 0 && a->Len() != 64 ) - { - reporter->Error("base64 decoding alphabet is not 64 characters: %s", a->CheckString()); - return nullptr; - } +String* decode_base64(const String* s, const String* a, Connection* conn) { + if ( a && a->Len() != 0 && a->Len() != 64 ) { + reporter->Error("base64 decoding alphabet is not 64 characters: %s", a->CheckString()); + return nullptr; + } - int buf_len = int((s->Len() + 3) / 4) * 3 + 1; - int rlen2, rlen = buf_len; - char *rbuf2, *rbuf = new char[rlen]; + int buf_len = int((s->Len() + 3) / 4) * 3 + 1; + int rlen2, rlen = buf_len; + char *rbuf2, *rbuf = new char[rlen]; - Base64Converter dec(conn, a ? a->CheckString() : ""); - dec.Decode(s->Len(), (const char*)s->Bytes(), &rlen, &rbuf); + Base64Converter dec(conn, a ? a->CheckString() : ""); + dec.Decode(s->Len(), (const char*)s->Bytes(), &rlen, &rbuf); - if ( dec.Errored() ) - goto err; + if ( dec.Errored() ) + goto err; - rlen2 = buf_len - rlen; - rbuf2 = rbuf + rlen; - // Done() returns -1 if there isn't enough padding, but we just ignore - // it. - dec.Done(&rlen2, &rbuf2); - rlen += rlen2; + rlen2 = buf_len - rlen; + rbuf2 = rbuf + rlen; + // Done() returns -1 if there isn't enough padding, but we just ignore + // it. + dec.Done(&rlen2, &rbuf2); + rlen += rlen2; - rbuf[rlen] = '\0'; - return new String(true, (u_char*)rbuf, rlen); + rbuf[rlen] = '\0'; + return new String(true, (u_char*)rbuf, rlen); err: - delete[] rbuf; - return nullptr; - } + delete[] rbuf; + return nullptr; +} -String* encode_base64(const String* s, const String* a, Connection* conn) - { - if ( a && a->Len() != 0 && a->Len() != 64 ) - { - reporter->Error("base64 alphabet is not 64 characters: %s", a->CheckString()); - return nullptr; - } +String* encode_base64(const String* s, const String* a, Connection* conn) { + if ( a && a->Len() != 0 && a->Len() != 64 ) { + reporter->Error("base64 alphabet is not 64 characters: %s", a->CheckString()); + return nullptr; + } - char* outbuf = nullptr; - int outlen = 0; - Base64Converter enc(conn, a ? a->CheckString() : ""); - enc.Encode(s->Len(), (const unsigned char*)s->Bytes(), &outlen, &outbuf); + char* outbuf = nullptr; + int outlen = 0; + Base64Converter enc(conn, a ? a->CheckString() : ""); + enc.Encode(s->Len(), (const unsigned char*)s->Bytes(), &outlen, &outbuf); - return new String(true, (u_char*)outbuf, outlen); - } + return new String(true, (u_char*)outbuf, outlen); +} - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/Base64.h b/src/Base64.h index 943cb52391..a6e258d05d 100644 --- a/src/Base64.h +++ b/src/Base64.h @@ -4,69 +4,66 @@ #include -namespace zeek - { +namespace zeek { class String; class Connection; -namespace detail - { +namespace detail { // Maybe we should have a base class for generic decoders? -class Base64Converter - { +class Base64Converter { public: - // is used for error reporting. If it is set to zero (as, - // e.g., done by the built-in functions decode_base64() and - // encode_base64()), encoding-errors will go to Reporter instead of - // Weird. Usage errors go to Reporter in any case. Empty alphabet - // indicates the default base64 alphabet. - explicit Base64Converter(Connection* conn, const std::string& alphabet = ""); - ~Base64Converter(); + // is used for error reporting. If it is set to zero (as, + // e.g., done by the built-in functions decode_base64() and + // encode_base64()), encoding-errors will go to Reporter instead of + // Weird. Usage errors go to Reporter in any case. Empty alphabet + // indicates the default base64 alphabet. + explicit Base64Converter(Connection* conn, const std::string& alphabet = ""); + ~Base64Converter(); - // A note on Decode(): - // - // The input is specified by and and the output - // buffer by and . If *buf is nil, a buffer of - // an appropriate size will be new'd and *buf will point - // to the buffer on return. *blen holds the length of - // decoded data on return. The function returns the number of - // input bytes processed, since the decoding will stop when there - // is not enough output buffer space. + // A note on Decode(): + // + // The input is specified by and and the output + // buffer by and . If *buf is nil, a buffer of + // an appropriate size will be new'd and *buf will point + // to the buffer on return. *blen holds the length of + // decoded data on return. The function returns the number of + // input bytes processed, since the decoding will stop when there + // is not enough output buffer space. - int Decode(int len, const char* data, int* blen, char** buf); - void Encode(int len, const unsigned char* data, int* blen, char** buf); + int Decode(int len, const char* data, int* blen, char** buf); + void Encode(int len, const unsigned char* data, int* blen, char** buf); - int Done(int* pblen, char** pbuf); - bool HasData() const { return base64_group_next != 0; } + int Done(int* pblen, char** pbuf); + bool HasData() const { return base64_group_next != 0; } - // True if an error has occurred. - int Errored() const { return errored; } + // True if an error has occurred. + int Errored() const { return errored; } - const char* ErrorMsg() const { return error_msg; } - void IllegalEncoding(const char* msg); + const char* ErrorMsg() const { return error_msg; } + void IllegalEncoding(const char* msg); protected: - char error_msg[256]; + char error_msg[256]; protected: - static const std::string default_alphabet; - std::string alphabet; + static const std::string default_alphabet; + std::string alphabet; - static int* InitBase64Table(const std::string& alphabet); - static int default_base64_table[256]; - char base64_group[4]; - int base64_group_next; - int base64_padding; - int base64_after_padding; - int* base64_table; - int errored; // if true, we encountered an error - skip further processing - Connection* conn; - }; + static int* InitBase64Table(const std::string& alphabet); + static int default_base64_table[256]; + char base64_group[4]; + int base64_group_next; + int base64_padding; + int base64_after_padding; + int* base64_table; + int errored; // if true, we encountered an error - skip further processing + Connection* conn; +}; String* decode_base64(const String* s, const String* a = nullptr, Connection* conn = nullptr); String* encode_base64(const String* s, const String* a = nullptr, Connection* conn = nullptr); - } // namespace detail - } // namespace zeek +} // namespace detail +} // namespace zeek diff --git a/src/BifReturnVal.cc b/src/BifReturnVal.cc index 48c57bd004..0a2e3593e4 100644 --- a/src/BifReturnVal.cc +++ b/src/BifReturnVal.cc @@ -4,9 +4,8 @@ #include "zeek/Val.h" -namespace zeek::detail - { +namespace zeek::detail { -BifReturnVal::BifReturnVal(std::nullptr_t) noexcept { } +BifReturnVal::BifReturnVal(std::nullptr_t) noexcept {} - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/BifReturnVal.h b/src/BifReturnVal.h index 70fd901ba2..79445b741f 100644 --- a/src/BifReturnVal.h +++ b/src/BifReturnVal.h @@ -6,31 +6,27 @@ #include "zeek/IntrusivePtr.h" -namespace zeek - { +namespace zeek { class Val; using ValPtr = IntrusivePtr; -namespace detail - { +namespace detail { /** * A simple wrapper class to use for the return value of BIFs so that * they may return either a Val* or IntrusivePtr (the former could * potentially be deprecated). */ -class BifReturnVal - { +class BifReturnVal { public: - template BifReturnVal(IntrusivePtr v) noexcept : rval(AdoptRef{}, v.release()) - { - } + template + BifReturnVal(IntrusivePtr v) noexcept : rval(AdoptRef{}, v.release()) {} - BifReturnVal(std::nullptr_t) noexcept; + BifReturnVal(std::nullptr_t) noexcept; - ValPtr rval; - }; + ValPtr rval; +}; - } // namespace detail - } // namespace zeek +} // namespace detail +} // namespace zeek diff --git a/src/CCL.cc b/src/CCL.cc index e396c5af55..70b05bf9a1 100644 --- a/src/CCL.cc +++ b/src/CCL.cc @@ -9,43 +9,33 @@ #include "zeek/DFA.h" #include "zeek/RE.h" -namespace zeek::detail - { +namespace zeek::detail { -CCL::CCL() - { - syms = new int_list; - index = -(rem->InsertCCL(this) + 1); - negated = 0; - } +CCL::CCL() { + syms = new int_list; + index = -(rem->InsertCCL(this) + 1); + negated = 0; +} -CCL::~CCL() - { - delete syms; - } +CCL::~CCL() { delete syms; } -void CCL::Negate() - { - negated = 1; - Add(SYM_BOL); - Add(SYM_EOL); - } +void CCL::Negate() { + negated = 1; + Add(SYM_BOL); + Add(SYM_EOL); +} -void CCL::Add(int sym) - { - auto sym_p = static_cast(sym); +void CCL::Add(int sym) { + auto sym_p = static_cast(sym); - // Check to see if the character is already in the ccl. - for ( auto sym_entry : *syms ) - if ( sym_entry == sym_p ) - return; + // Check to see if the character is already in the ccl. + for ( auto sym_entry : *syms ) + if ( sym_entry == sym_p ) + return; - syms->push_back(sym_p); - } + syms->push_back(sym_p); +} -void CCL::Sort() - { - std::sort(syms->begin(), syms->end()); - } +void CCL::Sort() { std::sort(syms->begin(), syms->end()); } - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/CCL.h b/src/CCL.h index 5b5688a333..7a9183bc25 100644 --- a/src/CCL.h +++ b/src/CCL.h @@ -5,36 +5,33 @@ #include #include -namespace zeek::detail - { +namespace zeek::detail { using int_list = std::vector; -class CCL - { +class CCL { public: - CCL(); - ~CCL(); + CCL(); + ~CCL(); - void Add(int sym); - void Negate(); - bool IsNegated() { return negated != 0; } - int Index() { return index; } + void Add(int sym); + void Negate(); + bool IsNegated() { return negated != 0; } + int Index() { return index; } - void Sort(); + void Sort(); - int_list* Syms() { return syms; } + int_list* Syms() { return syms; } - void ReplaceSyms(int_list* new_syms) - { - delete syms; - syms = new_syms; - } + void ReplaceSyms(int_list* new_syms) { + delete syms; + syms = new_syms; + } protected: - int_list* syms; - int negated; - int index; - }; + int_list* syms; + int negated; + int index; +}; - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/CompHash.cc b/src/CompHash.cc index aa0a9846f2..aef744ee9d 100644 --- a/src/CompHash.cc +++ b/src/CompHash.cc @@ -15,941 +15,778 @@ #include "zeek/Val.h" #include "zeek/ZeekString.h" -namespace zeek::detail - { +namespace zeek::detail { // A comparison callable to assist with consistent iteration order over tables // during reservation & writes. -struct HashKeyComparer - { - bool operator()(const std::unique_ptr& a, const std::unique_ptr& b) const - { - if ( a->Hash() != b->Hash() ) - return a->Hash() < b->Hash(); - if ( a->Size() != b->Size() ) - return a->Size() < b->Size(); - return memcmp(a->Key(), b->Key(), a->Size()) < 0; - } - }; +struct HashKeyComparer { + bool operator()(const std::unique_ptr& a, const std::unique_ptr& b) const { + if ( a->Hash() != b->Hash() ) + return a->Hash() < b->Hash(); + if ( a->Size() != b->Size() ) + return a->Size() < b->Size(); + return memcmp(a->Key(), b->Key(), a->Size()) < 0; + } +}; using HashkeyMap = std::map, ListValPtr, HashKeyComparer>; using HashkeyMapPtr = std::unique_ptr; // Helper that produces a table from HashKeys to the ListVal indexes into the // table, that we can iterate over in sorted-Hashkey order. -const HashkeyMapPtr ordered_hashkeys(const TableVal* tv) - { - auto res = std::make_unique(); - auto tbl = tv->AsTable(); - auto idx = 0; - - for ( const auto& entry : *tbl ) - { - auto k = entry.GetHashKey(); - // Potential optimization: we could do without the following if - // the caller uses k directly to determine key length & - // content. But: the way k got serialized might differ somewhat - // from how we'll end up doing it (e.g. singleton vs - // non-singleton), and looking up a table value with the hashkey - // is tricky in case of subnets (consider the special-casing in - // TableVal::Find()). - auto lv = tv->RecreateIndex(*k); - res->insert_or_assign(std::move(k), lv); - } - - return res; - } - -CompositeHash::CompositeHash(TypeListPtr composite_type) : type(std::move(composite_type)) - { - if ( type->GetTypes().size() == 1 ) - is_singleton = true; - } - -std::unique_ptr CompositeHash::MakeHashKey(const Val& argv, bool type_check) const - { - auto res = std::make_unique(); - const auto& tl = type->GetTypes(); - - if ( is_singleton ) - { - const Val* v = &argv; - - // This is the "singleton" case -- actually just a single value - // that may come bundled in a list. If so, unwrap it. - if ( v->GetType()->Tag() == TYPE_LIST ) - { - auto lv = v->AsListVal(); - - if ( type_check && lv->Length() != 1 ) - return nullptr; - - v = lv->Idx(0).get(); - } - - if ( SingleValHash(*res, v, tl[0].get(), type_check, false, true) ) - return res; - - return nullptr; - } - - if ( type_check && argv.GetType()->Tag() != TYPE_LIST ) - return nullptr; - - if ( ! ReserveKeySize(*res, &argv, type_check, false) ) - return nullptr; - - // Size computation has done requested type-checking, no further need - type_check = false; - - // The size computation resulted in a requested buffer size; allocate it. - res->Allocate(); - - for ( auto i = 0u; i < tl.size(); ++i ) - { - if ( ! SingleValHash(*res, argv.AsListVal()->Idx(i).get(), tl[i].get(), type_check, false, - false) ) - return nullptr; - } - - return res; - } - -ListValPtr CompositeHash::RecoverVals(const HashKey& hk) const - { - auto l = make_intrusive(TYPE_ANY); - const auto& tl = type->GetTypes(); - - hk.ResetRead(); - - for ( const auto& type : tl ) - { - ValPtr v; - - if ( ! RecoverOneVal(hk, type.get(), &v, false, is_singleton) ) - reporter->InternalError("value recovery failure in CompositeHash::RecoverVals"); - - ASSERT(v); - l->Append(std::move(v)); - } - - return l; - } - -bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool optional, - bool singleton) const - { - TypeTag tag = t->Tag(); - InternalTypeTag it = t->InternalType(); - - if ( optional ) - { - bool opt; - hk.Read("optional", opt); - - if ( ! opt ) - { - *pval = nullptr; - return true; - } - } - - switch ( it ) - { - case TYPE_INTERNAL_INT: - { - zeek_int_t i; - hk.Read("int", i); - - if ( tag == TYPE_ENUM ) - *pval = t->AsEnumType()->GetEnumVal(i); - else if ( tag == TYPE_BOOL ) - *pval = val_mgr->Bool(i); - else if ( tag == TYPE_INT ) - *pval = val_mgr->Int(i); - else - { - reporter->InternalError( - "bad internal unsigned int in CompositeHash::RecoverOneVal()"); - *pval = nullptr; - return false; - } - } - break; - - case TYPE_INTERNAL_UNSIGNED: - { - zeek_uint_t u; - hk.Read("unsigned", u); - - switch ( tag ) - { - case TYPE_COUNT: - *pval = val_mgr->Count(u); - break; - - case TYPE_PORT: - *pval = val_mgr->Port(u); - break; - - default: - reporter->InternalError( - "bad internal unsigned int in CompositeHash::RecoverOneVal()"); - *pval = nullptr; - return false; - } - } - break; - - case TYPE_INTERNAL_DOUBLE: - { - double d; - hk.Read("double", d); - - if ( tag == TYPE_INTERVAL ) - *pval = make_intrusive(d, 1.0); - else if ( tag == TYPE_TIME ) - *pval = make_intrusive(d); - else - *pval = make_intrusive(d); - } - break; - - case TYPE_INTERNAL_ADDR: - { - hk.AlignRead(sizeof(uint32_t)); - hk.EnsureReadSpace(sizeof(uint32_t) * 4); - IPAddr addr(IPv6, static_cast(hk.KeyAtRead()), IPAddr::Network); - hk.SkipRead("addr", sizeof(uint32_t) * 4); - - switch ( tag ) - { - case TYPE_ADDR: - *pval = make_intrusive(addr); - break; - - default: - reporter->InternalError( - "bad internal address in CompositeHash::RecoverOneVal()"); - *pval = nullptr; - return false; - } - } - break; - - case TYPE_INTERNAL_SUBNET: - { - hk.AlignRead(sizeof(uint32_t)); - hk.EnsureReadSpace(sizeof(uint32_t) * 4); - IPAddr addr(IPv6, static_cast(hk.KeyAtRead()), IPAddr::Network); - hk.SkipRead("subnet", sizeof(uint32_t) * 4); - - uint32_t width; - hk.Read("subnet-width", width); - *pval = make_intrusive(addr, width); - } - break; - - case TYPE_INTERNAL_VOID: - case TYPE_INTERNAL_OTHER: - { - switch ( t->Tag() ) - { - case TYPE_FUNC: - { - uint32_t id; - hk.Read("func", id); - - ASSERT(func_id_to_func != nullptr); - - if ( id >= func_id_to_func->size() ) - reporter->InternalError("failed to look up unique function id %" PRIu32 - " in CompositeHash::RecoverOneVal()", - id); - - const auto& f = func_id_to_func->at(id); - - *pval = make_intrusive(f); - const auto& pvt = (*pval)->GetType(); - - if ( ! pvt ) - reporter->InternalError( - "bad aggregate Val in CompositeHash::RecoverOneVal()"); - - else if ( t->Tag() != TYPE_FUNC && ! same_type(pvt, t) ) - // ### Maybe fix later, but may be fundamentally un-checkable --US - { - reporter->InternalError( - "inconsistent aggregate Val in CompositeHash::RecoverOneVal()"); - *pval = nullptr; - return false; - } - - // ### A crude approximation for now. - else if ( t->Tag() == TYPE_FUNC && pvt->Tag() != TYPE_FUNC ) - { - reporter->InternalError( - "inconsistent aggregate Val in CompositeHash::RecoverOneVal()"); - *pval = nullptr; - return false; - } - } - break; - - case TYPE_PATTERN: - { - const char* texts[2] = {nullptr, nullptr}; - uint64_t lens[2] = {0, 0}; - - if ( ! singleton ) - { - hk.Read("pattern-len1", lens[0]); - hk.Read("pattern-len2", lens[1]); - } - - texts[0] = static_cast(hk.KeyAtRead()); - hk.SkipRead("pattern-string1", strlen(texts[0]) + 1); - texts[1] = static_cast(hk.KeyAtRead()); - hk.SkipRead("pattern-string2", strlen(texts[1]) + 1); - - RE_Matcher* re = new RE_Matcher(texts[0], texts[1]); - - if ( ! re->Compile() ) - reporter->InternalError("failed compiling table/set key pattern: %s", - re->PatternText()); - - *pval = make_intrusive(re); - } - break; - - case TYPE_RECORD: - { - auto rt = t->AsRecordType(); - int num_fields = rt->NumFields(); - - std::vector values; - int i; - for ( i = 0; i < num_fields; ++i ) - { - ValPtr v; - Attributes* a = rt->FieldDecl(i)->attrs.get(); - bool is_optional = (a && a->Find(ATTR_OPTIONAL)); - - if ( ! RecoverOneVal(hk, rt->GetFieldType(i).get(), &v, is_optional, - false) ) - { - *pval = nullptr; - return false; - } - - // An earlier call to reporter->InternalError would have called - // abort() and broken the call tree that clang-tidy is relying on to - // get the error described. - // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Branch) - if ( ! (v || is_optional) ) - { - reporter->InternalError( - "didn't recover expected number of fields from HashKey"); - *pval = nullptr; - return false; - } - - values.emplace_back(std::move(v)); - } - - ASSERT(int(values.size()) == num_fields); - - auto rv = make_intrusive(IntrusivePtr{NewRef{}, rt}, - false /* init_fields */); - - for ( int i = 0; i < num_fields; ++i ) - rv->AppendField(std::move(values[i]), rt->GetFieldType(i)); - - *pval = std::move(rv); - } - break; - - case TYPE_TABLE: - { - int n; - hk.Read("table-size", n); - auto tt = t->AsTableType(); - auto tv = make_intrusive(IntrusivePtr{NewRef{}, tt}); - - for ( int i = 0; i < n; ++i ) - { - ValPtr key; - if ( ! RecoverOneVal(hk, tt->GetIndices().get(), &key, false, false) ) - { - *pval = nullptr; - return false; - } - - if ( t->IsSet() ) - tv->Assign(std::move(key), nullptr); - else - { - ValPtr value; - if ( ! RecoverOneVal(hk, tt->Yield().get(), &value, false, false) ) - { - *pval = nullptr; - return false; - } - tv->Assign(std::move(key), std::move(value)); - } - } - - *pval = std::move(tv); - } - break; - - case TYPE_VECTOR: - { - unsigned int n; - hk.Read("vector-size", n); - auto vt = t->AsVectorType(); - auto vv = make_intrusive(IntrusivePtr{NewRef{}, vt}); - - for ( unsigned int i = 0; i < n; ++i ) - { - unsigned int index; - hk.Read("vector-idx", index); - bool have_val; - hk.Read("vector-idx-present", have_val); - ValPtr value; - - if ( have_val && - ! RecoverOneVal(hk, vt->Yield().get(), &value, false, false) ) - { - *pval = nullptr; - return false; - } - - vv->Assign(index, std::move(value)); - } - - *pval = std::move(vv); - } - break; - - case TYPE_LIST: - { - int n; - hk.Read("list-size", n); - auto tl = t->AsTypeList(); - auto lv = make_intrusive(TYPE_ANY); - - for ( int i = 0; i < n; ++i ) - { - ValPtr v; - Type* it = tl->GetTypes()[i].get(); - if ( ! RecoverOneVal(hk, it, &v, false, false) ) - return false; - lv->Append(std::move(v)); - } - - *pval = std::move(lv); - } - break; - - default: - { - reporter->InternalError("bad index type in CompositeHash::RecoverOneVal"); - *pval = nullptr; - return false; - } - } - } - break; - - case TYPE_INTERNAL_STRING: - { - int n = hk.Size(); - - if ( ! singleton ) - { - hk.Read("string-len", n); - hk.EnsureReadSpace(n); - } - - *pval = make_intrusive(new String((const byte_vec)hk.KeyAtRead(), n, true)); - hk.SkipRead("string", n); - } - break; - - case TYPE_INTERNAL_ERROR: - break; - } - - return true; - } - -bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type_check, - bool optional, bool singleton) const - { - InternalTypeTag t = bt->InternalType(); - - if ( type_check && v ) - { - InternalTypeTag vt = v->GetType()->InternalType(); - if ( vt != t ) - return false; - } - - if ( optional ) - { - // Add a marker saying whether the optional field is set. - hk.Write("optional", v != nullptr); - - if ( ! v ) - return true; - } - - // All of the rest of the code here depends on v not being null, since it needs - // to get values from it. - if ( ! v ) - return false; - - switch ( t ) - { - case TYPE_INTERNAL_INT: - hk.Write("int", v->AsInt()); - break; - - case TYPE_INTERNAL_UNSIGNED: - hk.Write("unsigned", v->AsCount()); - break; - - case TYPE_INTERNAL_ADDR: - if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) - return false; - - hk.AlignWrite(sizeof(uint32_t)); - hk.EnsureWriteSpace(sizeof(uint32_t) * 4); - v->AsAddr().CopyIPv6(static_cast(hk.KeyAtWrite())); - hk.SkipWrite("addr", sizeof(uint32_t) * 4); - break; - - case TYPE_INTERNAL_SUBNET: - if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) - return false; - - hk.AlignWrite(sizeof(uint32_t)); - hk.EnsureWriteSpace(sizeof(uint32_t) * 5); - v->AsSubNet().Prefix().CopyIPv6(static_cast(hk.KeyAtWrite())); - hk.SkipWrite("subnet", sizeof(uint32_t) * 4); - hk.Write("subnet-width", v->AsSubNet().Length()); - break; - - case TYPE_INTERNAL_DOUBLE: - hk.Write("double", v->InternalDouble()); - break; - - case TYPE_INTERNAL_VOID: - case TYPE_INTERNAL_OTHER: - { - switch ( v->GetType()->Tag() ) - { - case TYPE_FUNC: - { - auto f = v->AsFunc(); - - if ( ! func_to_func_id ) - const_cast(this)->BuildFuncMappings(); - - auto id_mapping = func_to_func_id->find(f); - uint32_t id; - - if ( id_mapping == func_to_func_id->end() ) - { - // We need the pointer to stick around - // for our lifetime, so we have to get - // a non-const version we can ref. - FuncPtr fptr = {NewRef{}, const_cast(f)}; - - id = func_id_to_func->size(); - func_id_to_func->push_back(std::move(fptr)); - func_to_func_id->insert_or_assign(f, id); - } - else - id = id_mapping->second; - - hk.Write("func", id); - } - break; - - case TYPE_PATTERN: - { - const char* texts[2] = {v->AsPattern()->PatternText(), - v->AsPattern()->AnywherePatternText()}; - uint64_t lens[2] = {strlen(texts[0]) + 1, strlen(texts[1]) + 1}; - - if ( ! singleton ) - { - hk.Write("pattern-len1", lens[0]); - hk.Write("pattern-len2", lens[1]); - } - else - { - hk.Reserve("pattern", lens[0] + lens[1]); - hk.Allocate(); - } - - hk.Write("pattern-string1", static_cast(texts[0]), lens[0]); - hk.Write("pattern-string2", static_cast(texts[1]), lens[1]); - break; - } - - case TYPE_RECORD: - { - auto rv = v->AsRecordVal(); - auto rt = bt->AsRecordType(); - int num_fields = rt->NumFields(); - - if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) - return false; - - for ( int i = 0; i < num_fields; ++i ) - { - auto rv_i = rv->GetField(i); - - Attributes* a = rt->FieldDecl(i)->attrs.get(); - bool optional_attr = (a && a->Find(ATTR_OPTIONAL)); - - if ( ! (rv_i || optional_attr) ) - return false; - - if ( ! SingleValHash(hk, rv_i.get(), rt->GetFieldType(i).get(), type_check, - optional_attr, false) ) - return false; - } - break; - } - - case TYPE_TABLE: - { - if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) - return false; - - auto tv = v->AsTableVal(); - auto hashkeys = ordered_hashkeys(tv); - - hk.Write("table-size", tv->Size()); - - for ( auto& kv : *hashkeys ) - { - auto key = kv.second; - - if ( ! SingleValHash(hk, key.get(), key->GetType().get(), type_check, false, - false) ) - return false; - - if ( ! v->GetType()->IsSet() ) - { - auto val = const_cast(tv)->FindOrDefault(key); - - if ( ! SingleValHash(hk, val.get(), val->GetType().get(), type_check, - false, false) ) - return false; - } - } - } - break; - - case TYPE_VECTOR: - { - if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) - return false; - - auto vv = v->AsVectorVal(); - auto vt = v->GetType()->AsVectorType(); - - hk.Write("vector-size", vv->Size()); - - for ( unsigned int i = 0; i < vv->Size(); ++i ) - { - auto val = vv->ValAt(i); - hk.Write("vector-idx", i); - hk.Write("vector-idx-present", val != nullptr); - - if ( val && ! SingleValHash(hk, val.get(), vt->Yield().get(), type_check, - false, false) ) - return false; - } - } - break; - - case TYPE_LIST: - { - if ( ! hk.IsAllocated() ) - { - if ( ! ReserveSingleTypeKeySize(hk, bt, v, type_check, false, false, - false) ) - return false; - - hk.Allocate(); - } - - auto lv = v->AsListVal(); - - hk.Write("list-size", lv->Length()); - - for ( int i = 0; i < lv->Length(); ++i ) - { - Val* entry_val = lv->Idx(i).get(); - if ( ! SingleValHash(hk, entry_val, entry_val->GetType().get(), type_check, - false, false) ) - return false; - } - } - break; - - default: - { - reporter->InternalError("bad index type in CompositeHash::SingleValHash"); - return false; - } - } - - break; // case TYPE_INTERNAL_VOID/OTHER - } - - case TYPE_INTERNAL_STRING: - { - if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) - return false; - - const auto sval = v->AsString(); - - if ( ! singleton ) - hk.Write("string-len", sval->Len()); - - hk.Write("string", sval->Bytes(), sval->Len()); - } - break; - - default: - return false; - } - - return true; - } - -bool CompositeHash::EnsureTypeReserve(HashKey& hk, const Val* v, Type* bt, bool type_check) const - { - if ( hk.IsAllocated() ) - return true; - - if ( ! ReserveSingleTypeKeySize(hk, bt, v, type_check, false, false, true) ) - return false; - - hk.Allocate(); - return true; - } - -bool CompositeHash::ReserveKeySize(HashKey& hk, const Val* v, bool type_check, - bool calc_static_size) const - { - const auto& tl = type->GetTypes(); - - for ( auto i = 0u; i < tl.size(); ++i ) - { - if ( ! ReserveSingleTypeKeySize(hk, tl[i].get(), v ? v->AsListVal()->Idx(i).get() : nullptr, - type_check, false, calc_static_size, is_singleton) ) - return false; - } - - return true; - } - -bool CompositeHash::ReserveSingleTypeKeySize(HashKey& hk, Type* bt, const Val* v, bool type_check, - bool optional, bool calc_static_size, - bool singleton) const - { - InternalTypeTag t = bt->InternalType(); - - if ( optional ) - { - hk.ReserveType("optional"); - if ( ! v ) - return true; - } - - if ( type_check && v ) - { - InternalTypeTag vt = v->GetType()->InternalType(); - if ( vt != t ) - return false; - } - - switch ( t ) - { - case TYPE_INTERNAL_INT: - hk.ReserveType("int"); - break; - - case TYPE_INTERNAL_UNSIGNED: - hk.ReserveType("unsigned"); - break; - - case TYPE_INTERNAL_ADDR: - hk.Reserve("addr", sizeof(uint32_t) * 4, sizeof(uint32_t)); - break; - - case TYPE_INTERNAL_SUBNET: - hk.Reserve("subnet", sizeof(uint32_t) * 5, sizeof(uint32_t)); - break; - - case TYPE_INTERNAL_DOUBLE: - hk.ReserveType("double"); - break; - - case TYPE_INTERNAL_VOID: - case TYPE_INTERNAL_OTHER: - { - switch ( bt->Tag() ) - { - case TYPE_FUNC: - { - hk.ReserveType("func"); - break; - } - - case TYPE_PATTERN: - { - if ( ! v ) - return (optional && ! calc_static_size); - - if ( ! singleton ) - { - hk.ReserveType("pattern-len1"); - hk.ReserveType("pattern-len2"); - } - - // +1 in the following to include null terminators - hk.Reserve("pattern-string1", strlen(v->AsPattern()->PatternText()) + 1, 0); - hk.Reserve("pattern-string1", strlen(v->AsPattern()->AnywherePatternText()) + 1, - 0); - break; - } - - case TYPE_RECORD: - { - if ( ! v ) - return (optional && ! calc_static_size); - - const RecordVal* rv = v->AsRecordVal(); - RecordType* rt = bt->AsRecordType(); - int num_fields = rt->NumFields(); - - for ( int i = 0; i < num_fields; ++i ) - { - Attributes* a = rt->FieldDecl(i)->attrs.get(); - bool optional_attr = (a && a->Find(ATTR_OPTIONAL)); - - auto rv_v = rv ? rv->GetField(i) : nullptr; - if ( ! ReserveSingleTypeKeySize(hk, rt->GetFieldType(i).get(), rv_v.get(), - type_check, optional_attr, calc_static_size, - false) ) - return false; - } - break; - } - - case TYPE_TABLE: - { - if ( ! v ) - return (optional && ! calc_static_size); - - auto tv = v->AsTableVal(); - auto hashkeys = ordered_hashkeys(tv); - - hk.ReserveType("table-size"); - - for ( auto& kv : *hashkeys ) - { - auto key = kv.second; - - if ( ! ReserveSingleTypeKeySize(hk, key->GetType().get(), key.get(), - type_check, false, calc_static_size, - false) ) - return false; - - if ( ! bt->IsSet() ) - { - auto val = const_cast(tv)->FindOrDefault(key); - if ( ! ReserveSingleTypeKeySize(hk, val->GetType().get(), val.get(), - type_check, false, calc_static_size, - false) ) - return false; - } - } - - break; - } - - case TYPE_VECTOR: - { - if ( ! v ) - return (optional && ! calc_static_size); - - hk.ReserveType("vector-size"); - VectorVal* vv = const_cast(v->AsVectorVal()); - for ( unsigned int i = 0; i < vv->Size(); ++i ) - { - auto val = vv->ValAt(i); - hk.ReserveType("vector-idx"); - hk.ReserveType("vector-idx-present"); - if ( val && ! ReserveSingleTypeKeySize( - hk, bt->AsVectorType()->Yield().get(), val.get(), - type_check, false, calc_static_size, false) ) - return false; - } - break; - } - - case TYPE_LIST: - { - if ( ! v ) - return (optional && ! calc_static_size); - - hk.ReserveType("list-size"); - ListVal* lv = const_cast(v->AsListVal()); - for ( int i = 0; i < lv->Length(); ++i ) - { - if ( ! ReserveSingleTypeKeySize(hk, lv->Idx(i)->GetType().get(), - lv->Idx(i).get(), type_check, false, - calc_static_size, false) ) - return false; - } - - break; - } - - default: - { - reporter->InternalError( - "bad index type in CompositeHash::ReserveSingleTypeKeySize"); - return false; - } - } - - break; // case TYPE_INTERNAL_VOID/OTHER - } - - case TYPE_INTERNAL_STRING: - if ( ! v ) - return (optional && ! calc_static_size); - if ( ! singleton ) - hk.ReserveType("string-len"); - hk.Reserve("string", v->AsString()->Len()); - break; - - case TYPE_INTERNAL_ERROR: - return false; - } - - return true; - } - - } // namespace zeek::detail +const HashkeyMapPtr ordered_hashkeys(const TableVal* tv) { + auto res = std::make_unique(); + auto tbl = tv->AsTable(); + auto idx = 0; + + for ( const auto& entry : *tbl ) { + auto k = entry.GetHashKey(); + // Potential optimization: we could do without the following if + // the caller uses k directly to determine key length & + // content. But: the way k got serialized might differ somewhat + // from how we'll end up doing it (e.g. singleton vs + // non-singleton), and looking up a table value with the hashkey + // is tricky in case of subnets (consider the special-casing in + // TableVal::Find()). + auto lv = tv->RecreateIndex(*k); + res->insert_or_assign(std::move(k), lv); + } + + return res; +} + +CompositeHash::CompositeHash(TypeListPtr composite_type) : type(std::move(composite_type)) { + if ( type->GetTypes().size() == 1 ) + is_singleton = true; +} + +std::unique_ptr CompositeHash::MakeHashKey(const Val& argv, bool type_check) const { + auto res = std::make_unique(); + const auto& tl = type->GetTypes(); + + if ( is_singleton ) { + const Val* v = &argv; + + // This is the "singleton" case -- actually just a single value + // that may come bundled in a list. If so, unwrap it. + if ( v->GetType()->Tag() == TYPE_LIST ) { + auto lv = v->AsListVal(); + + if ( type_check && lv->Length() != 1 ) + return nullptr; + + v = lv->Idx(0).get(); + } + + if ( SingleValHash(*res, v, tl[0].get(), type_check, false, true) ) + return res; + + return nullptr; + } + + if ( type_check && argv.GetType()->Tag() != TYPE_LIST ) + return nullptr; + + if ( ! ReserveKeySize(*res, &argv, type_check, false) ) + return nullptr; + + // Size computation has done requested type-checking, no further need + type_check = false; + + // The size computation resulted in a requested buffer size; allocate it. + res->Allocate(); + + for ( auto i = 0u; i < tl.size(); ++i ) { + if ( ! SingleValHash(*res, argv.AsListVal()->Idx(i).get(), tl[i].get(), type_check, false, false) ) + return nullptr; + } + + return res; +} + +ListValPtr CompositeHash::RecoverVals(const HashKey& hk) const { + auto l = make_intrusive(TYPE_ANY); + const auto& tl = type->GetTypes(); + + hk.ResetRead(); + + for ( const auto& type : tl ) { + ValPtr v; + + if ( ! RecoverOneVal(hk, type.get(), &v, false, is_singleton) ) + reporter->InternalError("value recovery failure in CompositeHash::RecoverVals"); + + ASSERT(v); + l->Append(std::move(v)); + } + + return l; +} + +bool CompositeHash::RecoverOneVal(const HashKey& hk, Type* t, ValPtr* pval, bool optional, bool singleton) const { + TypeTag tag = t->Tag(); + InternalTypeTag it = t->InternalType(); + + if ( optional ) { + bool opt; + hk.Read("optional", opt); + + if ( ! opt ) { + *pval = nullptr; + return true; + } + } + + switch ( it ) { + case TYPE_INTERNAL_INT: { + zeek_int_t i; + hk.Read("int", i); + + if ( tag == TYPE_ENUM ) + *pval = t->AsEnumType()->GetEnumVal(i); + else if ( tag == TYPE_BOOL ) + *pval = val_mgr->Bool(i); + else if ( tag == TYPE_INT ) + *pval = val_mgr->Int(i); + else { + reporter->InternalError("bad internal unsigned int in CompositeHash::RecoverOneVal()"); + *pval = nullptr; + return false; + } + } break; + + case TYPE_INTERNAL_UNSIGNED: { + zeek_uint_t u; + hk.Read("unsigned", u); + + switch ( tag ) { + case TYPE_COUNT: *pval = val_mgr->Count(u); break; + + case TYPE_PORT: *pval = val_mgr->Port(u); break; + + default: + reporter->InternalError("bad internal unsigned int in CompositeHash::RecoverOneVal()"); + *pval = nullptr; + return false; + } + } break; + + case TYPE_INTERNAL_DOUBLE: { + double d; + hk.Read("double", d); + + if ( tag == TYPE_INTERVAL ) + *pval = make_intrusive(d, 1.0); + else if ( tag == TYPE_TIME ) + *pval = make_intrusive(d); + else + *pval = make_intrusive(d); + } break; + + case TYPE_INTERNAL_ADDR: { + hk.AlignRead(sizeof(uint32_t)); + hk.EnsureReadSpace(sizeof(uint32_t) * 4); + IPAddr addr(IPv6, static_cast(hk.KeyAtRead()), IPAddr::Network); + hk.SkipRead("addr", sizeof(uint32_t) * 4); + + switch ( tag ) { + case TYPE_ADDR: *pval = make_intrusive(addr); break; + + default: + reporter->InternalError("bad internal address in CompositeHash::RecoverOneVal()"); + *pval = nullptr; + return false; + } + } break; + + case TYPE_INTERNAL_SUBNET: { + hk.AlignRead(sizeof(uint32_t)); + hk.EnsureReadSpace(sizeof(uint32_t) * 4); + IPAddr addr(IPv6, static_cast(hk.KeyAtRead()), IPAddr::Network); + hk.SkipRead("subnet", sizeof(uint32_t) * 4); + + uint32_t width; + hk.Read("subnet-width", width); + *pval = make_intrusive(addr, width); + } break; + + case TYPE_INTERNAL_VOID: + case TYPE_INTERNAL_OTHER: { + switch ( t->Tag() ) { + case TYPE_FUNC: { + uint32_t id; + hk.Read("func", id); + + ASSERT(func_id_to_func != nullptr); + + if ( id >= func_id_to_func->size() ) + reporter->InternalError("failed to look up unique function id %" PRIu32 + " in CompositeHash::RecoverOneVal()", + id); + + const auto& f = func_id_to_func->at(id); + + *pval = make_intrusive(f); + const auto& pvt = (*pval)->GetType(); + + if ( ! pvt ) + reporter->InternalError("bad aggregate Val in CompositeHash::RecoverOneVal()"); + + else if ( t->Tag() != TYPE_FUNC && ! same_type(pvt, t) ) + // ### Maybe fix later, but may be fundamentally un-checkable --US + { + reporter->InternalError("inconsistent aggregate Val in CompositeHash::RecoverOneVal()"); + *pval = nullptr; + return false; + } + + // ### A crude approximation for now. + else if ( t->Tag() == TYPE_FUNC && pvt->Tag() != TYPE_FUNC ) { + reporter->InternalError("inconsistent aggregate Val in CompositeHash::RecoverOneVal()"); + *pval = nullptr; + return false; + } + } break; + + case TYPE_PATTERN: { + const char* texts[2] = {nullptr, nullptr}; + uint64_t lens[2] = {0, 0}; + + if ( ! singleton ) { + hk.Read("pattern-len1", lens[0]); + hk.Read("pattern-len2", lens[1]); + } + + texts[0] = static_cast(hk.KeyAtRead()); + hk.SkipRead("pattern-string1", strlen(texts[0]) + 1); + texts[1] = static_cast(hk.KeyAtRead()); + hk.SkipRead("pattern-string2", strlen(texts[1]) + 1); + + RE_Matcher* re = new RE_Matcher(texts[0], texts[1]); + + if ( ! re->Compile() ) + reporter->InternalError("failed compiling table/set key pattern: %s", re->PatternText()); + + *pval = make_intrusive(re); + } break; + + case TYPE_RECORD: { + auto rt = t->AsRecordType(); + int num_fields = rt->NumFields(); + + std::vector values; + int i; + for ( i = 0; i < num_fields; ++i ) { + ValPtr v; + Attributes* a = rt->FieldDecl(i)->attrs.get(); + bool is_optional = (a && a->Find(ATTR_OPTIONAL)); + + if ( ! RecoverOneVal(hk, rt->GetFieldType(i).get(), &v, is_optional, false) ) { + *pval = nullptr; + return false; + } + + // An earlier call to reporter->InternalError would have called + // abort() and broken the call tree that clang-tidy is relying on to + // get the error described. + // NOLINTNEXTLINE(clang-analyzer-core.uninitialized.Branch) + if ( ! (v || is_optional) ) { + reporter->InternalError("didn't recover expected number of fields from HashKey"); + *pval = nullptr; + return false; + } + + values.emplace_back(std::move(v)); + } + + ASSERT(int(values.size()) == num_fields); + + auto rv = make_intrusive(IntrusivePtr{NewRef{}, rt}, false /* init_fields */); + + for ( int i = 0; i < num_fields; ++i ) + rv->AppendField(std::move(values[i]), rt->GetFieldType(i)); + + *pval = std::move(rv); + } break; + + case TYPE_TABLE: { + int n; + hk.Read("table-size", n); + auto tt = t->AsTableType(); + auto tv = make_intrusive(IntrusivePtr{NewRef{}, tt}); + + for ( int i = 0; i < n; ++i ) { + ValPtr key; + if ( ! RecoverOneVal(hk, tt->GetIndices().get(), &key, false, false) ) { + *pval = nullptr; + return false; + } + + if ( t->IsSet() ) + tv->Assign(std::move(key), nullptr); + else { + ValPtr value; + if ( ! RecoverOneVal(hk, tt->Yield().get(), &value, false, false) ) { + *pval = nullptr; + return false; + } + tv->Assign(std::move(key), std::move(value)); + } + } + + *pval = std::move(tv); + } break; + + case TYPE_VECTOR: { + unsigned int n; + hk.Read("vector-size", n); + auto vt = t->AsVectorType(); + auto vv = make_intrusive(IntrusivePtr{NewRef{}, vt}); + + for ( unsigned int i = 0; i < n; ++i ) { + unsigned int index; + hk.Read("vector-idx", index); + bool have_val; + hk.Read("vector-idx-present", have_val); + ValPtr value; + + if ( have_val && ! RecoverOneVal(hk, vt->Yield().get(), &value, false, false) ) { + *pval = nullptr; + return false; + } + + vv->Assign(index, std::move(value)); + } + + *pval = std::move(vv); + } break; + + case TYPE_LIST: { + int n; + hk.Read("list-size", n); + auto tl = t->AsTypeList(); + auto lv = make_intrusive(TYPE_ANY); + + for ( int i = 0; i < n; ++i ) { + ValPtr v; + Type* it = tl->GetTypes()[i].get(); + if ( ! RecoverOneVal(hk, it, &v, false, false) ) + return false; + lv->Append(std::move(v)); + } + + *pval = std::move(lv); + } break; + + default: { + reporter->InternalError("bad index type in CompositeHash::RecoverOneVal"); + *pval = nullptr; + return false; + } + } + } break; + + case TYPE_INTERNAL_STRING: { + int n = hk.Size(); + + if ( ! singleton ) { + hk.Read("string-len", n); + hk.EnsureReadSpace(n); + } + + *pval = make_intrusive(new String((const byte_vec)hk.KeyAtRead(), n, true)); + hk.SkipRead("string", n); + } break; + + case TYPE_INTERNAL_ERROR: break; + } + + return true; +} + +bool CompositeHash::SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type_check, bool optional, + bool singleton) const { + InternalTypeTag t = bt->InternalType(); + + if ( type_check && v ) { + InternalTypeTag vt = v->GetType()->InternalType(); + if ( vt != t ) + return false; + } + + if ( optional ) { + // Add a marker saying whether the optional field is set. + hk.Write("optional", v != nullptr); + + if ( ! v ) + return true; + } + + // All of the rest of the code here depends on v not being null, since it needs + // to get values from it. + if ( ! v ) + return false; + + switch ( t ) { + case TYPE_INTERNAL_INT: hk.Write("int", v->AsInt()); break; + + case TYPE_INTERNAL_UNSIGNED: hk.Write("unsigned", v->AsCount()); break; + + case TYPE_INTERNAL_ADDR: + if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) + return false; + + hk.AlignWrite(sizeof(uint32_t)); + hk.EnsureWriteSpace(sizeof(uint32_t) * 4); + v->AsAddr().CopyIPv6(static_cast(hk.KeyAtWrite())); + hk.SkipWrite("addr", sizeof(uint32_t) * 4); + break; + + case TYPE_INTERNAL_SUBNET: + if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) + return false; + + hk.AlignWrite(sizeof(uint32_t)); + hk.EnsureWriteSpace(sizeof(uint32_t) * 5); + v->AsSubNet().Prefix().CopyIPv6(static_cast(hk.KeyAtWrite())); + hk.SkipWrite("subnet", sizeof(uint32_t) * 4); + hk.Write("subnet-width", v->AsSubNet().Length()); + break; + + case TYPE_INTERNAL_DOUBLE: hk.Write("double", v->InternalDouble()); break; + + case TYPE_INTERNAL_VOID: + case TYPE_INTERNAL_OTHER: { + switch ( v->GetType()->Tag() ) { + case TYPE_FUNC: { + auto f = v->AsFunc(); + + if ( ! func_to_func_id ) + const_cast(this)->BuildFuncMappings(); + + auto id_mapping = func_to_func_id->find(f); + uint32_t id; + + if ( id_mapping == func_to_func_id->end() ) { + // We need the pointer to stick around + // for our lifetime, so we have to get + // a non-const version we can ref. + FuncPtr fptr = {NewRef{}, const_cast(f)}; + + id = func_id_to_func->size(); + func_id_to_func->push_back(std::move(fptr)); + func_to_func_id->insert_or_assign(f, id); + } + else + id = id_mapping->second; + + hk.Write("func", id); + } break; + + case TYPE_PATTERN: { + const char* texts[2] = {v->AsPattern()->PatternText(), v->AsPattern()->AnywherePatternText()}; + uint64_t lens[2] = {strlen(texts[0]) + 1, strlen(texts[1]) + 1}; + + if ( ! singleton ) { + hk.Write("pattern-len1", lens[0]); + hk.Write("pattern-len2", lens[1]); + } + else { + hk.Reserve("pattern", lens[0] + lens[1]); + hk.Allocate(); + } + + hk.Write("pattern-string1", static_cast(texts[0]), lens[0]); + hk.Write("pattern-string2", static_cast(texts[1]), lens[1]); + break; + } + + case TYPE_RECORD: { + auto rv = v->AsRecordVal(); + auto rt = bt->AsRecordType(); + int num_fields = rt->NumFields(); + + if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) + return false; + + for ( int i = 0; i < num_fields; ++i ) { + auto rv_i = rv->GetField(i); + + Attributes* a = rt->FieldDecl(i)->attrs.get(); + bool optional_attr = (a && a->Find(ATTR_OPTIONAL)); + + if ( ! (rv_i || optional_attr) ) + return false; + + if ( ! SingleValHash(hk, rv_i.get(), rt->GetFieldType(i).get(), type_check, optional_attr, + false) ) + return false; + } + break; + } + + case TYPE_TABLE: { + if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) + return false; + + auto tv = v->AsTableVal(); + auto hashkeys = ordered_hashkeys(tv); + + hk.Write("table-size", tv->Size()); + + for ( auto& kv : *hashkeys ) { + auto key = kv.second; + + if ( ! SingleValHash(hk, key.get(), key->GetType().get(), type_check, false, false) ) + return false; + + if ( ! v->GetType()->IsSet() ) { + auto val = const_cast(tv)->FindOrDefault(key); + + if ( ! SingleValHash(hk, val.get(), val->GetType().get(), type_check, false, false) ) + return false; + } + } + } break; + + case TYPE_VECTOR: { + if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) + return false; + + auto vv = v->AsVectorVal(); + auto vt = v->GetType()->AsVectorType(); + + hk.Write("vector-size", vv->Size()); + + for ( unsigned int i = 0; i < vv->Size(); ++i ) { + auto val = vv->ValAt(i); + hk.Write("vector-idx", i); + hk.Write("vector-idx-present", val != nullptr); + + if ( val && ! SingleValHash(hk, val.get(), vt->Yield().get(), type_check, false, false) ) + return false; + } + } break; + + case TYPE_LIST: { + if ( ! hk.IsAllocated() ) { + if ( ! ReserveSingleTypeKeySize(hk, bt, v, type_check, false, false, false) ) + return false; + + hk.Allocate(); + } + + auto lv = v->AsListVal(); + + hk.Write("list-size", lv->Length()); + + for ( int i = 0; i < lv->Length(); ++i ) { + Val* entry_val = lv->Idx(i).get(); + if ( ! SingleValHash(hk, entry_val, entry_val->GetType().get(), type_check, false, false) ) + return false; + } + } break; + + default: { + reporter->InternalError("bad index type in CompositeHash::SingleValHash"); + return false; + } + } + + break; // case TYPE_INTERNAL_VOID/OTHER + } + + case TYPE_INTERNAL_STRING: { + if ( ! EnsureTypeReserve(hk, v, bt, type_check) ) + return false; + + const auto sval = v->AsString(); + + if ( ! singleton ) + hk.Write("string-len", sval->Len()); + + hk.Write("string", sval->Bytes(), sval->Len()); + } break; + + default: return false; + } + + return true; +} + +bool CompositeHash::EnsureTypeReserve(HashKey& hk, const Val* v, Type* bt, bool type_check) const { + if ( hk.IsAllocated() ) + return true; + + if ( ! ReserveSingleTypeKeySize(hk, bt, v, type_check, false, false, true) ) + return false; + + hk.Allocate(); + return true; +} + +bool CompositeHash::ReserveKeySize(HashKey& hk, const Val* v, bool type_check, bool calc_static_size) const { + const auto& tl = type->GetTypes(); + + for ( auto i = 0u; i < tl.size(); ++i ) { + if ( ! ReserveSingleTypeKeySize(hk, tl[i].get(), v ? v->AsListVal()->Idx(i).get() : nullptr, type_check, false, + calc_static_size, is_singleton) ) + return false; + } + + return true; +} + +bool CompositeHash::ReserveSingleTypeKeySize(HashKey& hk, Type* bt, const Val* v, bool type_check, bool optional, + bool calc_static_size, bool singleton) const { + InternalTypeTag t = bt->InternalType(); + + if ( optional ) { + hk.ReserveType("optional"); + if ( ! v ) + return true; + } + + if ( type_check && v ) { + InternalTypeTag vt = v->GetType()->InternalType(); + if ( vt != t ) + return false; + } + + switch ( t ) { + case TYPE_INTERNAL_INT: hk.ReserveType("int"); break; + + case TYPE_INTERNAL_UNSIGNED: hk.ReserveType("unsigned"); break; + + case TYPE_INTERNAL_ADDR: hk.Reserve("addr", sizeof(uint32_t) * 4, sizeof(uint32_t)); break; + + case TYPE_INTERNAL_SUBNET: hk.Reserve("subnet", sizeof(uint32_t) * 5, sizeof(uint32_t)); break; + + case TYPE_INTERNAL_DOUBLE: hk.ReserveType("double"); break; + + case TYPE_INTERNAL_VOID: + case TYPE_INTERNAL_OTHER: { + switch ( bt->Tag() ) { + case TYPE_FUNC: { + hk.ReserveType("func"); + break; + } + + case TYPE_PATTERN: { + if ( ! v ) + return (optional && ! calc_static_size); + + if ( ! singleton ) { + hk.ReserveType("pattern-len1"); + hk.ReserveType("pattern-len2"); + } + + // +1 in the following to include null terminators + hk.Reserve("pattern-string1", strlen(v->AsPattern()->PatternText()) + 1, 0); + hk.Reserve("pattern-string1", strlen(v->AsPattern()->AnywherePatternText()) + 1, 0); + break; + } + + case TYPE_RECORD: { + if ( ! v ) + return (optional && ! calc_static_size); + + const RecordVal* rv = v->AsRecordVal(); + RecordType* rt = bt->AsRecordType(); + int num_fields = rt->NumFields(); + + for ( int i = 0; i < num_fields; ++i ) { + Attributes* a = rt->FieldDecl(i)->attrs.get(); + bool optional_attr = (a && a->Find(ATTR_OPTIONAL)); + + auto rv_v = rv ? rv->GetField(i) : nullptr; + if ( ! ReserveSingleTypeKeySize(hk, rt->GetFieldType(i).get(), rv_v.get(), type_check, + optional_attr, calc_static_size, false) ) + return false; + } + break; + } + + case TYPE_TABLE: { + if ( ! v ) + return (optional && ! calc_static_size); + + auto tv = v->AsTableVal(); + auto hashkeys = ordered_hashkeys(tv); + + hk.ReserveType("table-size"); + + for ( auto& kv : *hashkeys ) { + auto key = kv.second; + + if ( ! ReserveSingleTypeKeySize(hk, key->GetType().get(), key.get(), type_check, false, + calc_static_size, false) ) + return false; + + if ( ! bt->IsSet() ) { + auto val = const_cast(tv)->FindOrDefault(key); + if ( ! ReserveSingleTypeKeySize(hk, val->GetType().get(), val.get(), type_check, false, + calc_static_size, false) ) + return false; + } + } + + break; + } + + case TYPE_VECTOR: { + if ( ! v ) + return (optional && ! calc_static_size); + + hk.ReserveType("vector-size"); + VectorVal* vv = const_cast(v->AsVectorVal()); + for ( unsigned int i = 0; i < vv->Size(); ++i ) { + auto val = vv->ValAt(i); + hk.ReserveType("vector-idx"); + hk.ReserveType("vector-idx-present"); + if ( val && ! ReserveSingleTypeKeySize(hk, bt->AsVectorType()->Yield().get(), val.get(), + type_check, false, calc_static_size, false) ) + return false; + } + break; + } + + case TYPE_LIST: { + if ( ! v ) + return (optional && ! calc_static_size); + + hk.ReserveType("list-size"); + ListVal* lv = const_cast(v->AsListVal()); + for ( int i = 0; i < lv->Length(); ++i ) { + if ( ! ReserveSingleTypeKeySize(hk, lv->Idx(i)->GetType().get(), lv->Idx(i).get(), type_check, + false, calc_static_size, false) ) + return false; + } + + break; + } + + default: { + reporter->InternalError("bad index type in CompositeHash::ReserveSingleTypeKeySize"); + return false; + } + } + + break; // case TYPE_INTERNAL_VOID/OTHER + } + + case TYPE_INTERNAL_STRING: + if ( ! v ) + return (optional && ! calc_static_size); + if ( ! singleton ) + hk.ReserveType("string-len"); + hk.Reserve("string", v->AsString()->Len()); + break; + + case TYPE_INTERNAL_ERROR: return false; + } + + return true; +} + +} // namespace zeek::detail diff --git a/src/CompHash.h b/src/CompHash.h index eb76090bd3..a8cde76595 100644 --- a/src/CompHash.h +++ b/src/CompHash.h @@ -7,67 +7,61 @@ #include "zeek/Func.h" #include "zeek/Type.h" -namespace zeek - { +namespace zeek { class ListVal; using ListValPtr = zeek::IntrusivePtr; - } // namespace zeek +} // namespace zeek -namespace zeek::detail - { +namespace zeek::detail { class HashKey; -class CompositeHash - { +class CompositeHash { public: - explicit CompositeHash(TypeListPtr composite_type); + explicit CompositeHash(TypeListPtr composite_type); - // Compute the hash corresponding to the given index val, - // or nullptr if it fails to typecheck. - std::unique_ptr MakeHashKey(const Val& v, bool type_check) const; + // Compute the hash corresponding to the given index val, + // or nullptr if it fails to typecheck. + std::unique_ptr MakeHashKey(const Val& v, bool type_check) const; - // Given a hash key, recover the values used to create it. - ListValPtr RecoverVals(const HashKey& k) const; + // Given a hash key, recover the values used to create it. + ListValPtr RecoverVals(const HashKey& k) const; protected: - bool SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type_check, bool optional, - bool singleton) const; + bool SingleValHash(HashKey& hk, const Val* v, Type* bt, bool type_check, bool optional, bool singleton) const; - // Recovers just one Val of possibly many; called from RecoverVals. - // Upon return, pval will point to the recovered Val of type t. - // Returns and updated kp for the next Val. Calls reporter->InternalError() - // upon errors, so there is no return value for invalid input. - bool RecoverOneVal(const HashKey& k, Type* t, ValPtr* pval, bool optional, - bool singleton) const; + // Recovers just one Val of possibly many; called from RecoverVals. + // Upon return, pval will point to the recovered Val of type t. + // Returns and updated kp for the next Val. Calls reporter->InternalError() + // upon errors, so there is no return value for invalid input. + bool RecoverOneVal(const HashKey& k, Type* t, ValPtr* pval, bool optional, bool singleton) const; - // Compute the size of the composite key. If v is non-nil then - // the value is computed for the particular list of values. - // Returns 0 if the key has an indeterminate size (if v not given), - // or if v doesn't match the index type (if given). - bool ReserveKeySize(HashKey& hk, const Val* v, bool type_check, bool calc_static_size) const; + // Compute the size of the composite key. If v is non-nil then + // the value is computed for the particular list of values. + // Returns 0 if the key has an indeterminate size (if v not given), + // or if v doesn't match the index type (if given). + bool ReserveKeySize(HashKey& hk, const Val* v, bool type_check, bool calc_static_size) const; - bool ReserveSingleTypeKeySize(HashKey& hk, Type*, const Val* v, bool type_check, bool optional, - bool calc_static_size, bool singleton) const; + bool ReserveSingleTypeKeySize(HashKey& hk, Type*, const Val* v, bool type_check, bool optional, + bool calc_static_size, bool singleton) const; - bool EnsureTypeReserve(HashKey& hk, const Val* v, Type* bt, bool type_check) const; + bool EnsureTypeReserve(HashKey& hk, const Val* v, Type* bt, bool type_check) const; - // The following are for allowing hashing of function values. - // These can occur, for example, in sets of predicates that get - // iterated over. We use pointers in order to keep storage - // lower for the common case of these not being needed. - std::unique_ptr> func_to_func_id; - std::unique_ptr> func_id_to_func; - void BuildFuncMappings() - { - func_to_func_id = std::make_unique>(); - func_id_to_func = std::make_unique>(); - } + // The following are for allowing hashing of function values. + // These can occur, for example, in sets of predicates that get + // iterated over. We use pointers in order to keep storage + // lower for the common case of these not being needed. + std::unique_ptr> func_to_func_id; + std::unique_ptr> func_id_to_func; + void BuildFuncMappings() { + func_to_func_id = std::make_unique>(); + func_id_to_func = std::make_unique>(); + } - TypeListPtr type; - bool is_singleton = false; // if just one type in index - }; + TypeListPtr type; + bool is_singleton = false; // if just one type in index +}; - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/Conn.cc b/src/Conn.cc index 5d39db982a..1f21df7775 100644 --- a/src/Conn.cc +++ b/src/Conn.cc @@ -22,470 +22,410 @@ #include "zeek/packet_analysis/protocol/tcp/TCP.h" #include "zeek/session/Manager.h" -namespace zeek - { +namespace zeek { uint64_t Connection::total_connections = 0; uint64_t Connection::current_connections = 0; -Connection::Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, - const Packet* pkt) - : Session(t, connection_timeout, connection_status_update, - detail::connection_status_update_interval), - key(k) - { - orig_addr = id->src_addr; - resp_addr = id->dst_addr; - orig_port = id->src_port; - resp_port = id->dst_port; - proto = TRANSPORT_UNKNOWN; - orig_flow_label = flow; - resp_flow_label = 0; - saw_first_orig_packet = 1; - saw_first_resp_packet = 0; +Connection::Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, const Packet* pkt) + : Session(t, connection_timeout, connection_status_update, detail::connection_status_update_interval), key(k) { + orig_addr = id->src_addr; + resp_addr = id->dst_addr; + orig_port = id->src_port; + resp_port = id->dst_port; + proto = TRANSPORT_UNKNOWN; + orig_flow_label = flow; + resp_flow_label = 0; + saw_first_orig_packet = 1; + saw_first_resp_packet = 0; - if ( pkt->l2_src ) - memcpy(orig_l2_addr, pkt->l2_src, sizeof(orig_l2_addr)); - else - memset(orig_l2_addr, 0, sizeof(orig_l2_addr)); + if ( pkt->l2_src ) + memcpy(orig_l2_addr, pkt->l2_src, sizeof(orig_l2_addr)); + else + memset(orig_l2_addr, 0, sizeof(orig_l2_addr)); - if ( pkt->l2_dst ) - memcpy(resp_l2_addr, pkt->l2_dst, sizeof(resp_l2_addr)); - else - memset(resp_l2_addr, 0, sizeof(resp_l2_addr)); + if ( pkt->l2_dst ) + memcpy(resp_l2_addr, pkt->l2_dst, sizeof(resp_l2_addr)); + else + memset(resp_l2_addr, 0, sizeof(resp_l2_addr)); - vlan = pkt->vlan; - inner_vlan = pkt->inner_vlan; + vlan = pkt->vlan; + inner_vlan = pkt->inner_vlan; - weird = 0; + weird = 0; - suppress_event = 0; + suppress_event = 0; - finished = 0; + finished = 0; - hist_seen = 0; - history = ""; + hist_seen = 0; + history = ""; - adapter = nullptr; - primary_PIA = nullptr; + adapter = nullptr; + primary_PIA = nullptr; - ++current_connections; - ++total_connections; + ++current_connections; + ++total_connections; - encapsulation = pkt->encap; - } + encapsulation = pkt->encap; +} -Connection::~Connection() - { - if ( ! finished ) - reporter->InternalError("Done() not called before destruction of Connection"); +Connection::~Connection() { + if ( ! finished ) + reporter->InternalError("Done() not called before destruction of Connection"); - CancelTimers(); + CancelTimers(); - if ( conn_val ) - conn_val->SetOrigin(nullptr); + if ( conn_val ) + conn_val->SetOrigin(nullptr); - delete adapter; + delete adapter; - --current_connections; - } + --current_connections; +} -void Connection::CheckEncapsulation(const std::shared_ptr& arg_encap) - { - if ( encapsulation && arg_encap ) - { - if ( *encapsulation != *arg_encap ) - { - if ( tunnel_changed && - (zeek::detail::tunnel_max_changes_per_connection == 0 || - tunnel_changes < zeek::detail::tunnel_max_changes_per_connection) ) - { - tunnel_changes++; - EnqueueEvent(tunnel_changed, nullptr, GetVal(), arg_encap->ToVal()); - } +void Connection::CheckEncapsulation(const std::shared_ptr& arg_encap) { + if ( encapsulation && arg_encap ) { + if ( *encapsulation != *arg_encap ) { + if ( tunnel_changed && (zeek::detail::tunnel_max_changes_per_connection == 0 || + tunnel_changes < zeek::detail::tunnel_max_changes_per_connection) ) { + tunnel_changes++; + EnqueueEvent(tunnel_changed, nullptr, GetVal(), arg_encap->ToVal()); + } - encapsulation = std::make_shared(*arg_encap); - } - } + encapsulation = std::make_shared(*arg_encap); + } + } - else if ( encapsulation ) - { - if ( tunnel_changed ) - { - EncapsulationStack empty; - EnqueueEvent(tunnel_changed, nullptr, GetVal(), empty.ToVal()); - } + else if ( encapsulation ) { + if ( tunnel_changed ) { + EncapsulationStack empty; + EnqueueEvent(tunnel_changed, nullptr, GetVal(), empty.ToVal()); + } - encapsulation = nullptr; - } + encapsulation = nullptr; + } - else if ( arg_encap ) - { - if ( tunnel_changed ) - EnqueueEvent(tunnel_changed, nullptr, GetVal(), arg_encap->ToVal()); + else if ( arg_encap ) { + if ( tunnel_changed ) + EnqueueEvent(tunnel_changed, nullptr, GetVal(), arg_encap->ToVal()); - encapsulation = std::make_shared(*arg_encap); - } - } + encapsulation = std::make_shared(*arg_encap); + } +} -void Connection::Done() - { - finished = 1; +void Connection::Done() { + finished = 1; - if ( adapter ) - { - if ( ConnTransport() == TRANSPORT_TCP ) - { - auto* ta = static_cast(adapter); - assert(ta->IsAnalyzer("TCP")); - analyzer::tcp::TCP_Endpoint* to = ta->Orig(); - analyzer::tcp::TCP_Endpoint* tr = ta->Resp(); + if ( adapter ) { + if ( ConnTransport() == TRANSPORT_TCP ) { + auto* ta = static_cast(adapter); + assert(ta->IsAnalyzer("TCP")); + analyzer::tcp::TCP_Endpoint* to = ta->Orig(); + analyzer::tcp::TCP_Endpoint* tr = ta->Resp(); - packet_analysis::TCP::TCPAnalyzer::GetStats().StateLeft(to->state, tr->state); - } + packet_analysis::TCP::TCPAnalyzer::GetStats().StateLeft(to->state, tr->state); + } - if ( ! adapter->IsFinished() ) - adapter->Done(); - } - } + if ( ! adapter->IsFinished() ) + adapter->Done(); + } +} -void Connection::NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, int caplen, - const u_char*& data, int& record_packet, int& record_content, +void Connection::NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, int caplen, const u_char*& data, + int& record_packet, int& record_content, // arguments for reproducing packets - const Packet* pkt) - { - run_state::current_timestamp = t; - run_state::current_pkt = pkt; + const Packet* pkt) { + run_state::current_timestamp = t; + run_state::current_pkt = pkt; - if ( adapter ) - { - if ( adapter->Skipping() ) - return; + if ( adapter ) { + if ( adapter->Skipping() ) + return; - record_current_packet = record_packet; - record_current_content = record_content; - adapter->NextPacket(len, data, is_orig, -1, ip, caplen); - record_packet = record_current_packet; - record_content = record_current_content; - } - else - last_time = t; + record_current_packet = record_packet; + record_current_content = record_content; + adapter->NextPacket(len, data, is_orig, -1, ip, caplen); + record_packet = record_current_packet; + record_content = record_current_content; + } + else + last_time = t; - run_state::current_timestamp = 0; - run_state::current_pkt = nullptr; - } + run_state::current_timestamp = 0; + run_state::current_pkt = nullptr; +} -bool Connection::IsReuse(double t, const u_char* pkt) - { - return adapter && adapter->IsReuse(t, pkt); - } +bool Connection::IsReuse(double t, const u_char* pkt) { return adapter && adapter->IsReuse(t, pkt); } -bool Connection::ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scaling_threshold, - uint32_t scaling_base) - { - if ( ++counter == scaling_threshold ) - { - AddHistory(code); +bool Connection::ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scaling_threshold, uint32_t scaling_base) { + if ( ++counter == scaling_threshold ) { + AddHistory(code); - auto new_threshold = scaling_threshold * scaling_base; + auto new_threshold = scaling_threshold * scaling_base; - if ( new_threshold <= scaling_threshold ) - // This can happen due to wrap-around. In that - // case, reset the counter but leave the threshold - // unchanged. - counter = 0; + if ( new_threshold <= scaling_threshold ) + // This can happen due to wrap-around. In that + // case, reset the counter but leave the threshold + // unchanged. + counter = 0; - else - scaling_threshold = new_threshold; + else + scaling_threshold = new_threshold; - return true; - } + return true; + } - return false; - } + return false; +} -void Connection::HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t threshold) - { - if ( ! e ) - return; +void Connection::HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t threshold) { + if ( ! e ) + return; - if ( threshold == 1 ) - // This will be far and away the most common case, - // and at this stage it's not a *multiple* instance. - return; + if ( threshold == 1 ) + // This will be far and away the most common case, + // and at this stage it's not a *multiple* instance. + return; - EnqueueEvent(e, nullptr, GetVal(), val_mgr->Bool(is_orig), val_mgr->Count(threshold)); - } + EnqueueEvent(e, nullptr, GetVal(), val_mgr->Bool(is_orig), val_mgr->Count(threshold)); +} -namespace - { +namespace { // Flip everything that needs to be flipped in the connection // record that is known on this level. This needs to align // with GetVal() and connection's layout in init-bare. -void flip_conn_val(const RecordValPtr& conn_val) - { - // Flip the the conn_id (c$id). - const auto& id_val = conn_val->GetField(0); - const auto& tmp_addr = id_val->GetField(0); - const auto& tmp_port = id_val->GetField(1); - id_val->Assign(0, id_val->GetField(2)); - id_val->Assign(1, id_val->GetField(3)); - id_val->Assign(2, tmp_addr); - id_val->Assign(3, tmp_port); +void flip_conn_val(const RecordValPtr& conn_val) { + // Flip the the conn_id (c$id). + const auto& id_val = conn_val->GetField(0); + const auto& tmp_addr = id_val->GetField(0); + const auto& tmp_port = id_val->GetField(1); + id_val->Assign(0, id_val->GetField(2)); + id_val->Assign(1, id_val->GetField(3)); + id_val->Assign(2, tmp_addr); + id_val->Assign(3, tmp_port); - // Flip the endpoints within connection. - const auto& tmp_endp = conn_val->GetField(1); - conn_val->Assign(1, conn_val->GetField(2)); - conn_val->Assign(2, tmp_endp); - } - } + // Flip the endpoints within connection. + const auto& tmp_endp = conn_val->GetField(1); + conn_val->Assign(1, conn_val->GetField(2)); + conn_val->Assign(2, tmp_endp); +} +} // namespace -const RecordValPtr& Connection::GetVal() - { - if ( ! conn_val ) - { - conn_val = make_intrusive(id::connection); +const RecordValPtr& Connection::GetVal() { + if ( ! conn_val ) { + conn_val = make_intrusive(id::connection); - TransportProto prot_type = ConnTransport(); + TransportProto prot_type = ConnTransport(); - auto id_val = make_intrusive(id::conn_id); - id_val->Assign(0, make_intrusive(orig_addr)); - id_val->Assign(1, val_mgr->Port(ntohs(orig_port), prot_type)); - id_val->Assign(2, make_intrusive(resp_addr)); - id_val->Assign(3, val_mgr->Port(ntohs(resp_port), prot_type)); + auto id_val = make_intrusive(id::conn_id); + id_val->Assign(0, make_intrusive(orig_addr)); + id_val->Assign(1, val_mgr->Port(ntohs(orig_port), prot_type)); + id_val->Assign(2, make_intrusive(resp_addr)); + id_val->Assign(3, val_mgr->Port(ntohs(resp_port), prot_type)); - auto orig_endp = make_intrusive(id::endpoint); - orig_endp->Assign(0, 0); - orig_endp->Assign(1, 0); - orig_endp->Assign(4, orig_flow_label); + auto orig_endp = make_intrusive(id::endpoint); + orig_endp->Assign(0, 0); + orig_endp->Assign(1, 0); + orig_endp->Assign(4, orig_flow_label); - const int l2_len = sizeof(orig_l2_addr); - char null[l2_len]{}; + const int l2_len = sizeof(orig_l2_addr); + char null[l2_len]{}; - if ( memcmp(&orig_l2_addr, &null, l2_len) != 0 ) - orig_endp->Assign(5, fmt_mac(orig_l2_addr, l2_len)); + if ( memcmp(&orig_l2_addr, &null, l2_len) != 0 ) + orig_endp->Assign(5, fmt_mac(orig_l2_addr, l2_len)); - auto resp_endp = make_intrusive(id::endpoint); - resp_endp->Assign(0, 0); - resp_endp->Assign(1, 0); - resp_endp->Assign(4, resp_flow_label); + auto resp_endp = make_intrusive(id::endpoint); + resp_endp->Assign(0, 0); + resp_endp->Assign(1, 0); + resp_endp->Assign(4, resp_flow_label); - if ( memcmp(&resp_l2_addr, &null, l2_len) != 0 ) - resp_endp->Assign(5, fmt_mac(resp_l2_addr, l2_len)); + if ( memcmp(&resp_l2_addr, &null, l2_len) != 0 ) + resp_endp->Assign(5, fmt_mac(resp_l2_addr, l2_len)); - conn_val->Assign(0, std::move(id_val)); - conn_val->Assign(1, std::move(orig_endp)); - conn_val->Assign(2, std::move(resp_endp)); - // 3 and 4 are set below. - conn_val->Assign(5, make_intrusive(id::string_set)); // service - conn_val->Assign(6, val_mgr->EmptyString()); // history + conn_val->Assign(0, std::move(id_val)); + conn_val->Assign(1, std::move(orig_endp)); + conn_val->Assign(2, std::move(resp_endp)); + // 3 and 4 are set below. + conn_val->Assign(5, make_intrusive(id::string_set)); // service + conn_val->Assign(6, val_mgr->EmptyString()); // history - if ( ! uid ) - uid.Set(zeek::detail::bits_per_uid); + if ( ! uid ) + uid.Set(zeek::detail::bits_per_uid); - conn_val->Assign(7, uid.Base62("C")); + conn_val->Assign(7, uid.Base62("C")); - if ( encapsulation && encapsulation->Depth() > 0 ) - conn_val->Assign(8, encapsulation->ToVal()); + if ( encapsulation && encapsulation->Depth() > 0 ) + conn_val->Assign(8, encapsulation->ToVal()); - if ( vlan != 0 ) - conn_val->Assign(9, vlan); + if ( vlan != 0 ) + conn_val->Assign(9, vlan); - if ( inner_vlan != 0 ) - conn_val->Assign(10, inner_vlan); - } + if ( inner_vlan != 0 ) + conn_val->Assign(10, inner_vlan); + } - if ( adapter ) - adapter->UpdateConnVal(conn_val.get()); + if ( adapter ) + adapter->UpdateConnVal(conn_val.get()); - conn_val->AssignTime(3, start_time); // ### - conn_val->AssignInterval(4, last_time - start_time); + conn_val->AssignTime(3, start_time); // ### + conn_val->AssignInterval(4, last_time - start_time); - if ( ! history.empty() ) - { - auto v = conn_val->GetFieldAs(6); - if ( *v != history ) - conn_val->Assign(6, history); - } + if ( ! history.empty() ) { + auto v = conn_val->GetFieldAs(6); + if ( *v != history ) + conn_val->Assign(6, history); + } - conn_val->SetOrigin(this); + conn_val->SetOrigin(this); - return conn_val; - } + return conn_val; +} -analyzer::Analyzer* Connection::FindAnalyzer(analyzer::ID id) - { - return adapter ? adapter->FindChild(id) : nullptr; - } +analyzer::Analyzer* Connection::FindAnalyzer(analyzer::ID id) { return adapter ? adapter->FindChild(id) : nullptr; } -analyzer::Analyzer* Connection::FindAnalyzer(const zeek::Tag& tag) - { - return adapter ? adapter->FindChild(tag) : nullptr; - } +analyzer::Analyzer* Connection::FindAnalyzer(const zeek::Tag& tag) { + return adapter ? adapter->FindChild(tag) : nullptr; +} -analyzer::Analyzer* Connection::FindAnalyzer(const char* name) - { - return adapter->FindChild(name); - } +analyzer::Analyzer* Connection::FindAnalyzer(const char* name) { return adapter->FindChild(name); } -void Connection::AppendAddl(const char* str) - { - const auto& cv = GetVal(); +void Connection::AppendAddl(const char* str) { + const auto& cv = GetVal(); - const char* old = cv->GetFieldAs(6)->CheckString(); - const char* format = *old ? "%s %s" : "%s%s"; + const char* old = cv->GetFieldAs(6)->CheckString(); + const char* format = *old ? "%s %s" : "%s%s"; - cv->Assign(6, util::fmt(format, old, str)); - } + cv->Assign(6, util::fmt(format, old, str)); +} -void Connection::Match(detail::Rule::PatternType type, const u_char* data, int len, bool is_orig, - bool bol, bool eol, bool clear_state) - { - if ( primary_PIA ) - primary_PIA->Match(type, data, len, is_orig, bol, eol, clear_state); - } +void Connection::Match(detail::Rule::PatternType type, const u_char* data, int len, bool is_orig, bool bol, bool eol, + bool clear_state) { + if ( primary_PIA ) + primary_PIA->Match(type, data, len, is_orig, bol, eol, clear_state); +} -void Connection::RemovalEvent() - { - if ( connection_state_remove ) - EnqueueEvent(connection_state_remove, nullptr, GetVal()); - } +void Connection::RemovalEvent() { + if ( connection_state_remove ) + EnqueueEvent(connection_state_remove, nullptr, GetVal()); +} -void Connection::Weird(const char* name, const char* addl, const char* source) - { - weird = 1; - reporter->Weird(this, name, addl ? addl : "", source ? source : ""); - } +void Connection::Weird(const char* name, const char* addl, const char* source) { + weird = 1; + reporter->Weird(this, name, addl ? addl : "", source ? source : ""); +} -void Connection::FlipRoles() - { - IPAddr tmp_addr = resp_addr; - resp_addr = orig_addr; - orig_addr = tmp_addr; +void Connection::FlipRoles() { + IPAddr tmp_addr = resp_addr; + resp_addr = orig_addr; + orig_addr = tmp_addr; - uint32_t tmp_port = resp_port; - resp_port = orig_port; - orig_port = tmp_port; + uint32_t tmp_port = resp_port; + resp_port = orig_port; + orig_port = tmp_port; - const int l2_len = sizeof(orig_l2_addr); - u_char tmp_l2_addr[l2_len]; - memcpy(tmp_l2_addr, resp_l2_addr, l2_len); - memcpy(resp_l2_addr, orig_l2_addr, l2_len); - memcpy(orig_l2_addr, tmp_l2_addr, l2_len); + const int l2_len = sizeof(orig_l2_addr); + u_char tmp_l2_addr[l2_len]; + memcpy(tmp_l2_addr, resp_l2_addr, l2_len); + memcpy(resp_l2_addr, orig_l2_addr, l2_len); + memcpy(orig_l2_addr, tmp_l2_addr, l2_len); - bool tmp_bool = saw_first_resp_packet; - saw_first_resp_packet = saw_first_orig_packet; - saw_first_orig_packet = tmp_bool; + bool tmp_bool = saw_first_resp_packet; + saw_first_resp_packet = saw_first_orig_packet; + saw_first_orig_packet = tmp_bool; - uint32_t tmp_flow = resp_flow_label; - resp_flow_label = orig_flow_label; - orig_flow_label = tmp_flow; + uint32_t tmp_flow = resp_flow_label; + resp_flow_label = orig_flow_label; + orig_flow_label = tmp_flow; - if ( conn_val ) - flip_conn_val(conn_val); + if ( conn_val ) + flip_conn_val(conn_val); - if ( adapter ) - adapter->FlipRoles(); + if ( adapter ) + adapter->FlipRoles(); - analyzer_mgr->ApplyScheduledAnalyzers(this); + analyzer_mgr->ApplyScheduledAnalyzers(this); - AddHistory('^'); + AddHistory('^'); - if ( connection_flipped ) - EnqueueEvent(connection_flipped, nullptr, GetVal()); - } + if ( connection_flipped ) + EnqueueEvent(connection_flipped, nullptr, GetVal()); +} -void Connection::Describe(ODesc* d) const - { - session::Session::Describe(d); +void Connection::Describe(ODesc* d) const { + session::Session::Describe(d); - switch ( proto ) - { - case TRANSPORT_TCP: - d->Add("TCP"); - break; + switch ( proto ) { + case TRANSPORT_TCP: d->Add("TCP"); break; - case TRANSPORT_UDP: - d->Add("UDP"); - break; + case TRANSPORT_UDP: d->Add("UDP"); break; - case TRANSPORT_ICMP: - d->Add("ICMP"); - break; + case TRANSPORT_ICMP: d->Add("ICMP"); break; - case TRANSPORT_UNKNOWN: - d->Add("unknown"); - reporter->InternalWarning("unknown transport in Connection::Describe()"); + case TRANSPORT_UNKNOWN: + d->Add("unknown"); + reporter->InternalWarning("unknown transport in Connection::Describe()"); - break; + break; - default: - reporter->InternalError("unhandled transport type in Connection::Describe"); - } + default: reporter->InternalError("unhandled transport type in Connection::Describe"); + } - d->SP(); - d->Add(orig_addr); - d->Add(":"); - d->Add(ntohs(orig_port)); + d->SP(); + d->Add(orig_addr); + d->Add(":"); + d->Add(ntohs(orig_port)); - d->SP(); - d->AddSP("->"); + d->SP(); + d->AddSP("->"); - d->Add(resp_addr); - d->Add(":"); - d->Add(ntohs(resp_port)); + d->Add(resp_addr); + d->Add(":"); + d->Add(ntohs(resp_port)); - d->NL(); - } + d->NL(); +} -void Connection::IDString(ODesc* d) const - { - d->Add(orig_addr); - d->AddRaw(":", 1); - d->Add(ntohs(orig_port)); - d->AddRaw(" > ", 3); - d->Add(resp_addr); - d->AddRaw(":", 1); - d->Add(ntohs(resp_port)); - } +void Connection::IDString(ODesc* d) const { + d->Add(orig_addr); + d->AddRaw(":", 1); + d->Add(ntohs(orig_port)); + d->AddRaw(" > ", 3); + d->Add(resp_addr); + d->AddRaw(":", 1); + d->Add(ntohs(resp_port)); +} -void Connection::SetSessionAdapter(packet_analysis::IP::SessionAdapter* aa, analyzer::pia::PIA* pia) - { - adapter = aa; - primary_PIA = pia; - } +void Connection::SetSessionAdapter(packet_analysis::IP::SessionAdapter* aa, analyzer::pia::PIA* pia) { + adapter = aa; + primary_PIA = pia; +} -void Connection::CheckFlowLabel(bool is_orig, uint32_t flow_label) - { - uint32_t& my_flow_label = is_orig ? orig_flow_label : resp_flow_label; +void Connection::CheckFlowLabel(bool is_orig, uint32_t flow_label) { + uint32_t& my_flow_label = is_orig ? orig_flow_label : resp_flow_label; - if ( my_flow_label != flow_label ) - { - if ( conn_val ) - { - RecordVal* endp = conn_val->GetFieldAs(is_orig ? 1 : 2); - endp->Assign(4, flow_label); - } + if ( my_flow_label != flow_label ) { + if ( conn_val ) { + RecordVal* endp = conn_val->GetFieldAs(is_orig ? 1 : 2); + endp->Assign(4, flow_label); + } - if ( connection_flow_label_changed && - (is_orig ? saw_first_orig_packet : saw_first_resp_packet) ) - { - EnqueueEvent(connection_flow_label_changed, nullptr, GetVal(), val_mgr->Bool(is_orig), - val_mgr->Count(my_flow_label), val_mgr->Count(flow_label)); - } + if ( connection_flow_label_changed && (is_orig ? saw_first_orig_packet : saw_first_resp_packet) ) { + EnqueueEvent(connection_flow_label_changed, nullptr, GetVal(), val_mgr->Bool(is_orig), + val_mgr->Count(my_flow_label), val_mgr->Count(flow_label)); + } - my_flow_label = flow_label; - } + my_flow_label = flow_label; + } - if ( is_orig ) - saw_first_orig_packet = 1; - else - saw_first_resp_packet = 1; - } + if ( is_orig ) + saw_first_orig_packet = 1; + else + saw_first_resp_packet = 1; +} -bool Connection::PermitWeird(const char* name, uint64_t threshold, uint64_t rate, double duration) - { - return detail::PermitWeird(weird_state, name, threshold, rate, duration); - } +bool Connection::PermitWeird(const char* name, uint64_t threshold, uint64_t rate, double duration) { + return detail::PermitWeird(weird_state, name, threshold, rate, duration); +} - } // namespace zeek +} // namespace zeek diff --git a/src/Conn.h b/src/Conn.h index ded0865a6d..6ad9d56a55 100644 --- a/src/Conn.h +++ b/src/Conn.h @@ -19,8 +19,7 @@ #include "zeek/iosource/Packet.h" #include "zeek/session/Session.h" -namespace zeek - { +namespace zeek { class Connection; class EncapsulationStack; @@ -30,253 +29,235 @@ class RecordVal; using ValPtr = IntrusivePtr; using RecordValPtr = IntrusivePtr; -namespace session - { +namespace session { class Manager; - } -namespace detail - { +} +namespace detail { class Specific_RE_Matcher; class RuleEndpointState; class RuleHdrTest; - } // namespace detail +} // namespace detail -namespace analyzer - { +namespace analyzer { class Analyzer; - } -namespace packet_analysis::IP - { +} +namespace packet_analysis::IP { class SessionAdapter; - } +} -enum ConnEventToFlag - { - NUL_IN_LINE, - SINGULAR_CR, - SINGULAR_LF, - NUM_EVENTS_TO_FLAG, - }; +enum ConnEventToFlag { + NUL_IN_LINE, + SINGULAR_CR, + SINGULAR_LF, + NUM_EVENTS_TO_FLAG, +}; -struct ConnTuple - { - IPAddr src_addr; - IPAddr dst_addr; - uint32_t src_port = 0; - uint32_t dst_port = 0; - bool is_one_way = false; // if true, don't canonicalize order - TransportProto proto = TRANSPORT_UNKNOWN; - }; +struct ConnTuple { + IPAddr src_addr; + IPAddr dst_addr; + uint32_t src_port = 0; + uint32_t dst_port = 0; + bool is_one_way = false; // if true, don't canonicalize order + TransportProto proto = TRANSPORT_UNKNOWN; +}; -static inline int addr_port_canon_lt(const IPAddr& addr1, uint32_t p1, const IPAddr& addr2, - uint32_t p2) - { - return addr1 < addr2 || (addr1 == addr2 && p1 < p2); - } +static inline int addr_port_canon_lt(const IPAddr& addr1, uint32_t p1, const IPAddr& addr2, uint32_t p2) { + return addr1 < addr2 || (addr1 == addr2 && p1 < p2); +} -class Connection final : public session::Session - { +class Connection final : public session::Session { public: - Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, - const Packet* pkt); - ~Connection() override; + Connection(const detail::ConnKey& k, double t, const ConnTuple* id, uint32_t flow, const Packet* pkt); + ~Connection() override; - /** - * Invoked when an encapsulation is discovered. It records the encapsulation - * with the connection and raises a "tunnel_changed" event if it's different - * from the previous encapsulation or if it's the first one encountered. - * - * @param encap The new encapsulation. Can be set to null to indicated no - * encapsulation or clear an old one. - */ - void CheckEncapsulation(const std::shared_ptr& encap); + /** + * Invoked when an encapsulation is discovered. It records the encapsulation + * with the connection and raises a "tunnel_changed" event if it's different + * from the previous encapsulation or if it's the first one encountered. + * + * @param encap The new encapsulation. Can be set to null to indicated no + * encapsulation or clear an old one. + */ + void CheckEncapsulation(const std::shared_ptr& encap); - /** - * Invoked when the session is about to be removed. Use Ref(this) - * inside Done to keep the session object around, though it'll - * no longer be accessible from the SessionManager. - */ - void Done() override; + /** + * Invoked when the session is about to be removed. Use Ref(this) + * inside Done to keep the session object around, though it'll + * no longer be accessible from the SessionManager. + */ + void Done() override; - // Process the connection's next packet. "data" points just - // beyond the IP header. It's updated to point just beyond - // the transport header (or whatever should be saved, if we - // decide not to save the full packet contents). - // - // If record_packet is true, the packet should be recorded. - // If record_content is true, then its entire contents should - // be recorded, otherwise just up through the transport header. - // Both are assumed set to true when called. - void NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, int caplen, - const u_char*& data, int& record_packet, int& record_content, - // arguments for reproducing packets - const Packet* pkt); + // Process the connection's next packet. "data" points just + // beyond the IP header. It's updated to point just beyond + // the transport header (or whatever should be saved, if we + // decide not to save the full packet contents). + // + // If record_packet is true, the packet should be recorded. + // If record_content is true, then its entire contents should + // be recorded, otherwise just up through the transport header. + // Both are assumed set to true when called. + void NextPacket(double t, bool is_orig, const IP_Hdr* ip, int len, int caplen, const u_char*& data, + int& record_packet, int& record_content, + // arguments for reproducing packets + const Packet* pkt); - // Keys are only considered valid for a connection when a - // connection is in the session map. If it is removed, the key - // should be marked invalid. - const detail::ConnKey& Key() const { return key; } - session::detail::Key SessionKey(bool copy) const override - { - return session::detail::Key{&key, sizeof(key), session::detail::Key::CONNECTION_KEY_TYPE, - copy}; - } + // Keys are only considered valid for a connection when a + // connection is in the session map. If it is removed, the key + // should be marked invalid. + const detail::ConnKey& Key() const { return key; } + session::detail::Key SessionKey(bool copy) const override { + return session::detail::Key{&key, sizeof(key), session::detail::Key::CONNECTION_KEY_TYPE, copy}; + } - const IPAddr& OrigAddr() const { return orig_addr; } - const IPAddr& RespAddr() const { return resp_addr; } + const IPAddr& OrigAddr() const { return orig_addr; } + const IPAddr& RespAddr() const { return resp_addr; } - uint32_t OrigPort() const { return orig_port; } - uint32_t RespPort() const { return resp_port; } + uint32_t OrigPort() const { return orig_port; } + uint32_t RespPort() const { return resp_port; } - void FlipRoles(); + void FlipRoles(); - analyzer::Analyzer* FindAnalyzer(analyzer::ID id); - analyzer::Analyzer* FindAnalyzer(const zeek::Tag& tag); // find first in tree. - analyzer::Analyzer* FindAnalyzer(const char* name); // find first in tree. + analyzer::Analyzer* FindAnalyzer(analyzer::ID id); + analyzer::Analyzer* FindAnalyzer(const zeek::Tag& tag); // find first in tree. + analyzer::Analyzer* FindAnalyzer(const char* name); // find first in tree. - TransportProto ConnTransport() const { return proto; } - std::string TransportIdentifier() const override - { - if ( proto == TRANSPORT_TCP ) - return "tcp"; - else if ( proto == TRANSPORT_UDP ) - return "udp"; - else if ( proto == TRANSPORT_ICMP ) - return "icmp"; - else - return "unknown"; - } + TransportProto ConnTransport() const { return proto; } + std::string TransportIdentifier() const override { + if ( proto == TRANSPORT_TCP ) + return "tcp"; + else if ( proto == TRANSPORT_UDP ) + return "udp"; + else if ( proto == TRANSPORT_ICMP ) + return "icmp"; + else + return "unknown"; + } - // Returns true if the packet reflects a reuse of this - // connection (i.e., not a continuation but the beginning of - // a new connection). - bool IsReuse(double t, const u_char* pkt); + // Returns true if the packet reflects a reuse of this + // connection (i.e., not a continuation but the beginning of + // a new connection). + bool IsReuse(double t, const u_char* pkt); - /** - * Returns the associated "connection" record. - */ - const RecordValPtr& GetVal() override; + /** + * Returns the associated "connection" record. + */ + const RecordValPtr& GetVal() override; - /** - * Append additional entries to the history field in the connection record. - */ - void AppendAddl(const char* str); + /** + * Append additional entries to the history field in the connection record. + */ + void AppendAddl(const char* str); - void Match(detail::Rule::PatternType type, const u_char* data, int len, bool is_orig, bool bol, - bool eol, bool clear_state); + void Match(detail::Rule::PatternType type, const u_char* data, int len, bool is_orig, bool bol, bool eol, + bool clear_state); - /** - * Generates connection removal event(s). - */ - void RemovalEvent() override; + /** + * Generates connection removal event(s). + */ + void RemovalEvent() override; - void Weird(const char* name, const char* addl = "", const char* source = ""); - bool DidWeird() const { return weird != 0; } + void Weird(const char* name, const char* addl = "", const char* source = ""); + bool DidWeird() const { return weird != 0; } - inline bool FlagEvent(ConnEventToFlag e) - { - if ( e >= 0 && e < NUM_EVENTS_TO_FLAG ) - { - if ( suppress_event & (1 << e) ) - return false; - suppress_event |= 1 << e; - } + inline bool FlagEvent(ConnEventToFlag e) { + if ( e >= 0 && e < NUM_EVENTS_TO_FLAG ) { + if ( suppress_event & (1 << e) ) + return false; + suppress_event |= 1 << e; + } - return true; - } + return true; + } - void Describe(ODesc* d) const override; - void IDString(ODesc* d) const; + void Describe(ODesc* d) const override; + void IDString(ODesc* d) const; - // Statistics. + // Statistics. - static uint64_t TotalConnections() { return total_connections; } - static uint64_t CurrentConnections() { return current_connections; } + static uint64_t TotalConnections() { return total_connections; } + static uint64_t CurrentConnections() { return current_connections; } - // Returns true if the history was already seen, false otherwise. - bool CheckHistory(uint32_t mask, char code) - { - if ( (hist_seen & mask) == 0 ) - { - hist_seen |= mask; - AddHistory(code); - return false; - } - else - return true; - } + // Returns true if the history was already seen, false otherwise. + bool CheckHistory(uint32_t mask, char code) { + if ( (hist_seen & mask) == 0 ) { + hist_seen |= mask; + AddHistory(code); + return false; + } + else + return true; + } - // Increments the passed counter and adds it as a history - // code if it has crossed the next scaling threshold. Scaling - // is done in terms of powers of the third argument. - // Returns true if the threshold was crossed, false otherwise. - bool ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scaling_threshold, - uint32_t scaling_base = 10); + // Increments the passed counter and adds it as a history + // code if it has crossed the next scaling threshold. Scaling + // is done in terms of powers of the third argument. + // Returns true if the threshold was crossed, false otherwise. + bool ScaledHistoryEntry(char code, uint32_t& counter, uint32_t& scaling_threshold, uint32_t scaling_base = 10); - void HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t threshold); + void HistoryThresholdEvent(EventHandlerPtr e, bool is_orig, uint32_t threshold); - void AddHistory(char code) { history += code; } + void AddHistory(char code) { history += code; } - const std::string& GetHistory() const { return history; } - void ReplaceHistory(std::string new_h) { history = std::move(new_h); } + const std::string& GetHistory() const { return history; } + void ReplaceHistory(std::string new_h) { history = std::move(new_h); } - // Sets the root of the analyzer tree as well as the primary PIA. - void SetSessionAdapter(packet_analysis::IP::SessionAdapter* aa, analyzer::pia::PIA* pia); - packet_analysis::IP::SessionAdapter* GetSessionAdapter() { return adapter; } - analyzer::pia::PIA* GetPrimaryPIA() { return primary_PIA; } + // Sets the root of the analyzer tree as well as the primary PIA. + void SetSessionAdapter(packet_analysis::IP::SessionAdapter* aa, analyzer::pia::PIA* pia); + packet_analysis::IP::SessionAdapter* GetSessionAdapter() { return adapter; } + analyzer::pia::PIA* GetPrimaryPIA() { return primary_PIA; } - // Sets the transport protocol in use. - void SetTransport(TransportProto arg_proto) { proto = arg_proto; } + // Sets the transport protocol in use. + void SetTransport(TransportProto arg_proto) { proto = arg_proto; } - void SetUID(const UID& arg_uid) { uid = arg_uid; } + void SetUID(const UID& arg_uid) { uid = arg_uid; } - UID GetUID() const { return uid; } + UID GetUID() const { return uid; } - std::shared_ptr GetEncapsulation() const { return encapsulation; } + std::shared_ptr GetEncapsulation() const { return encapsulation; } - void CheckFlowLabel(bool is_orig, uint32_t flow_label); + void CheckFlowLabel(bool is_orig, uint32_t flow_label); - uint32_t GetOrigFlowLabel() { return orig_flow_label; } - uint32_t GetRespFlowLabel() { return resp_flow_label; } + uint32_t GetOrigFlowLabel() { return orig_flow_label; } + uint32_t GetRespFlowLabel() { return resp_flow_label; } - bool PermitWeird(const char* name, uint64_t threshold, uint64_t rate, double duration); + bool PermitWeird(const char* name, uint64_t threshold, uint64_t rate, double duration); private: - friend class session::detail::Timer; + friend class session::detail::Timer; - IPAddr orig_addr; - IPAddr resp_addr; - uint32_t orig_port, resp_port; // in network order - TransportProto proto; - uint32_t orig_flow_label, resp_flow_label; // most recent IPv6 flow labels - uint32_t vlan, inner_vlan; // VLAN this connection traverses, if available - u_char orig_l2_addr[Packet::L2_ADDR_LEN]; // Link-layer originator address, if available - u_char resp_l2_addr[Packet::L2_ADDR_LEN]; // Link-layer responder address, if available - int suppress_event; // suppress certain events to once per conn. - RecordValPtr conn_val; - std::shared_ptr encapsulation; // tunnels - uint8_t tunnel_changes = 0; + IPAddr orig_addr; + IPAddr resp_addr; + uint32_t orig_port, resp_port; // in network order + TransportProto proto; + uint32_t orig_flow_label, resp_flow_label; // most recent IPv6 flow labels + uint32_t vlan, inner_vlan; // VLAN this connection traverses, if available + u_char orig_l2_addr[Packet::L2_ADDR_LEN]; // Link-layer originator address, if available + u_char resp_l2_addr[Packet::L2_ADDR_LEN]; // Link-layer responder address, if available + int suppress_event; // suppress certain events to once per conn. + RecordValPtr conn_val; + std::shared_ptr encapsulation; // tunnels + uint8_t tunnel_changes = 0; - detail::ConnKey key; + detail::ConnKey key; - unsigned int weird : 1; - unsigned int finished : 1; - unsigned int saw_first_orig_packet : 1, saw_first_resp_packet : 1; + unsigned int weird : 1; + unsigned int finished : 1; + unsigned int saw_first_orig_packet : 1, saw_first_resp_packet : 1; - uint32_t hist_seen; - std::string history; + uint32_t hist_seen; + std::string history; - packet_analysis::IP::SessionAdapter* adapter; - analyzer::pia::PIA* primary_PIA; + packet_analysis::IP::SessionAdapter* adapter; + analyzer::pia::PIA* primary_PIA; - UID uid; // Globally unique connection ID. - detail::WeirdStateMap weird_state; + UID uid; // Globally unique connection ID. + detail::WeirdStateMap weird_state; - // Count number of connections. - static uint64_t total_connections; - static uint64_t current_connections; - }; + // Count number of connections. + static uint64_t total_connections; + static uint64_t current_connections; +}; - } // namespace zeek +} // namespace zeek diff --git a/src/DFA.cc b/src/DFA.cc index c0ed843281..b9d62f8db4 100644 --- a/src/DFA.cc +++ b/src/DFA.cc @@ -8,455 +8,390 @@ #include "zeek/EquivClass.h" #include "zeek/Hash.h" -namespace zeek::detail - { +namespace zeek::detail { DFA_State::DFA_State(int arg_state_num, const EquivClass* ec, NFA_state_list* arg_nfa_states, - AcceptingSet* arg_accept) - { - state_num = arg_state_num; - num_sym = ec->NumClasses(); - nfa_states = arg_nfa_states; - accept = arg_accept; - mark = nullptr; - - SymPartition(ec); - - xtions = new DFA_State*[num_sym]; - - for ( int i = 0; i < num_sym; ++i ) - xtions[i] = DFA_UNCOMPUTED_STATE_PTR; - } - -DFA_State::~DFA_State() - { - delete[] xtions; - delete nfa_states; - delete accept; - delete meta_ec; - } - -void DFA_State::AddXtion(int sym, DFA_State* next_state) - { - xtions[sym] = next_state; - } - -void DFA_State::SymPartition(const EquivClass* ec) - { - // Partitioning is done by creating equivalence classes for those - // characters which have out-transitions from the given state. Thus - // we are really creating equivalence classes of equivalence classes. - meta_ec = new EquivClass(ec->NumClasses()); - - assert(nfa_states); - for ( int i = 0; i < nfa_states->length(); ++i ) - { - NFA_State* n = (*nfa_states)[i]; - int sym = n->TransSym(); - - if ( sym == SYM_EPSILON ) - continue; - - if ( sym != SYM_CCL ) - { // character transition - if ( ec->IsRep(sym) ) - { - sym = ec->SymEquivClass(sym); - meta_ec->UniqueChar(sym); - } - continue; - } - - // Character class. - meta_ec->CCL_Use(n->TransCCL()); - } - - meta_ec->BuildECs(); - } - -DFA_State* DFA_State::ComputeXtion(int sym, DFA_Machine* machine) - { - int equiv_sym = meta_ec->EquivRep(sym); - if ( xtions[equiv_sym] != DFA_UNCOMPUTED_STATE_PTR ) - { - AddXtion(sym, xtions[equiv_sym]); - return xtions[sym]; - } - - const EquivClass* ec = machine->EC(); - - DFA_State* next_d; - - NFA_state_list* ns = SymFollowSet(equiv_sym, ec); - if ( ns->length() > 0 ) - { - NFA_state_list* state_set = epsilon_closure(ns); - if ( ! machine->StateSetToDFA_State(state_set, next_d, ec) ) - delete state_set; - } - else - { - delete ns; - next_d = nullptr; // Jam - } - - AddXtion(equiv_sym, next_d); - if ( sym != equiv_sym ) - AddXtion(sym, next_d); - - return xtions[sym]; - } - -void DFA_State::AppendIfNew(int sym, int_list* sym_list) - { - for ( auto value : *sym_list ) - if ( value == sym ) - return; - - sym_list->push_back(sym); - } - -NFA_state_list* DFA_State::SymFollowSet(int ec_sym, const EquivClass* ec) - { - NFA_state_list* ns = new NFA_state_list; - - assert(nfa_states); - - for ( int i = 0; i < nfa_states->length(); ++i ) - { - NFA_State* n = (*nfa_states)[i]; - - if ( n->TransSym() == SYM_CCL ) - { // it's a character class - CCL* ccl = n->TransCCL(); - int_list* syms = ccl->Syms(); - - if ( ccl->IsNegated() ) - { - size_t j; - for ( j = 0; j < syms->size(); ++j ) - { - // Loop through (sorted) negated - // character class, which has - // presumably already been converted - // over to equivalence classes. - if ( (*syms)[j] >= ec_sym ) - break; - } - - if ( j >= syms->size() || (*syms)[j] > ec_sym ) - // Didn't find ec_sym in ccl. - n->AddXtionsTo(ns); - - continue; - } - - for ( auto sym : *syms ) - { - if ( sym > ec_sym ) - break; - - if ( sym == ec_sym ) - { - n->AddXtionsTo(ns); - break; - } - } - } - - else if ( n->TransSym() == SYM_EPSILON ) - { // do nothing - } - - else if ( ec->IsRep(n->TransSym()) ) - { - if ( ec_sym == ec->SymEquivClass(n->TransSym()) ) - n->AddXtionsTo(ns); - } - } - - ns->resize(0); - return ns; - } - -void DFA_State::ClearMarks() - { - if ( mark ) - { - SetMark(nullptr); - - for ( int i = 0; i < num_sym; ++i ) - { - DFA_State* s = xtions[i]; - - if ( s && s != DFA_UNCOMPUTED_STATE_PTR ) - xtions[i]->ClearMarks(); - } - } - } - -void DFA_State::Describe(ODesc* d) const - { - d->Add("DFA state"); - } - -void DFA_State::Dump(FILE* f, DFA_Machine* m) - { - if ( mark ) - return; - - fprintf(f, "\nDFA state %d:", StateNum()); - - if ( accept ) - { - AcceptingSet::const_iterator it; - - for ( it = accept->begin(); it != accept->end(); ++it ) - fprintf(f, "%s accept #%d", it == accept->begin() ? "" : ",", *it); - } - - fprintf(f, "\n"); - - int num_trans = 0; - for ( int sym = 0; sym < num_sym; ++sym ) - { - DFA_State* s = xtions[sym]; - - if ( ! s ) - continue; - - // Look ahead for compression. - int i; - for ( i = sym + 1; i < num_sym; ++i ) - if ( xtions[i] != s ) - break; - - constexpr int xbuf_size = 512; - char* xbuf = new char[xbuf_size]; - - int r = m->Rep(sym); - if ( ! r ) - r = '.'; - - if ( i == sym + 1 ) - snprintf(xbuf, xbuf_size, "'%c'", r); - else - snprintf(xbuf, xbuf_size, "'%c'-'%c'", r, m->Rep(i - 1)); - - if ( s == DFA_UNCOMPUTED_STATE_PTR ) - fprintf(f, "%stransition on %s to ", ++num_trans == 1 ? "\t" : "\n\t", - xbuf); - else - fprintf(f, "%stransition on %s to state %d", ++num_trans == 1 ? "\t" : "\n\t", xbuf, - s->StateNum()); - - delete[] xbuf; - - sym = i - 1; - } - - if ( num_trans > 0 ) - fprintf(f, "\n"); - - SetMark(this); - - for ( int sym = 0; sym < num_sym; ++sym ) - { - DFA_State* s = xtions[sym]; - - if ( s && s != DFA_UNCOMPUTED_STATE_PTR ) - s->Dump(f, m); - } - } - -void DFA_State::Stats(unsigned int* computed, unsigned int* uncomputed) - { - for ( int sym = 0; sym < num_sym; ++sym ) - { - DFA_State* s = xtions[sym]; - - if ( s == DFA_UNCOMPUTED_STATE_PTR ) - (*uncomputed)++; - else - (*computed)++; - } - } - -unsigned int DFA_State::Size() - { - return sizeof(*this) + util::pad_size(sizeof(DFA_State*) * num_sym) + - (accept ? util::pad_size(sizeof(int) * accept->size()) : 0) + - (nfa_states ? util::pad_size(sizeof(NFA_State*) * nfa_states->length()) : 0) + - (meta_ec ? meta_ec->Size() : 0); - } - -DFA_State_Cache::DFA_State_Cache() - { - hits = misses = 0; - } - -DFA_State_Cache::~DFA_State_Cache() - { - for ( auto& entry : states ) - { - assert(entry.second); - Unref(entry.second); - } - - states.clear(); - } - -DFA_State* DFA_State_Cache::Lookup(const NFA_state_list& nfas, DigestStr* digest) - { - // We assume that state ID's don't exceed 10 digits, plus - // we allow one more character for the delimiter. - auto id_tag_buf = std::make_unique(nfas.length() * 11 + 1); - auto id_tag = id_tag_buf.get(); - u_char* p = id_tag; - - for ( int i = 0; i < nfas.length(); ++i ) - { - NFA_State* n = nfas[i]; - if ( n->TransSym() != SYM_EPSILON || n->Accept() != NO_ACCEPT ) - { - int id = n->ID(); - do - { - *p++ = '0' + (char)(id % 10); - id /= 10; - } while ( id > 0 ); - *p++ = '&'; - } - } - - *p++ = '\0'; - - // We use the short MD5 instead of the full string for the - // HashKey because the data is copied into the key. - hash128_t hash; - KeyedHash::Hash128(id_tag, p - id_tag, &hash); - *digest = DigestStr(reinterpret_cast(hash), 16); - - auto entry = states.find(*digest); - if ( entry == states.end() ) - { - ++misses; - return nullptr; - } - ++hits; - - digest->clear(); - - return entry->second; - } - -DFA_State* DFA_State_Cache::Insert(DFA_State* state, DigestStr digest) - { - states.emplace(std::move(digest), state); - return state; - } - -void DFA_State_Cache::GetStats(Stats* s) - { - s->dfa_states = 0; - s->nfa_states = 0; - s->computed = 0; - s->uncomputed = 0; - s->mem = 0; - s->hits = hits; - s->misses = misses; - - for ( const auto& state : states ) - { - DFA_State* e = state.second; - ++s->dfa_states; - s->nfa_states += e->NFAStateNum(); - e->Stats(&s->computed, &s->uncomputed); - s->mem += util::pad_size(e->Size()) + padded_sizeof(*e); - } - } - -DFA_Machine::DFA_Machine(NFA_Machine* n, EquivClass* arg_ec) - { - state_count = 0; - - nfa = n; - Ref(n); - - ec = arg_ec; - - dfa_state_cache = new DFA_State_Cache(); - - NFA_state_list* ns = new NFA_state_list; - ns->push_back(n->FirstState()); - - if ( ns->length() > 0 ) - { - NFA_state_list* state_set = epsilon_closure(ns); - StateSetToDFA_State(state_set, start_state, ec); - } - else - { - start_state = nullptr; // Jam - delete ns; - } - } - -DFA_Machine::~DFA_Machine() - { - delete dfa_state_cache; - Unref(nfa); - } - -void DFA_Machine::Describe(ODesc* d) const - { - d->Add("DFA machine"); - } - -void DFA_Machine::Dump(FILE* f) - { - start_state->Dump(f, this); - start_state->ClearMarks(); - } - -bool DFA_Machine::StateSetToDFA_State(NFA_state_list* state_set, DFA_State*& d, - const EquivClass* ec) - { - DigestStr digest; - d = dfa_state_cache->Lookup(*state_set, &digest); - - if ( d ) - return false; - - AcceptingSet* accept = new AcceptingSet; - - for ( int i = 0; i < state_set->length(); ++i ) - { - int acc = (*state_set)[i]->Accept(); - - if ( acc != NO_ACCEPT ) - accept->insert(acc); - } - - if ( accept->empty() ) - { - delete accept; - accept = nullptr; - } - - DFA_State* ds = new DFA_State(state_count++, ec, state_set, accept); - d = dfa_state_cache->Insert(ds, std::move(digest)); - - return true; - } - -int DFA_Machine::Rep(int sym) - { - for ( int i = 0; i < NUM_SYM; ++i ) - if ( ec->SymEquivClass(i) == sym ) - return i; - - return -1; - } - - } // namespace zeek::detail + AcceptingSet* arg_accept) { + state_num = arg_state_num; + num_sym = ec->NumClasses(); + nfa_states = arg_nfa_states; + accept = arg_accept; + mark = nullptr; + + SymPartition(ec); + + xtions = new DFA_State*[num_sym]; + + for ( int i = 0; i < num_sym; ++i ) + xtions[i] = DFA_UNCOMPUTED_STATE_PTR; +} + +DFA_State::~DFA_State() { + delete[] xtions; + delete nfa_states; + delete accept; + delete meta_ec; +} + +void DFA_State::AddXtion(int sym, DFA_State* next_state) { xtions[sym] = next_state; } + +void DFA_State::SymPartition(const EquivClass* ec) { + // Partitioning is done by creating equivalence classes for those + // characters which have out-transitions from the given state. Thus + // we are really creating equivalence classes of equivalence classes. + meta_ec = new EquivClass(ec->NumClasses()); + + assert(nfa_states); + for ( int i = 0; i < nfa_states->length(); ++i ) { + NFA_State* n = (*nfa_states)[i]; + int sym = n->TransSym(); + + if ( sym == SYM_EPSILON ) + continue; + + if ( sym != SYM_CCL ) { // character transition + if ( ec->IsRep(sym) ) { + sym = ec->SymEquivClass(sym); + meta_ec->UniqueChar(sym); + } + continue; + } + + // Character class. + meta_ec->CCL_Use(n->TransCCL()); + } + + meta_ec->BuildECs(); +} + +DFA_State* DFA_State::ComputeXtion(int sym, DFA_Machine* machine) { + int equiv_sym = meta_ec->EquivRep(sym); + if ( xtions[equiv_sym] != DFA_UNCOMPUTED_STATE_PTR ) { + AddXtion(sym, xtions[equiv_sym]); + return xtions[sym]; + } + + const EquivClass* ec = machine->EC(); + + DFA_State* next_d; + + NFA_state_list* ns = SymFollowSet(equiv_sym, ec); + if ( ns->length() > 0 ) { + NFA_state_list* state_set = epsilon_closure(ns); + if ( ! machine->StateSetToDFA_State(state_set, next_d, ec) ) + delete state_set; + } + else { + delete ns; + next_d = nullptr; // Jam + } + + AddXtion(equiv_sym, next_d); + if ( sym != equiv_sym ) + AddXtion(sym, next_d); + + return xtions[sym]; +} + +void DFA_State::AppendIfNew(int sym, int_list* sym_list) { + for ( auto value : *sym_list ) + if ( value == sym ) + return; + + sym_list->push_back(sym); +} + +NFA_state_list* DFA_State::SymFollowSet(int ec_sym, const EquivClass* ec) { + NFA_state_list* ns = new NFA_state_list; + + assert(nfa_states); + + for ( int i = 0; i < nfa_states->length(); ++i ) { + NFA_State* n = (*nfa_states)[i]; + + if ( n->TransSym() == SYM_CCL ) { // it's a character class + CCL* ccl = n->TransCCL(); + int_list* syms = ccl->Syms(); + + if ( ccl->IsNegated() ) { + size_t j; + for ( j = 0; j < syms->size(); ++j ) { + // Loop through (sorted) negated + // character class, which has + // presumably already been converted + // over to equivalence classes. + if ( (*syms)[j] >= ec_sym ) + break; + } + + if ( j >= syms->size() || (*syms)[j] > ec_sym ) + // Didn't find ec_sym in ccl. + n->AddXtionsTo(ns); + + continue; + } + + for ( auto sym : *syms ) { + if ( sym > ec_sym ) + break; + + if ( sym == ec_sym ) { + n->AddXtionsTo(ns); + break; + } + } + } + + else if ( n->TransSym() == SYM_EPSILON ) { // do nothing + } + + else if ( ec->IsRep(n->TransSym()) ) { + if ( ec_sym == ec->SymEquivClass(n->TransSym()) ) + n->AddXtionsTo(ns); + } + } + + ns->resize(0); + return ns; +} + +void DFA_State::ClearMarks() { + if ( mark ) { + SetMark(nullptr); + + for ( int i = 0; i < num_sym; ++i ) { + DFA_State* s = xtions[i]; + + if ( s && s != DFA_UNCOMPUTED_STATE_PTR ) + xtions[i]->ClearMarks(); + } + } +} + +void DFA_State::Describe(ODesc* d) const { d->Add("DFA state"); } + +void DFA_State::Dump(FILE* f, DFA_Machine* m) { + if ( mark ) + return; + + fprintf(f, "\nDFA state %d:", StateNum()); + + if ( accept ) { + AcceptingSet::const_iterator it; + + for ( it = accept->begin(); it != accept->end(); ++it ) + fprintf(f, "%s accept #%d", it == accept->begin() ? "" : ",", *it); + } + + fprintf(f, "\n"); + + int num_trans = 0; + for ( int sym = 0; sym < num_sym; ++sym ) { + DFA_State* s = xtions[sym]; + + if ( ! s ) + continue; + + // Look ahead for compression. + int i; + for ( i = sym + 1; i < num_sym; ++i ) + if ( xtions[i] != s ) + break; + + constexpr int xbuf_size = 512; + char* xbuf = new char[xbuf_size]; + + int r = m->Rep(sym); + if ( ! r ) + r = '.'; + + if ( i == sym + 1 ) + snprintf(xbuf, xbuf_size, "'%c'", r); + else + snprintf(xbuf, xbuf_size, "'%c'-'%c'", r, m->Rep(i - 1)); + + if ( s == DFA_UNCOMPUTED_STATE_PTR ) + fprintf(f, "%stransition on %s to ", ++num_trans == 1 ? "\t" : "\n\t", xbuf); + else + fprintf(f, "%stransition on %s to state %d", ++num_trans == 1 ? "\t" : "\n\t", xbuf, s->StateNum()); + + delete[] xbuf; + + sym = i - 1; + } + + if ( num_trans > 0 ) + fprintf(f, "\n"); + + SetMark(this); + + for ( int sym = 0; sym < num_sym; ++sym ) { + DFA_State* s = xtions[sym]; + + if ( s && s != DFA_UNCOMPUTED_STATE_PTR ) + s->Dump(f, m); + } +} + +void DFA_State::Stats(unsigned int* computed, unsigned int* uncomputed) { + for ( int sym = 0; sym < num_sym; ++sym ) { + DFA_State* s = xtions[sym]; + + if ( s == DFA_UNCOMPUTED_STATE_PTR ) + (*uncomputed)++; + else + (*computed)++; + } +} + +unsigned int DFA_State::Size() { + return sizeof(*this) + util::pad_size(sizeof(DFA_State*) * num_sym) + + (accept ? util::pad_size(sizeof(int) * accept->size()) : 0) + + (nfa_states ? util::pad_size(sizeof(NFA_State*) * nfa_states->length()) : 0) + + (meta_ec ? meta_ec->Size() : 0); +} + +DFA_State_Cache::DFA_State_Cache() { hits = misses = 0; } + +DFA_State_Cache::~DFA_State_Cache() { + for ( auto& entry : states ) { + assert(entry.second); + Unref(entry.second); + } + + states.clear(); +} + +DFA_State* DFA_State_Cache::Lookup(const NFA_state_list& nfas, DigestStr* digest) { + // We assume that state ID's don't exceed 10 digits, plus + // we allow one more character for the delimiter. + auto id_tag_buf = std::make_unique(nfas.length() * 11 + 1); + auto id_tag = id_tag_buf.get(); + u_char* p = id_tag; + + for ( int i = 0; i < nfas.length(); ++i ) { + NFA_State* n = nfas[i]; + if ( n->TransSym() != SYM_EPSILON || n->Accept() != NO_ACCEPT ) { + int id = n->ID(); + do { + *p++ = '0' + (char)(id % 10); + id /= 10; + } while ( id > 0 ); + *p++ = '&'; + } + } + + *p++ = '\0'; + + // We use the short MD5 instead of the full string for the + // HashKey because the data is copied into the key. + hash128_t hash; + KeyedHash::Hash128(id_tag, p - id_tag, &hash); + *digest = DigestStr(reinterpret_cast(hash), 16); + + auto entry = states.find(*digest); + if ( entry == states.end() ) { + ++misses; + return nullptr; + } + ++hits; + + digest->clear(); + + return entry->second; +} + +DFA_State* DFA_State_Cache::Insert(DFA_State* state, DigestStr digest) { + states.emplace(std::move(digest), state); + return state; +} + +void DFA_State_Cache::GetStats(Stats* s) { + s->dfa_states = 0; + s->nfa_states = 0; + s->computed = 0; + s->uncomputed = 0; + s->mem = 0; + s->hits = hits; + s->misses = misses; + + for ( const auto& state : states ) { + DFA_State* e = state.second; + ++s->dfa_states; + s->nfa_states += e->NFAStateNum(); + e->Stats(&s->computed, &s->uncomputed); + s->mem += util::pad_size(e->Size()) + padded_sizeof(*e); + } +} + +DFA_Machine::DFA_Machine(NFA_Machine* n, EquivClass* arg_ec) { + state_count = 0; + + nfa = n; + Ref(n); + + ec = arg_ec; + + dfa_state_cache = new DFA_State_Cache(); + + NFA_state_list* ns = new NFA_state_list; + ns->push_back(n->FirstState()); + + if ( ns->length() > 0 ) { + NFA_state_list* state_set = epsilon_closure(ns); + StateSetToDFA_State(state_set, start_state, ec); + } + else { + start_state = nullptr; // Jam + delete ns; + } +} + +DFA_Machine::~DFA_Machine() { + delete dfa_state_cache; + Unref(nfa); +} + +void DFA_Machine::Describe(ODesc* d) const { d->Add("DFA machine"); } + +void DFA_Machine::Dump(FILE* f) { + start_state->Dump(f, this); + start_state->ClearMarks(); +} + +bool DFA_Machine::StateSetToDFA_State(NFA_state_list* state_set, DFA_State*& d, const EquivClass* ec) { + DigestStr digest; + d = dfa_state_cache->Lookup(*state_set, &digest); + + if ( d ) + return false; + + AcceptingSet* accept = new AcceptingSet; + + for ( int i = 0; i < state_set->length(); ++i ) { + int acc = (*state_set)[i]->Accept(); + + if ( acc != NO_ACCEPT ) + accept->insert(acc); + } + + if ( accept->empty() ) { + delete accept; + accept = nullptr; + } + + DFA_State* ds = new DFA_State(state_count++, ec, state_set, accept); + d = dfa_state_cache->Insert(ds, std::move(digest)); + + return true; +} + +int DFA_Machine::Rep(int sym) { + for ( int i = 0; i < NUM_SYM; ++i ) + if ( ec->SymEquivClass(i) == sym ) + return i; + + return -1; +} + +} // namespace zeek::detail diff --git a/src/DFA.h b/src/DFA.h index 1af46d647a..fd0b590025 100644 --- a/src/DFA.h +++ b/src/DFA.h @@ -11,8 +11,7 @@ #include "zeek/Obj.h" #include "zeek/RE.h" // for typedef AcceptingSet -namespace zeek::detail - { +namespace zeek::detail { class DFA_State; class DFA_Machine; @@ -22,132 +21,126 @@ class DFA_Machine; #define DFA_UNCOMPUTED_STATE -2 #define DFA_UNCOMPUTED_STATE_PTR ((DFA_State*)DFA_UNCOMPUTED_STATE) -class DFA_State : public Obj - { +class DFA_State : public Obj { public: - DFA_State(int state_num, const EquivClass* ec, NFA_state_list* nfa_states, - AcceptingSet* accept); - ~DFA_State() override; + DFA_State(int state_num, const EquivClass* ec, NFA_state_list* nfa_states, AcceptingSet* accept); + ~DFA_State() override; - int StateNum() const { return state_num; } - int NFAStateNum() const { return nfa_states->length(); } - void AddXtion(int sym, DFA_State* next_state); + int StateNum() const { return state_num; } + int NFAStateNum() const { return nfa_states->length(); } + void AddXtion(int sym, DFA_State* next_state); - inline DFA_State* Xtion(int sym, DFA_Machine* machine); + inline DFA_State* Xtion(int sym, DFA_Machine* machine); - const AcceptingSet* Accept() const { return accept; } - void SymPartition(const EquivClass* ec); + const AcceptingSet* Accept() const { return accept; } + void SymPartition(const EquivClass* ec); - // ec_sym is an equivalence class, not a character. - NFA_state_list* SymFollowSet(int ec_sym, const EquivClass* ec); + // ec_sym is an equivalence class, not a character. + NFA_state_list* SymFollowSet(int ec_sym, const EquivClass* ec); - void SetMark(DFA_State* m) { mark = m; } - DFA_State* Mark() const { return mark; } - void ClearMarks(); + void SetMark(DFA_State* m) { mark = m; } + DFA_State* Mark() const { return mark; } + void ClearMarks(); - // Returns the equivalence classes of ec's corresponding to this state. - const EquivClass* MetaECs() const { return meta_ec; } + // Returns the equivalence classes of ec's corresponding to this state. + const EquivClass* MetaECs() const { return meta_ec; } - void Describe(ODesc* d) const override; - void Dump(FILE* f, DFA_Machine* m); - void Stats(unsigned int* computed, unsigned int* uncomputed); - unsigned int Size(); + void Describe(ODesc* d) const override; + void Dump(FILE* f, DFA_Machine* m); + void Stats(unsigned int* computed, unsigned int* uncomputed); + unsigned int Size(); protected: - friend class DFA_State_Cache; + friend class DFA_State_Cache; - DFA_State* ComputeXtion(int sym, DFA_Machine* machine); - void AppendIfNew(int sym, int_list* sym_list); + DFA_State* ComputeXtion(int sym, DFA_Machine* machine); + void AppendIfNew(int sym, int_list* sym_list); - int state_num; - int num_sym; + int state_num; + int num_sym; - DFA_State** xtions; + DFA_State** xtions; - AcceptingSet* accept; - NFA_state_list* nfa_states; - EquivClass* meta_ec; // which ec's make same transition - DFA_State* mark; - }; + AcceptingSet* accept; + NFA_state_list* nfa_states; + EquivClass* meta_ec; // which ec's make same transition + DFA_State* mark; +}; using DigestStr = std::basic_string; -class DFA_State_Cache - { +class DFA_State_Cache { public: - DFA_State_Cache(); - ~DFA_State_Cache(); + DFA_State_Cache(); + ~DFA_State_Cache(); - // If the caller stores the handle, it has to call Ref() on it. - DFA_State* Lookup(const NFA_state_list& nfa_states, DigestStr* digest); + // If the caller stores the handle, it has to call Ref() on it. + DFA_State* Lookup(const NFA_state_list& nfa_states, DigestStr* digest); - // Takes ownership of state; digest is the one returned by Lookup(). - DFA_State* Insert(DFA_State* state, DigestStr digest); + // Takes ownership of state; digest is the one returned by Lookup(). + DFA_State* Insert(DFA_State* state, DigestStr digest); - int NumEntries() const { return states.size(); } + int NumEntries() const { return states.size(); } - struct Stats - { - // Sum of all NFA states - unsigned int nfa_states; - unsigned int dfa_states; - unsigned int computed; - unsigned int uncomputed; - unsigned int mem; - unsigned int hits; - unsigned int misses; - }; + struct Stats { + // Sum of all NFA states + unsigned int nfa_states; + unsigned int dfa_states; + unsigned int computed; + unsigned int uncomputed; + unsigned int mem; + unsigned int hits; + unsigned int misses; + }; - void GetStats(Stats* s); + void GetStats(Stats* s); private: - int hits; // Statistics - int misses; + int hits; // Statistics + int misses; - // Hash indexed by NFA states (MD5s of them, actually). - std::map states; - }; + // Hash indexed by NFA states (MD5s of them, actually). + std::map states; +}; -class DFA_Machine : public Obj - { +class DFA_Machine : public Obj { public: - DFA_Machine(NFA_Machine* n, EquivClass* ec); - ~DFA_Machine() override; + DFA_Machine(NFA_Machine* n, EquivClass* ec); + ~DFA_Machine() override; - DFA_State* StartState() const { return start_state; } + DFA_State* StartState() const { return start_state; } - int NumStates() const { return dfa_state_cache->NumEntries(); } + int NumStates() const { return dfa_state_cache->NumEntries(); } - DFA_State_Cache* Cache() { return dfa_state_cache; } + DFA_State_Cache* Cache() { return dfa_state_cache; } - int Rep(int sym); + int Rep(int sym); - void Describe(ODesc* d) const override; - void Dump(FILE* f); + void Describe(ODesc* d) const override; + void Dump(FILE* f); protected: - friend class DFA_State; // for DFA_State::ComputeXtion - friend class DFA_State_Cache; + friend class DFA_State; // for DFA_State::ComputeXtion + friend class DFA_State_Cache; - int state_count; + int state_count; - // The state list has to be sorted according to IDs. - bool StateSetToDFA_State(NFA_state_list* state_set, DFA_State*& d, const EquivClass* ec); - const EquivClass* EC() const { return ec; } + // The state list has to be sorted according to IDs. + bool StateSetToDFA_State(NFA_state_list* state_set, DFA_State*& d, const EquivClass* ec); + const EquivClass* EC() const { return ec; } - EquivClass* ec; // equivalence classes corresponding to NFAs - DFA_State* start_state; - DFA_State_Cache* dfa_state_cache; + EquivClass* ec; // equivalence classes corresponding to NFAs + DFA_State* start_state; + DFA_State_Cache* dfa_state_cache; - NFA_Machine* nfa; - }; + NFA_Machine* nfa; +}; -inline DFA_State* DFA_State::Xtion(int sym, DFA_Machine* machine) - { - if ( xtions[sym] == DFA_UNCOMPUTED_STATE_PTR ) - return ComputeXtion(sym, machine); - else - return xtions[sym]; - } +inline DFA_State* DFA_State::Xtion(int sym, DFA_Machine* machine) { + if ( xtions[sym] == DFA_UNCOMPUTED_STATE_PTR ) + return ComputeXtion(sym, machine); + else + return xtions[sym]; +} - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/DNS_Mapping.cc b/src/DNS_Mapping.cc index b2f6c222f2..6e2720f4b0 100644 --- a/src/DNS_Mapping.cc +++ b/src/DNS_Mapping.cc @@ -6,428 +6,399 @@ #include "zeek/DNS_Mgr.h" #include "zeek/Reporter.h" -namespace zeek::detail - { +namespace zeek::detail { -DNS_Mapping::DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type) - { - Init(h); - req_host = host; - req_ttl = ttl; - req_type = type; +DNS_Mapping::DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type) { + Init(h); + req_host = host; + req_ttl = ttl; + req_type = type; - if ( names.empty() ) - names.push_back(std::move(host)); - } + if ( names.empty() ) + names.push_back(std::move(host)); +} -DNS_Mapping::DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl) - { - Init(h); - req_addr = addr; - req_ttl = ttl; - req_type = T_PTR; - } +DNS_Mapping::DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl) { + Init(h); + req_addr = addr; + req_ttl = ttl; + req_type = T_PTR; +} -DNS_Mapping::DNS_Mapping(FILE* f) - { - Clear(); - init_failed = true; +DNS_Mapping::DNS_Mapping(FILE* f) { + Clear(); + init_failed = true; - req_ttl = 0; - creation_time = 0; + req_ttl = 0; + creation_time = 0; - char buf[512]; + char buf[512]; - if ( ! fgets(buf, sizeof(buf), f) ) - { - no_mapping = true; - return; - } + if ( ! fgets(buf, sizeof(buf), f) ) { + no_mapping = true; + return; + } - char req_buf[512 + 1], name_buf[512 + 1]; - int is_req_host; - int failed_local; - int num_addrs; + char req_buf[512 + 1], name_buf[512 + 1]; + int is_req_host; + int failed_local; + int num_addrs; - if ( sscanf(buf, "%lf %d %512s %d %512s %d %d %" PRIu32, &creation_time, &is_req_host, req_buf, - &failed_local, name_buf, &req_type, &num_addrs, &req_ttl) != 8 ) - { - no_mapping = true; - return; - } + if ( sscanf(buf, "%lf %d %512s %d %512s %d %d %" PRIu32, &creation_time, &is_req_host, req_buf, &failed_local, + name_buf, &req_type, &num_addrs, &req_ttl) != 8 ) { + no_mapping = true; + return; + } - failed = static_cast(failed_local); + failed = static_cast(failed_local); - if ( is_req_host ) - req_host = req_buf; - else - req_addr = IPAddr(req_buf); + if ( is_req_host ) + req_host = req_buf; + else + req_addr = IPAddr(req_buf); - names.emplace_back(name_buf); + names.emplace_back(name_buf); - for ( int i = 0; i < num_addrs; ++i ) - { - if ( ! fgets(buf, sizeof(buf), f) ) - return; + for ( int i = 0; i < num_addrs; ++i ) { + if ( ! fgets(buf, sizeof(buf), f) ) + return; - char* newline = strchr(buf, '\n'); - if ( newline ) - *newline = '\0'; + char* newline = strchr(buf, '\n'); + if ( newline ) + *newline = '\0'; - addrs.emplace_back(buf); - } + addrs.emplace_back(buf); + } - init_failed = false; - } + init_failed = false; +} -ListValPtr DNS_Mapping::Addrs() - { - if ( failed ) - return nullptr; +ListValPtr DNS_Mapping::Addrs() { + if ( failed ) + return nullptr; - if ( ! addrs_val ) - { - addrs_val = make_intrusive(TYPE_ADDR); + if ( ! addrs_val ) { + addrs_val = make_intrusive(TYPE_ADDR); - for ( const auto& addr : addrs ) - addrs_val->Append(make_intrusive(addr)); - } + for ( const auto& addr : addrs ) + addrs_val->Append(make_intrusive(addr)); + } - return addrs_val; - } + return addrs_val; +} -TableValPtr DNS_Mapping::AddrsSet() - { - auto l = Addrs(); +TableValPtr DNS_Mapping::AddrsSet() { + auto l = Addrs(); - if ( ! l || l->Length() == 0 ) - return DNS_Mgr::empty_addr_set(); + if ( ! l || l->Length() == 0 ) + return DNS_Mgr::empty_addr_set(); - return l->ToSetVal(); - } + return l->ToSetVal(); +} -StringValPtr DNS_Mapping::Host() - { - if ( failed || names.empty() ) - return nullptr; +StringValPtr DNS_Mapping::Host() { + if ( failed || names.empty() ) + return nullptr; - if ( ! host_val ) - host_val = make_intrusive(names[0]); + if ( ! host_val ) + host_val = make_intrusive(names[0]); - return host_val; - } + return host_val; +} -void DNS_Mapping::Init(struct hostent* h) - { - no_mapping = false; - init_failed = false; - creation_time = util::current_time(); - host_val = nullptr; - addrs_val = nullptr; +void DNS_Mapping::Init(struct hostent* h) { + no_mapping = false; + init_failed = false; + creation_time = util::current_time(); + host_val = nullptr; + addrs_val = nullptr; - if ( ! h ) - { - Clear(); - return; - } + if ( ! h ) { + Clear(); + return; + } - if ( h->h_name ) - // for now, just use the official name - // TODO: this could easily be expanded to include all of the aliases as well - names.emplace_back(h->h_name); + if ( h->h_name ) + // for now, just use the official name + // TODO: this could easily be expanded to include all of the aliases as well + names.emplace_back(h->h_name); - if ( h->h_addr_list ) - { - for ( int i = 0; h->h_addr_list[i] != NULL; ++i ) - { - if ( h->h_addrtype == AF_INET ) - addrs.emplace_back(IPv4, (uint32_t*)h->h_addr_list[i], IPAddr::Network); - else if ( h->h_addrtype == AF_INET6 ) - addrs.emplace_back(IPv6, (uint32_t*)h->h_addr_list[i], IPAddr::Network); - } - } + if ( h->h_addr_list ) { + for ( int i = 0; h->h_addr_list[i] != NULL; ++i ) { + if ( h->h_addrtype == AF_INET ) + addrs.emplace_back(IPv4, (uint32_t*)h->h_addr_list[i], IPAddr::Network); + else if ( h->h_addrtype == AF_INET6 ) + addrs.emplace_back(IPv6, (uint32_t*)h->h_addr_list[i], IPAddr::Network); + } + } - failed = false; - } + failed = false; +} -void DNS_Mapping::Clear() - { - names.clear(); - host_val = nullptr; - addrs.clear(); - addrs_val = nullptr; - no_mapping = false; - req_type = 0; - failed = true; - } +void DNS_Mapping::Clear() { + names.clear(); + host_val = nullptr; + addrs.clear(); + addrs_val = nullptr; + no_mapping = false; + req_type = 0; + failed = true; +} -void DNS_Mapping::Save(FILE* f) const - { - fprintf(f, "%.0f %d %s %d %s %d %zu %" PRIu32 "\n", creation_time, ! req_host.empty(), - req_host.empty() ? req_addr.AsString().c_str() : req_host.c_str(), failed, - names.empty() ? "*" : names[0].c_str(), req_type, addrs.size(), req_ttl); +void DNS_Mapping::Save(FILE* f) const { + fprintf(f, "%.0f %d %s %d %s %d %zu %" PRIu32 "\n", creation_time, ! req_host.empty(), + req_host.empty() ? req_addr.AsString().c_str() : req_host.c_str(), failed, + names.empty() ? "*" : names[0].c_str(), req_type, addrs.size(), req_ttl); - for ( const auto& addr : addrs ) - fprintf(f, "%s\n", addr.AsString().c_str()); - } + for ( const auto& addr : addrs ) + fprintf(f, "%s\n", addr.AsString().c_str()); +} -void DNS_Mapping::Merge(const DNS_MappingPtr& other) - { - std::copy(other->names.begin(), other->names.end(), std::back_inserter(names)); - std::copy(other->addrs.begin(), other->addrs.end(), std::back_inserter(addrs)); - } +void DNS_Mapping::Merge(const DNS_MappingPtr& other) { + std::copy(other->names.begin(), other->names.end(), std::back_inserter(names)); + std::copy(other->addrs.begin(), other->addrs.end(), std::back_inserter(addrs)); +} // This value needs to be incremented if something changes in the data stored by Save(). This // allows us to change the structure of the cache without breaking something in DNS_Mgr. constexpr int FILE_VERSION = 1; -void DNS_Mapping::InitializeCache(FILE* f) - { - fprintf(f, "%d\n", FILE_VERSION); - } +void DNS_Mapping::InitializeCache(FILE* f) { fprintf(f, "%d\n", FILE_VERSION); } -bool DNS_Mapping::ValidateCacheVersion(FILE* f) - { - char buf[512]; - if ( ! fgets(buf, sizeof(buf), f) ) - return false; +bool DNS_Mapping::ValidateCacheVersion(FILE* f) { + char buf[512]; + if ( ! fgets(buf, sizeof(buf), f) ) + return false; - int version; - if ( sscanf(buf, "%d", &version) != 1 ) - { - reporter->Warning("Existing DNS cache did not have correct version, ignoring"); - return false; - } + int version; + if ( sscanf(buf, "%d", &version) != 1 ) { + reporter->Warning("Existing DNS cache did not have correct version, ignoring"); + return false; + } - return FILE_VERSION == version; - } + return FILE_VERSION == version; +} ////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////// -TEST_CASE("dns_mapping init null hostent") - { - DNS_Mapping mapping("www.apple.com", nullptr, 123, T_A); +TEST_CASE("dns_mapping init null hostent") { + DNS_Mapping mapping("www.apple.com", nullptr, 123, T_A); - CHECK(! mapping.Valid()); - CHECK(mapping.Addrs() == nullptr); - CHECK(mapping.AddrsSet()->EqualTo(DNS_Mgr::empty_addr_set())); - CHECK(mapping.Host() == nullptr); - } + CHECK(! mapping.Valid()); + CHECK(mapping.Addrs() == nullptr); + CHECK(mapping.AddrsSet()->EqualTo(DNS_Mgr::empty_addr_set())); + CHECK(mapping.Host() == nullptr); +} -TEST_CASE("dns_mapping init host") - { - IPAddr addr("1.2.3.4"); - in4_addr in4; - addr.CopyIPv4(&in4); +TEST_CASE("dns_mapping init host") { + IPAddr addr("1.2.3.4"); + in4_addr in4; + addr.CopyIPv4(&in4); - struct hostent he; - he.h_name = util::copy_string("testing.home"); - he.h_aliases = NULL; - he.h_addrtype = AF_INET; - he.h_length = sizeof(in_addr); + struct hostent he; + he.h_name = util::copy_string("testing.home"); + he.h_aliases = NULL; + he.h_addrtype = AF_INET; + he.h_length = sizeof(in_addr); - std::vector addrs = {&in4, NULL}; - he.h_addr_list = reinterpret_cast(addrs.data()); + std::vector addrs = {&in4, NULL}; + he.h_addr_list = reinterpret_cast(addrs.data()); - DNS_Mapping mapping("testing.home", &he, 123, T_A); - CHECK(mapping.Valid()); - CHECK(mapping.ReqAddr() == IPAddr::v6_unspecified); - CHECK(strcmp(mapping.ReqHost(), "testing.home") == 0); - CHECK(mapping.ReqStr() == "testing.home"); + DNS_Mapping mapping("testing.home", &he, 123, T_A); + CHECK(mapping.Valid()); + CHECK(mapping.ReqAddr() == IPAddr::v6_unspecified); + CHECK(strcmp(mapping.ReqHost(), "testing.home") == 0); + CHECK(mapping.ReqStr() == "testing.home"); - auto lva = mapping.Addrs(); - REQUIRE(lva != nullptr); - CHECK(lva->Length() == 1); - auto lvae = lva->Idx(0)->AsAddrVal(); - REQUIRE(lvae != nullptr); - CHECK(lvae->Get().AsString() == "1.2.3.4"); + auto lva = mapping.Addrs(); + REQUIRE(lva != nullptr); + CHECK(lva->Length() == 1); + auto lvae = lva->Idx(0)->AsAddrVal(); + REQUIRE(lvae != nullptr); + CHECK(lvae->Get().AsString() == "1.2.3.4"); - auto tvas = mapping.AddrsSet(); - REQUIRE(tvas != nullptr); - CHECK_FALSE(tvas->EqualTo(DNS_Mgr::empty_addr_set())); + auto tvas = mapping.AddrsSet(); + REQUIRE(tvas != nullptr); + CHECK_FALSE(tvas->EqualTo(DNS_Mgr::empty_addr_set())); - auto svh = mapping.Host(); - REQUIRE(svh != nullptr); - CHECK(svh->ToStdString() == "testing.home"); + auto svh = mapping.Host(); + REQUIRE(svh != nullptr); + CHECK(svh->ToStdString() == "testing.home"); - delete[] he.h_name; - } + delete[] he.h_name; +} -TEST_CASE("dns_mapping init addr") - { - IPAddr addr("1.2.3.4"); - in4_addr in4; - addr.CopyIPv4(&in4); +TEST_CASE("dns_mapping init addr") { + IPAddr addr("1.2.3.4"); + in4_addr in4; + addr.CopyIPv4(&in4); - struct hostent he; - he.h_name = util::copy_string("testing.home"); - he.h_aliases = NULL; - he.h_addrtype = AF_INET; - he.h_length = sizeof(in_addr); + struct hostent he; + he.h_name = util::copy_string("testing.home"); + he.h_aliases = NULL; + he.h_addrtype = AF_INET; + he.h_length = sizeof(in_addr); - std::vector addrs = {&in4, NULL}; - he.h_addr_list = reinterpret_cast(addrs.data()); + std::vector addrs = {&in4, NULL}; + he.h_addr_list = reinterpret_cast(addrs.data()); - DNS_Mapping mapping(addr, &he, 123); - CHECK(mapping.Valid()); - CHECK(mapping.ReqAddr() == addr); - CHECK(mapping.ReqHost() == nullptr); - CHECK(mapping.ReqStr() == "1.2.3.4"); + DNS_Mapping mapping(addr, &he, 123); + CHECK(mapping.Valid()); + CHECK(mapping.ReqAddr() == addr); + CHECK(mapping.ReqHost() == nullptr); + CHECK(mapping.ReqStr() == "1.2.3.4"); - auto lva = mapping.Addrs(); - REQUIRE(lva != nullptr); - CHECK(lva->Length() == 1); - auto lvae = lva->Idx(0)->AsAddrVal(); - REQUIRE(lvae != nullptr); - CHECK(lvae->Get().AsString() == "1.2.3.4"); + auto lva = mapping.Addrs(); + REQUIRE(lva != nullptr); + CHECK(lva->Length() == 1); + auto lvae = lva->Idx(0)->AsAddrVal(); + REQUIRE(lvae != nullptr); + CHECK(lvae->Get().AsString() == "1.2.3.4"); - auto tvas = mapping.AddrsSet(); - REQUIRE(tvas != nullptr); - CHECK_FALSE(tvas->EqualTo(DNS_Mgr::empty_addr_set())); + auto tvas = mapping.AddrsSet(); + REQUIRE(tvas != nullptr); + CHECK_FALSE(tvas->EqualTo(DNS_Mgr::empty_addr_set())); - auto svh = mapping.Host(); - REQUIRE(svh != nullptr); - CHECK(svh->ToStdString() == "testing.home"); + auto svh = mapping.Host(); + REQUIRE(svh != nullptr); + CHECK(svh->ToStdString() == "testing.home"); - delete[] he.h_name; - } + delete[] he.h_name; +} -TEST_CASE("dns_mapping save reload") - { - // TODO: this test uses fmemopen and mkdtemp, both of which aren't available on - // Windows. We'll have to figure out another way to do this test there. +TEST_CASE("dns_mapping save reload") { + // TODO: this test uses fmemopen and mkdtemp, both of which aren't available on + // Windows. We'll have to figure out another way to do this test there. #ifndef _MSC_VER - IPAddr addr("1.2.3.4"); - in4_addr in4; - addr.CopyIPv4(&in4); + IPAddr addr("1.2.3.4"); + in4_addr in4; + addr.CopyIPv4(&in4); - struct hostent he; - he.h_name = util::copy_string("testing.home"); - he.h_aliases = NULL; - he.h_addrtype = AF_INET; - he.h_length = sizeof(in_addr); + struct hostent he; + he.h_name = util::copy_string("testing.home"); + he.h_aliases = NULL; + he.h_addrtype = AF_INET; + he.h_length = sizeof(in_addr); - std::vector addrs = {&in4, NULL}; - he.h_addr_list = reinterpret_cast(addrs.data()); + std::vector addrs = {&in4, NULL}; + he.h_addr_list = reinterpret_cast(addrs.data()); - // Create a temporary file in memory and fseek to the end of it so we're at - // EOF for the next bit. - char buffer[4096]; - memset(buffer, 0, 4096); - FILE* tmpfile = fmemopen(buffer, 4096, "r+"); - if ( fseek(tmpfile, 0, SEEK_END) < 0 ) - reporter->Error("DNS_Mapping: seek failed"); + // Create a temporary file in memory and fseek to the end of it so we're at + // EOF for the next bit. + char buffer[4096]; + memset(buffer, 0, 4096); + FILE* tmpfile = fmemopen(buffer, 4096, "r+"); + if ( fseek(tmpfile, 0, SEEK_END) < 0 ) + reporter->Error("DNS_Mapping: seek failed"); - // Try loading from the file at EOF. This should cause a mapping failure. - DNS_Mapping mapping(tmpfile); - CHECK(mapping.NoMapping()); - rewind(tmpfile); + // Try loading from the file at EOF. This should cause a mapping failure. + DNS_Mapping mapping(tmpfile); + CHECK(mapping.NoMapping()); + rewind(tmpfile); - // Try reading from the empty file. This should cause an init failure. - DNS_Mapping mapping2(tmpfile); - CHECK(mapping2.InitFailed()); - rewind(tmpfile); + // Try reading from the empty file. This should cause an init failure. + DNS_Mapping mapping2(tmpfile); + CHECK(mapping2.InitFailed()); + rewind(tmpfile); - // Save a valid mapping into the file and rewind to the start. - DNS_Mapping mapping3(addr, &he, 123); - mapping3.Save(tmpfile); - rewind(tmpfile); + // Save a valid mapping into the file and rewind to the start. + DNS_Mapping mapping3(addr, &he, 123); + mapping3.Save(tmpfile); + rewind(tmpfile); - // Test loading the mapping back out of the file - DNS_Mapping mapping4(tmpfile); - fclose(tmpfile); - CHECK(mapping4.Valid()); - CHECK(mapping4.ReqAddr() == addr); - CHECK(mapping4.ReqHost() == nullptr); - CHECK(mapping4.ReqStr() == "1.2.3.4"); + // Test loading the mapping back out of the file + DNS_Mapping mapping4(tmpfile); + fclose(tmpfile); + CHECK(mapping4.Valid()); + CHECK(mapping4.ReqAddr() == addr); + CHECK(mapping4.ReqHost() == nullptr); + CHECK(mapping4.ReqStr() == "1.2.3.4"); - auto lva = mapping4.Addrs(); - REQUIRE(lva != nullptr); - CHECK(lva->Length() == 1); - auto lvae = lva->Idx(0)->AsAddrVal(); - REQUIRE(lvae != nullptr); - CHECK(lvae->Get().AsString() == "1.2.3.4"); + auto lva = mapping4.Addrs(); + REQUIRE(lva != nullptr); + CHECK(lva->Length() == 1); + auto lvae = lva->Idx(0)->AsAddrVal(); + REQUIRE(lvae != nullptr); + CHECK(lvae->Get().AsString() == "1.2.3.4"); - auto tvas = mapping4.AddrsSet(); - REQUIRE(tvas != nullptr); - CHECK(tvas != DNS_Mgr::empty_addr_set()); + auto tvas = mapping4.AddrsSet(); + REQUIRE(tvas != nullptr); + CHECK(tvas != DNS_Mgr::empty_addr_set()); - auto svh = mapping4.Host(); - REQUIRE(svh != nullptr); - CHECK(svh->ToStdString() == "testing.home"); + auto svh = mapping4.Host(); + REQUIRE(svh != nullptr); + CHECK(svh->ToStdString() == "testing.home"); - delete[] he.h_name; + delete[] he.h_name; #endif - } +} -TEST_CASE("dns_mapping multiple addresses") - { - IPAddr addr("1.2.3.4"); - in4_addr in4_1; - addr.CopyIPv4(&in4_1); +TEST_CASE("dns_mapping multiple addresses") { + IPAddr addr("1.2.3.4"); + in4_addr in4_1; + addr.CopyIPv4(&in4_1); - IPAddr addr2("5.6.7.8"); - in4_addr in4_2; - addr2.CopyIPv4(&in4_2); + IPAddr addr2("5.6.7.8"); + in4_addr in4_2; + addr2.CopyIPv4(&in4_2); - struct hostent he; - he.h_name = util::copy_string("testing.home"); - he.h_aliases = NULL; - he.h_addrtype = AF_INET; - he.h_length = sizeof(in_addr); + struct hostent he; + he.h_name = util::copy_string("testing.home"); + he.h_aliases = NULL; + he.h_addrtype = AF_INET; + he.h_length = sizeof(in_addr); - std::vector addrs = {&in4_1, &in4_2, NULL}; - he.h_addr_list = reinterpret_cast(addrs.data()); + std::vector addrs = {&in4_1, &in4_2, NULL}; + he.h_addr_list = reinterpret_cast(addrs.data()); - DNS_Mapping mapping("testing.home", &he, 123, T_A); - CHECK(mapping.Valid()); + DNS_Mapping mapping("testing.home", &he, 123, T_A); + CHECK(mapping.Valid()); - auto lva = mapping.Addrs(); - REQUIRE(lva != nullptr); - CHECK(lva->Length() == 2); + auto lva = mapping.Addrs(); + REQUIRE(lva != nullptr); + CHECK(lva->Length() == 2); - auto lvae = lva->Idx(0)->AsAddrVal(); - REQUIRE(lvae != nullptr); - CHECK(lvae->Get().AsString() == "1.2.3.4"); + auto lvae = lva->Idx(0)->AsAddrVal(); + REQUIRE(lvae != nullptr); + CHECK(lvae->Get().AsString() == "1.2.3.4"); - lvae = lva->Idx(1)->AsAddrVal(); - REQUIRE(lvae != nullptr); - CHECK(lvae->Get().AsString() == "5.6.7.8"); + lvae = lva->Idx(1)->AsAddrVal(); + REQUIRE(lvae != nullptr); + CHECK(lvae->Get().AsString() == "5.6.7.8"); - delete[] he.h_name; - } + delete[] he.h_name; +} -TEST_CASE("dns_mapping ipv6") - { - IPAddr addr("64:ff9b:1::"); - in6_addr in6; - addr.CopyIPv6(&in6); +TEST_CASE("dns_mapping ipv6") { + IPAddr addr("64:ff9b:1::"); + in6_addr in6; + addr.CopyIPv6(&in6); - struct hostent he; - he.h_name = util::copy_string("testing.home"); - he.h_aliases = NULL; - he.h_addrtype = AF_INET6; - he.h_length = sizeof(in6_addr); + struct hostent he; + he.h_name = util::copy_string("testing.home"); + he.h_aliases = NULL; + he.h_addrtype = AF_INET6; + he.h_length = sizeof(in6_addr); - std::vector addrs = {&in6, NULL}; - he.h_addr_list = reinterpret_cast(addrs.data()); + std::vector addrs = {&in6, NULL}; + he.h_addr_list = reinterpret_cast(addrs.data()); - DNS_Mapping mapping(addr, &he, 123); - CHECK(mapping.Valid()); - CHECK(mapping.ReqAddr() == addr); - CHECK(mapping.ReqHost() == nullptr); - CHECK(mapping.ReqStr() == "64:ff9b:1::"); + DNS_Mapping mapping(addr, &he, 123); + CHECK(mapping.Valid()); + CHECK(mapping.ReqAddr() == addr); + CHECK(mapping.ReqHost() == nullptr); + CHECK(mapping.ReqStr() == "64:ff9b:1::"); - auto lva = mapping.Addrs(); - REQUIRE(lva != nullptr); - CHECK(lva->Length() == 1); - auto lvae = lva->Idx(0)->AsAddrVal(); - REQUIRE(lvae != nullptr); - CHECK(lvae->Get().AsString() == "64:ff9b:1::"); + auto lva = mapping.Addrs(); + REQUIRE(lva != nullptr); + CHECK(lva->Length() == 1); + auto lvae = lva->Idx(0)->AsAddrVal(); + REQUIRE(lvae != nullptr); + CHECK(lvae->Get().AsString() == "64:ff9b:1::"); - delete[] he.h_name; - } + delete[] he.h_name; +} - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/DNS_Mapping.h b/src/DNS_Mapping.h index fa761c4e51..7aadfb6355 100644 --- a/src/DNS_Mapping.h +++ b/src/DNS_Mapping.h @@ -8,73 +8,71 @@ #include "zeek/IPAddr.h" #include "zeek/Val.h" -namespace zeek::detail - { +namespace zeek::detail { class DNS_Mapping; using DNS_MappingPtr = std::shared_ptr; -class DNS_Mapping - { +class DNS_Mapping { public: - DNS_Mapping() = delete; - DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type); - DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl); - DNS_Mapping(FILE* f); + DNS_Mapping() = delete; + DNS_Mapping(std::string host, struct hostent* h, uint32_t ttl, int type); + DNS_Mapping(const IPAddr& addr, struct hostent* h, uint32_t ttl); + DNS_Mapping(FILE* f); - bool NoMapping() const { return no_mapping; } - bool InitFailed() const { return init_failed; } + bool NoMapping() const { return no_mapping; } + bool InitFailed() const { return init_failed; } - ~DNS_Mapping() = default; + ~DNS_Mapping() = default; - // Returns nil if this was an address request. - // TODO: fix this an uses of this to just return the empty string - const char* ReqHost() const { return req_host.empty() ? nullptr : req_host.c_str(); } - const IPAddr& ReqAddr() const { return req_addr; } - std::string ReqStr() const { return req_host.empty() ? req_addr.AsString() : req_host; } - int ReqType() const { return req_type; } + // Returns nil if this was an address request. + // TODO: fix this an uses of this to just return the empty string + const char* ReqHost() const { return req_host.empty() ? nullptr : req_host.c_str(); } + const IPAddr& ReqAddr() const { return req_addr; } + std::string ReqStr() const { return req_host.empty() ? req_addr.AsString() : req_host; } + int ReqType() const { return req_type; } - ListValPtr Addrs(); - TableValPtr AddrsSet(); // addresses returned as a set - StringValPtr Host(); + ListValPtr Addrs(); + TableValPtr AddrsSet(); // addresses returned as a set + StringValPtr Host(); - double CreationTime() const { return creation_time; } - uint32_t TTL() const { return req_ttl; } + double CreationTime() const { return creation_time; } + uint32_t TTL() const { return req_ttl; } - void Save(FILE* f) const; + void Save(FILE* f) const; - bool Failed() const { return failed; } - bool Valid() const { return ! failed; } + bool Failed() const { return failed; } + bool Valid() const { return ! failed; } - bool Expired() const { return util::current_time() > (creation_time + req_ttl); } + bool Expired() const { return util::current_time() > (creation_time + req_ttl); } - void Merge(const DNS_MappingPtr& other); + void Merge(const DNS_MappingPtr& other); - static void InitializeCache(FILE* f); - static bool ValidateCacheVersion(FILE* f); + static void InitializeCache(FILE* f); + static bool ValidateCacheVersion(FILE* f); protected: - friend class DNS_Mgr; + friend class DNS_Mgr; - void Init(struct hostent* h); - void Clear(); + void Init(struct hostent* h); + void Clear(); - std::string req_host; - IPAddr req_addr; - uint32_t req_ttl = 0; - int req_type = 0; + std::string req_host; + IPAddr req_addr; + uint32_t req_ttl = 0; + int req_type = 0; - // This class supports multiple names per address, but we only store one of them. - std::vector names; - StringValPtr host_val; + // This class supports multiple names per address, but we only store one of them. + std::vector names; + StringValPtr host_val; - std::vector addrs; - ListValPtr addrs_val; + std::vector addrs; + ListValPtr addrs_val; - double creation_time = 0.0; - bool no_mapping = false; // when initializing from a file, immediately hit EOF - bool init_failed = false; - bool failed = false; - }; + double creation_time = 0.0; + bool no_mapping = false; // when initializing from a file, immediately hit EOF + bool init_failed = false; + bool failed = false; +}; - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/DNS_Mgr.cc b/src/DNS_Mgr.cc index c0e7b331cc..a7747e3c67 100644 --- a/src/DNS_Mgr.cc +++ b/src/DNS_Mgr.cc @@ -57,246 +57,173 @@ constexpr int MAX_PENDING_REQUESTS = 20; constexpr int MAX_UDP_BUFFER_SIZE = 4096; // This unfortunately doesn't exist in c-ares, even though it seems rather useful. -static const char* request_type_string(int request_type) - { - switch ( request_type ) - { - case T_A: - return "T_A"; - case T_NS: - return "T_NS"; - case T_MD: - return "T_MD"; - case T_MF: - return "T_MF"; - case T_CNAME: - return "T_CNAME"; - case T_SOA: - return "T_SOA"; - case T_MB: - return "T_MB"; - case T_MG: - return "T_MG"; - case T_MR: - return "T_MR"; - case T_NULL: - return "T_NULL"; - case T_WKS: - return "T_WKS"; - case T_PTR: - return "T_PTR"; - case T_HINFO: - return "T_HINFO"; - case T_MINFO: - return "T_MINFO"; - case T_MX: - return "T_MX"; - case T_TXT: - return "T_TXT"; - case T_RP: - return "T_RP"; - case T_AFSDB: - return "T_AFSDB"; - case T_X25: - return "T_X25"; - case T_ISDN: - return "T_ISDN"; - case T_RT: - return "T_RT"; - case T_NSAP: - return "T_NSAP"; - case T_NSAP_PTR: - return "T_NSAP_PTR"; - case T_SIG: - return "T_SIG"; - case T_KEY: - return "T_KEY"; - case T_PX: - return "T_PX"; - case T_GPOS: - return "T_GPOS"; - case T_AAAA: - return "T_AAAA"; - case T_LOC: - return "T_LOC"; - case T_NXT: - return "T_NXT"; - case T_EID: - return "T_EID"; - case T_NIMLOC: - return "T_NIMLOC"; - case T_SRV: - return "T_SRV"; - case T_ATMA: - return "T_ATMA"; - case T_NAPTR: - return "T_NAPTR"; - case T_KX: - return "T_KX"; - case T_CERT: - return "T_CERT"; - case T_A6: - return "T_A6"; - case T_DNAME: - return "T_DNAME"; - case T_SINK: - return "T_SINK"; - case T_OPT: - return "T_OPT"; - case T_APL: - return "T_APL"; - case T_DS: - return "T_DS"; - case T_SSHFP: - return "T_SSHFP"; - case T_RRSIG: - return "T_RRSIG"; - case T_NSEC: - return "T_NSEC"; - case T_DNSKEY: - return "T_DNSKEY"; - case T_TKEY: - return "T_TKEY"; - case T_TSIG: - return "T_TSIG"; - case T_IXFR: - return "T_IXFR"; - case T_AXFR: - return "T_AXFR"; - case T_MAILB: - return "T_MAILB"; - case T_MAILA: - return "T_MAILA"; - case T_ANY: - return "T_ANY"; - case T_URI: - return "T_URI"; - case T_CAA: - return "T_CAA"; - case T_MAX: - return "T_MAX"; - default: - return ""; - } - } +static const char* request_type_string(int request_type) { + switch ( request_type ) { + case T_A: return "T_A"; + case T_NS: return "T_NS"; + case T_MD: return "T_MD"; + case T_MF: return "T_MF"; + case T_CNAME: return "T_CNAME"; + case T_SOA: return "T_SOA"; + case T_MB: return "T_MB"; + case T_MG: return "T_MG"; + case T_MR: return "T_MR"; + case T_NULL: return "T_NULL"; + case T_WKS: return "T_WKS"; + case T_PTR: return "T_PTR"; + case T_HINFO: return "T_HINFO"; + case T_MINFO: return "T_MINFO"; + case T_MX: return "T_MX"; + case T_TXT: return "T_TXT"; + case T_RP: return "T_RP"; + case T_AFSDB: return "T_AFSDB"; + case T_X25: return "T_X25"; + case T_ISDN: return "T_ISDN"; + case T_RT: return "T_RT"; + case T_NSAP: return "T_NSAP"; + case T_NSAP_PTR: return "T_NSAP_PTR"; + case T_SIG: return "T_SIG"; + case T_KEY: return "T_KEY"; + case T_PX: return "T_PX"; + case T_GPOS: return "T_GPOS"; + case T_AAAA: return "T_AAAA"; + case T_LOC: return "T_LOC"; + case T_NXT: return "T_NXT"; + case T_EID: return "T_EID"; + case T_NIMLOC: return "T_NIMLOC"; + case T_SRV: return "T_SRV"; + case T_ATMA: return "T_ATMA"; + case T_NAPTR: return "T_NAPTR"; + case T_KX: return "T_KX"; + case T_CERT: return "T_CERT"; + case T_A6: return "T_A6"; + case T_DNAME: return "T_DNAME"; + case T_SINK: return "T_SINK"; + case T_OPT: return "T_OPT"; + case T_APL: return "T_APL"; + case T_DS: return "T_DS"; + case T_SSHFP: return "T_SSHFP"; + case T_RRSIG: return "T_RRSIG"; + case T_NSEC: return "T_NSEC"; + case T_DNSKEY: return "T_DNSKEY"; + case T_TKEY: return "T_TKEY"; + case T_TSIG: return "T_TSIG"; + case T_IXFR: return "T_IXFR"; + case T_AXFR: return "T_AXFR"; + case T_MAILB: return "T_MAILB"; + case T_MAILA: return "T_MAILA"; + case T_ANY: return "T_ANY"; + case T_URI: return "T_URI"; + case T_CAA: return "T_CAA"; + case T_MAX: return "T_MAX"; + default: return ""; + } +} -struct ares_deleter - { - void operator()(char* s) const { ares_free_string(s); } - void operator()(unsigned char* s) const { ares_free_string(s); } - void operator()(ares_addrinfo* s) const { ares_freeaddrinfo(s); } - void operator()(struct hostent* h) const { ares_free_hostent(h); } - void operator()(struct ares_txt_reply* h) const { ares_free_data(h); } - }; +struct ares_deleter { + void operator()(char* s) const { ares_free_string(s); } + void operator()(unsigned char* s) const { ares_free_string(s); } + void operator()(ares_addrinfo* s) const { ares_freeaddrinfo(s); } + void operator()(struct hostent* h) const { ares_free_hostent(h); } + void operator()(struct ares_txt_reply* h) const { ares_free_data(h); } +}; -namespace zeek::detail - { +namespace zeek::detail { static void addrinfo_cb(void* arg, int status, int timeouts, struct ares_addrinfo* result); static void query_cb(void* arg, int status, int timeouts, unsigned char* buf, int len); static void sock_cb(void* data, int s, int read, int write); -struct CallbackArgs - { - DNS_Request* req; - DNS_Mgr* mgr; - }; +struct CallbackArgs { + DNS_Request* req; + DNS_Mgr* mgr; +}; -class DNS_Request - { +class DNS_Request { public: - DNS_Request(std::string host, int request_type, bool async = false); - DNS_Request(const IPAddr& addr, bool async = false); - ~DNS_Request() = default; + DNS_Request(std::string host, int request_type, bool async = false); + DNS_Request(const IPAddr& addr, bool async = false); + ~DNS_Request() = default; - std::string Host() const { return host; } - const IPAddr& Addr() const { return addr; } - int RequestType() const { return request_type; } - bool IsTxt() const { return request_type == 16; } + std::string Host() const { return host; } + const IPAddr& Addr() const { return addr; } + int RequestType() const { return request_type; } + bool IsTxt() const { return request_type == 16; } - void MakeRequest(ares_channel channel, DNS_Mgr* mgr); - void ProcessAsyncResult(bool timed_out, DNS_Mgr* mgr); + void MakeRequest(ares_channel channel, DNS_Mgr* mgr); + void ProcessAsyncResult(bool timed_out, DNS_Mgr* mgr); private: - std::string host; - IPAddr addr; - int request_type = 0; // Query type - bool async = false; - std::unique_ptr query; - static uint16_t request_id; - }; + std::string host; + IPAddr addr; + int request_type = 0; // Query type + bool async = false; + std::unique_ptr query; + static uint16_t request_id; +}; uint16_t DNS_Request::request_id = 0; DNS_Request::DNS_Request(std::string host, int request_type, bool async) - : host(std::move(host)), request_type(request_type), async(async) - { - // We combine the T_A and T_AAAA requests together in one request, so set the type - // to T_A to make things easier in other parts of the code (mostly around lookups). - if ( request_type == T_AAAA ) - request_type = T_A; - } + : host(std::move(host)), request_type(request_type), async(async) { + // We combine the T_A and T_AAAA requests together in one request, so set the type + // to T_A to make things easier in other parts of the code (mostly around lookups). + if ( request_type == T_AAAA ) + request_type = T_A; +} -DNS_Request::DNS_Request(const IPAddr& addr, bool async) : addr(addr), async(async) - { - request_type = T_PTR; - } +DNS_Request::DNS_Request(const IPAddr& addr, bool async) : addr(addr), async(async) { request_type = T_PTR; } -void DNS_Request::MakeRequest(ares_channel channel, DNS_Mgr* mgr) - { - // This needs to get deleted at the end of the callback method. - auto req_data = std::make_unique(); - req_data->req = this; - req_data->mgr = mgr; +void DNS_Request::MakeRequest(ares_channel channel, DNS_Mgr* mgr) { + // This needs to get deleted at the end of the callback method. + auto req_data = std::make_unique(); + req_data->req = this; + req_data->mgr = mgr; - // It's completely fine if this rolls over. It's just to keep the query ID different - // from one query to the next, and it's unlikely we'd do 2^16 queries so fast that - // all of them would be in flight at the same time. - DNS_Request::request_id++; + // It's completely fine if this rolls over. It's just to keep the query ID different + // from one query to the next, and it's unlikely we'd do 2^16 queries so fast that + // all of them would be in flight at the same time. + DNS_Request::request_id++; - if ( request_type == T_A ) - { - // For A/AAAA requests, we use a different method than the other requests. Since - // we're using the AF_UNSPEC family, we get both the ipv4 and ipv6 responses - // back in the same request if use ares_getaddrinfo() so we can store them both - // in the same mapping. - ares_addrinfo_hints hints = {ARES_AI_CANONNAME, AF_UNSPEC, 0, 0}; - ares_getaddrinfo(channel, host.c_str(), NULL, &hints, addrinfo_cb, req_data.release()); - } - else - { - std::string query_host; - if ( request_type == T_PTR ) - query_host = addr.PtrName(); - else - query_host = host; + if ( request_type == T_A ) { + // For A/AAAA requests, we use a different method than the other requests. Since + // we're using the AF_UNSPEC family, we get both the ipv4 and ipv6 responses + // back in the same request if use ares_getaddrinfo() so we can store them both + // in the same mapping. + ares_addrinfo_hints hints = {ARES_AI_CANONNAME, AF_UNSPEC, 0, 0}; + ares_getaddrinfo(channel, host.c_str(), NULL, &hints, addrinfo_cb, req_data.release()); + } + else { + std::string query_host; + if ( request_type == T_PTR ) + query_host = addr.PtrName(); + else + query_host = host; - std::unique_ptr query_str; - int len = 0; - int status = ares_create_query( - query_host.c_str(), C_IN, request_type, DNS_Request::request_id, 1, - out_ptr(query_str), &len, MAX_UDP_BUFFER_SIZE); + std::unique_ptr query_str; + int len = 0; + int status = ares_create_query(query_host.c_str(), C_IN, request_type, DNS_Request::request_id, 1, + out_ptr(query_str), &len, MAX_UDP_BUFFER_SIZE); - if ( status != ARES_SUCCESS || query_str == nullptr ) - return; + if ( status != ARES_SUCCESS || query_str == nullptr ) + return; - // Store this so it can be destroyed when the request is destroyed. - this->query = std::move(query_str); - ares_send(channel, this->query.get(), len, query_cb, req_data.release()); - } - } + // Store this so it can be destroyed when the request is destroyed. + this->query = std::move(query_str); + ares_send(channel, this->query.get(), len, query_cb, req_data.release()); + } +} -void DNS_Request::ProcessAsyncResult(bool timed_out, DNS_Mgr* mgr) - { - if ( ! async ) - return; +void DNS_Request::ProcessAsyncResult(bool timed_out, DNS_Mgr* mgr) { + if ( ! async ) + return; - if ( request_type == T_A ) - mgr->CheckAsyncHostRequest(host, timed_out); - else if ( request_type == T_PTR ) - mgr->CheckAsyncAddrRequest(addr, timed_out); - else - mgr->CheckAsyncOtherRequest(host, timed_out, request_type); - } + if ( request_type == T_A ) + mgr->CheckAsyncHostRequest(host, timed_out); + else if ( request_type == T_PTR ) + mgr->CheckAsyncAddrRequest(addr, timed_out); + else + mgr->CheckAsyncOtherRequest(host, timed_out, request_type); +} /** * Retrieves the TTL value from the first RR in the response. @@ -314,1281 +241,1122 @@ void DNS_Request::ProcessAsyncResult(bool timed_out, DNS_Mgr* mgr) * @return A status code from c-ares. This will be ARES_SUCCESS on success, or some other * code on failure. */ -static int get_ttl(unsigned char* abuf, int alen, int* ttl) - { - int status; - long len; - std::unique_ptr hostname; +static int get_ttl(unsigned char* abuf, int alen, int* ttl) { + int status; + long len; + std::unique_ptr hostname; - *ttl = DNS_TIMEOUT; + *ttl = DNS_TIMEOUT; - unsigned char* aptr = abuf + HFIXEDSZ; - status = ares_expand_name(aptr, abuf, alen, out_ptr(hostname), &len); - if ( status != ARES_SUCCESS ) - return status; + unsigned char* aptr = abuf + HFIXEDSZ; + status = ares_expand_name(aptr, abuf, alen, out_ptr(hostname), &len); + if ( status != ARES_SUCCESS ) + return status; - if ( aptr + len + QFIXEDSZ > abuf + alen ) - return ARES_EBADRESP; + if ( aptr + len + QFIXEDSZ > abuf + alen ) + return ARES_EBADRESP; - aptr += len + QFIXEDSZ; - hostname.reset(); + aptr += len + QFIXEDSZ; + hostname.reset(); - status = ares_expand_name(aptr, abuf, alen, out_ptr(hostname), &len); - if ( status != ARES_SUCCESS ) - return status; + status = ares_expand_name(aptr, abuf, alen, out_ptr(hostname), &len); + if ( status != ARES_SUCCESS ) + return status; - if ( aptr + RRFIXEDSZ > abuf + alen ) - return ARES_EBADRESP; + if ( aptr + RRFIXEDSZ > abuf + alen ) + return ARES_EBADRESP; - aptr += len; - *ttl = DNS_RR_TTL(aptr); + aptr += len; + *ttl = DNS_RR_TTL(aptr); - return status; - } + return status; +} /** * Called in response to ares_getaddrinfo requests. Builds a hostent structure from * the result data and sends it to the DNS manager via AddResult(). */ -static void addrinfo_cb(void* arg, int status, int timeouts, struct ares_addrinfo* result) - { - auto arg_data = reinterpret_cast(arg); - const auto [req, mgr] = *arg_data; - std::unique_ptr res_ptr(result); +static void addrinfo_cb(void* arg, int status, int timeouts, struct ares_addrinfo* result) { + auto arg_data = reinterpret_cast(arg); + const auto [req, mgr] = *arg_data; + std::unique_ptr res_ptr(result); - if ( status != ARES_SUCCESS ) - { - // These two statuses should only ever be sent if we're shutting down everything - // and all of the existing queries are being cancelled. There's no reason to - // store a status that's just going to get deleted, nor is there a reason to log - // anything. - if ( status != ARES_ECANCELLED && status != ARES_EDESTRUCTION ) - { - // Insert something into the cache so that the request loop will end correctly. - // We use the DNS_TIMEOUT value as the TTL here since it's small enough that the - // failed response will expire soon, and because we don't have the TTL from the - // response data. - mgr->AddResult(req, nullptr, DNS_TIMEOUT); - } - } - else - { - std::vector addrs; - std::vector addrs6; - for ( ares_addrinfo_node* entry = result->nodes; entry != NULL; entry = entry->ai_next ) - { - if ( entry->ai_family == AF_INET ) - { - struct sockaddr_in* addr = reinterpret_cast(entry->ai_addr); - addrs.push_back(&addr->sin_addr); - } - else if ( entry->ai_family == AF_INET6 ) - { - struct sockaddr_in6* addr = (struct sockaddr_in6*)(entry->ai_addr); - addrs6.push_back(&addr->sin6_addr); - } - } + if ( status != ARES_SUCCESS ) { + // These two statuses should only ever be sent if we're shutting down everything + // and all of the existing queries are being cancelled. There's no reason to + // store a status that's just going to get deleted, nor is there a reason to log + // anything. + if ( status != ARES_ECANCELLED && status != ARES_EDESTRUCTION ) { + // Insert something into the cache so that the request loop will end correctly. + // We use the DNS_TIMEOUT value as the TTL here since it's small enough that the + // failed response will expire soon, and because we don't have the TTL from the + // response data. + mgr->AddResult(req, nullptr, DNS_TIMEOUT); + } + } + else { + std::vector addrs; + std::vector addrs6; + for ( ares_addrinfo_node* entry = result->nodes; entry != NULL; entry = entry->ai_next ) { + if ( entry->ai_family == AF_INET ) { + struct sockaddr_in* addr = reinterpret_cast(entry->ai_addr); + addrs.push_back(&addr->sin_addr); + } + else if ( entry->ai_family == AF_INET6 ) { + struct sockaddr_in6* addr = (struct sockaddr_in6*)(entry->ai_addr); + addrs6.push_back(&addr->sin6_addr); + } + } - if ( ! addrs.empty() ) - { - // Push a null on the end so the addr list has a final point during later parsing. - addrs.push_back(NULL); + if ( ! addrs.empty() ) { + // Push a null on the end so the addr list has a final point during later parsing. + addrs.push_back(NULL); - struct hostent he - { - }; - he.h_name = util::copy_string(result->name); - he.h_addrtype = AF_INET; - he.h_length = sizeof(in_addr); - he.h_addr_list = reinterpret_cast(addrs.data()); + struct hostent he {}; + he.h_name = util::copy_string(result->name); + he.h_addrtype = AF_INET; + he.h_length = sizeof(in_addr); + he.h_addr_list = reinterpret_cast(addrs.data()); - mgr->AddResult(req, &he, result->nodes[0].ai_ttl); + mgr->AddResult(req, &he, result->nodes[0].ai_ttl); - delete[] he.h_name; - } + delete[] he.h_name; + } - if ( ! addrs6.empty() ) - { - // Push a null on the end so the addr list has a final point during later parsing. - addrs6.push_back(NULL); + if ( ! addrs6.empty() ) { + // Push a null on the end so the addr list has a final point during later parsing. + addrs6.push_back(NULL); - struct hostent he - { - }; - he.h_name = util::copy_string(result->name); - he.h_addrtype = AF_INET6; - he.h_length = sizeof(in6_addr); - he.h_addr_list = reinterpret_cast(addrs6.data()); + struct hostent he {}; + he.h_name = util::copy_string(result->name); + he.h_addrtype = AF_INET6; + he.h_length = sizeof(in6_addr); + he.h_addr_list = reinterpret_cast(addrs6.data()); - mgr->AddResult(req, &he, result->nodes[0].ai_ttl, true); + mgr->AddResult(req, &he, result->nodes[0].ai_ttl, true); - delete[] he.h_name; - } - } + delete[] he.h_name; + } + } - req->ProcessAsyncResult(timeouts > 0, mgr); + req->ProcessAsyncResult(timeouts > 0, mgr); - // TODO: might need to turn these into unique_ptr as well? - delete req; - delete arg_data; - } + // TODO: might need to turn these into unique_ptr as well? + delete req; + delete arg_data; +} -static void query_cb(void* arg, int status, int timeouts, unsigned char* buf, int len) - { - auto arg_data = reinterpret_cast(arg); - const auto [req, mgr] = *arg_data; +static void query_cb(void* arg, int status, int timeouts, unsigned char* buf, int len) { + auto arg_data = reinterpret_cast(arg); + const auto [req, mgr] = *arg_data; - if ( status != ARES_SUCCESS ) - { - // These two statuses should only ever be sent if we're shutting down everything - // and all of the existing queries are being cancelled. There's no reason to - // store a status that's just going to get deleted, nor is there a reason to log - // anything. - if ( status != ARES_ECANCELLED && status != ARES_EDESTRUCTION ) - { - // Insert something into the cache so that the request loop will end correctly. - // We use the DNS_TIMEOUT value as the TTL here since it's small enough that the - // failed response will expire soon, and because we don't have the TTL from the - // response data. - mgr->AddResult(req, nullptr, DNS_TIMEOUT); - } - } - else - { - // We don't really care that we couldn't properly parse the TTL here, since the - // later parsing will fail with better error messages. In that case, it's ok - // that we throw away the status value. - int ttl; - get_ttl(buf, len, &ttl); + if ( status != ARES_SUCCESS ) { + // These two statuses should only ever be sent if we're shutting down everything + // and all of the existing queries are being cancelled. There's no reason to + // store a status that's just going to get deleted, nor is there a reason to log + // anything. + if ( status != ARES_ECANCELLED && status != ARES_EDESTRUCTION ) { + // Insert something into the cache so that the request loop will end correctly. + // We use the DNS_TIMEOUT value as the TTL here since it's small enough that the + // failed response will expire soon, and because we don't have the TTL from the + // response data. + mgr->AddResult(req, nullptr, DNS_TIMEOUT); + } + } + else { + // We don't really care that we couldn't properly parse the TTL here, since the + // later parsing will fail with better error messages. In that case, it's ok + // that we throw away the status value. + int ttl; + get_ttl(buf, len, &ttl); - switch ( req->RequestType() ) - { - case T_PTR: - { - std::unique_ptr he; - if ( req->Addr().GetFamily() == IPv4 ) - { - struct in_addr addr; - req->Addr().CopyIPv4(&addr); - status = ares_parse_ptr_reply(buf, len, &addr, sizeof(addr), AF_INET, - out_ptr(he)); - } - else - { - struct in6_addr addr; - req->Addr().CopyIPv6(&addr); - status = ares_parse_ptr_reply(buf, len, &addr, sizeof(addr), AF_INET6, - out_ptr(he)); - } + switch ( req->RequestType() ) { + case T_PTR: { + std::unique_ptr he; + if ( req->Addr().GetFamily() == IPv4 ) { + struct in_addr addr; + req->Addr().CopyIPv4(&addr); + status = ares_parse_ptr_reply(buf, len, &addr, sizeof(addr), AF_INET, out_ptr(he)); + } + else { + struct in6_addr addr; + req->Addr().CopyIPv6(&addr); + status = + ares_parse_ptr_reply(buf, len, &addr, sizeof(addr), AF_INET6, out_ptr(he)); + } - if ( status == ARES_SUCCESS ) - mgr->AddResult(req, he.get(), ttl); - else - { - // See above for why DNS_TIMEOUT here. - mgr->AddResult(req, nullptr, DNS_TIMEOUT); - } - break; - } - case T_TXT: - { - std::unique_ptr reply; - int r = ares_parse_txt_reply(buf, len, out_ptr(reply)); - if ( r == ARES_SUCCESS ) - { - // Use a hostent to send the data into AddResult(). We only care about - // setting the host field, but everything else should be zero just for - // safety. + if ( status == ARES_SUCCESS ) + mgr->AddResult(req, he.get(), ttl); + else { + // See above for why DNS_TIMEOUT here. + mgr->AddResult(req, nullptr, DNS_TIMEOUT); + } + break; + } + case T_TXT: { + std::unique_ptr reply; + int r = ares_parse_txt_reply(buf, len, out_ptr(reply)); + if ( r == ARES_SUCCESS ) { + // Use a hostent to send the data into AddResult(). We only care about + // setting the host field, but everything else should be zero just for + // safety. - // We don't currently handle more than the first response, and throw the - // rest away. There really isn't a good reason for this, we just haven't - // ever done so. It would likely require some changes to the output from - // Lookup(), since right now it only returns one value. - struct hostent he - { - }; - he.h_name = util::copy_string(reinterpret_cast(reply->txt)); - mgr->AddResult(req, &he, ttl); + // We don't currently handle more than the first response, and throw the + // rest away. There really isn't a good reason for this, we just haven't + // ever done so. It would likely require some changes to the output from + // Lookup(), since right now it only returns one value. + struct hostent he {}; + he.h_name = util::copy_string(reinterpret_cast(reply->txt)); + mgr->AddResult(req, &he, ttl); - delete[] he.h_name; - } - else - { - // See above for why DNS_TIMEOUT here. - mgr->AddResult(req, nullptr, DNS_TIMEOUT); - } + delete[] he.h_name; + } + else { + // See above for why DNS_TIMEOUT here. + mgr->AddResult(req, nullptr, DNS_TIMEOUT); + } - break; - } + break; + } - default: - reporter->Error("Requests of type %d (%s) are unsupported", req->RequestType(), - request_type_string(req->RequestType())); - break; - } - } + default: + reporter->Error("Requests of type %d (%s) are unsupported", req->RequestType(), + request_type_string(req->RequestType())); + break; + } + } - req->ProcessAsyncResult(timeouts > 0, mgr); - delete arg_data; - delete req; - } + req->ProcessAsyncResult(timeouts > 0, mgr); + delete arg_data; + delete req; +} /** * Called when the c-ares socket changes state, which indicates that it's connected to * some source of data (either a host file or a DNS server). This indicates that we're * able to do lookups against c-ares now and should activate the IOSource. */ -static void sock_cb(void* data, ares_socket_t s, int read, int write) - { - auto mgr = reinterpret_cast(data); - mgr->RegisterSocket((int)s, read == 1, write == 1); - } - -DNS_Mgr::DNS_Mgr(DNS_MgrMode arg_mode) : IOSource(true), mode(arg_mode) - { - ares_library_init(ARES_LIB_INIT_ALL); - } - -DNS_Mgr::~DNS_Mgr() - { - Flush(); - - ares_cancel(channel); - ares_destroy(channel); - ares_library_cleanup(); - } - -void DNS_Mgr::Done() - { - shutting_down = true; - Flush(); - } - -void DNS_Mgr::RegisterSocket(int fd, bool read, bool write) - { - if ( read && socket_fds.count(fd) == 0 ) - { - socket_fds.insert(fd); - iosource_mgr->RegisterFd(fd, this, IOSource::READ); - } - else if ( ! read && socket_fds.count(fd) != 0 ) - { - socket_fds.erase(fd); - iosource_mgr->UnregisterFd(fd, this, IOSource::READ); - } - - if ( write && write_socket_fds.count(fd) == 0 ) - { - write_socket_fds.insert(fd); - iosource_mgr->RegisterFd(fd, this, IOSource::WRITE); - } - else if ( ! write && write_socket_fds.count(fd) != 0 ) - { - write_socket_fds.erase(fd); - iosource_mgr->UnregisterFd(fd, this, IOSource::WRITE); - } - } - -void DNS_Mgr::InitSource() - { - if ( did_init ) - return; - - ares_options options; - int optmask = 0; - - // Enable an EDNS option to be sent with the requests. This allows us to set - // a bigger UDP buffer size in the request, which prevents fallback to TCP - // at least up to that size. - options.flags = ARES_FLAG_EDNS; - optmask |= ARES_OPT_FLAGS; - - options.ednspsz = MAX_UDP_BUFFER_SIZE; - optmask |= ARES_OPT_EDNSPSZ; - - options.socket_receive_buffer_size = MAX_UDP_BUFFER_SIZE; - optmask |= ARES_OPT_SOCK_RCVBUF; - - // This option is in milliseconds. - options.timeout = DNS_TIMEOUT * 1000; - optmask |= ARES_OPT_TIMEOUTMS; - - // This causes c-ares to only attempt each server twice before - // giving up. - options.tries = 2; - optmask |= ARES_OPT_TRIES; - - // See the comment on sock_cb for how this gets used. - options.sock_state_cb = sock_cb; - options.sock_state_cb_data = this; - optmask |= ARES_OPT_SOCK_STATE_CB; - - int status = ares_init_options(&channel, &options, optmask); - if ( status != ARES_SUCCESS ) - reporter->FatalError("Failed to initialize c-ares for DNS resolution: %s", - ares_strerror(status)); - - // Note that Init() may be called by way of LookupHost() during the act of - // parsing a hostname literal (e.g. google.com), so we can't use a - // script-layer option to configure the DNS resolver as it may not be - // configured to the user's desired address at the time when we need to to - // the lookup. - auto dns_resolver = getenv("ZEEK_DNS_RESOLVER"); - if ( dns_resolver ) - { - ares_addr_node servers; - servers.next = NULL; - - auto dns_resolver_addr = IPAddr(dns_resolver); - - if ( dns_resolver_addr.GetFamily() == IPv4 ) - { - servers.family = AF_INET; - dns_resolver_addr.CopyIPv4(&(servers.addr.addr4)); - } - else - { - struct sockaddr_in6 sa = {0}; - sa.sin6_family = AF_INET6; - dns_resolver_addr.CopyIPv6(&sa.sin6_addr); - - servers.family = AF_INET6; - memcpy(&(servers.addr.addr6), &sa.sin6_addr, sizeof(ares_in6_addr)); - } - - ares_set_servers(channel, &servers); - } - - did_init = true; - } - -void DNS_Mgr::InitPostScript() - { - if ( ! doctest::is_running_in_test ) - { - dm_rec = id::find_type("dns_mapping"); - - // Registering will call InitSource(), which sets up all of the DNS library stuff - iosource_mgr->Register(this, true); - } - else - { - // This would normally be called when registering the iosource above. - InitSource(); - } - - // Load the DNS cache from disk, if it exists. - std::string cache_dir = dir.empty() ? "." : dir; - cache_name = util::fmt("%s/%s", cache_dir.c_str(), ".zeek-dns-cache"); - LoadCache(cache_name); - } - -static TableValPtr fake_name_lookup_result(const std::string& name) - { - hash128_t hash; - KeyedHash::StaticHash128(name.c_str(), name.size(), &hash); - auto hv = make_intrusive(TYPE_ADDR); - hv->Append(make_intrusive(reinterpret_cast(&hash))); - return hv->ToSetVal(); - } - -static std::string fake_lookup_result(const std::string& name, int request_type) - { - return util::fmt("fake_lookup_result_%s_%s", request_type_string(request_type), name.c_str()); - } - -static std::string fake_addr_lookup_result(const IPAddr& addr) - { - return util::fmt("fake_addr_lookup_result_%s", addr.AsString().c_str()); - } - -static void resolve_lookup_cb(DNS_Mgr::LookupCallback* callback, TableValPtr result) - { - callback->Resolved(std::move(result)); - delete callback; - } - -static void resolve_lookup_cb(DNS_Mgr::LookupCallback* callback, const std::string& result) - { - callback->Resolved(result); - delete callback; - } - -ValPtr DNS_Mgr::Lookup(const std::string& name, int request_type) - { - if ( shutting_down ) - return nullptr; - - if ( request_type == T_A || request_type == T_AAAA ) - return LookupHost(name); - - if ( mode == DNS_FAKE ) - return make_intrusive(fake_lookup_result(name, request_type)); - - InitSource(); - - if ( mode != DNS_PRIME ) - { - if ( auto val = LookupOtherInCache(name, request_type, false) ) - return val; - } - - switch ( mode ) - { - case DNS_PRIME: - { - auto req = new DNS_Request(name, request_type); - req->MakeRequest(channel, this); - return empty_addr_set(); - } - - case DNS_FORCE: - reporter->FatalError("can't find DNS entry for %s (req type %d / %s) in cache", - name.c_str(), request_type, request_type_string(request_type)); - return nullptr; - - case DNS_DEFAULT: - { - auto req = new DNS_Request(name, request_type); - req->MakeRequest(channel, this); - Resolve(); - - // Call LookupHost() a second time to get the newly stored value out of the cache. - return Lookup(name, request_type); - } - - default: - reporter->InternalError("bad mode %d in DNS_Mgr::Lookup", mode); - return nullptr; - } - - return nullptr; - } - -TableValPtr DNS_Mgr::LookupHost(const std::string& name) - { - if ( shutting_down ) - return nullptr; - - if ( mode == DNS_FAKE ) - return fake_name_lookup_result(name); - - InitSource(); - - // Check the cache before attempting to look up the name remotely. - if ( mode != DNS_PRIME ) - { - if ( auto val = LookupNameInCache(name, false, true) ) - return val; - } - - // Not found, or priming. - switch ( mode ) - { - case DNS_PRIME: - { - // We pass T_A here, but DNSRequest::MakeRequest() will special-case that in - // a request that gets both T_A and T_AAAA results at one time. - auto req = new DNS_Request(name, T_A); - req->MakeRequest(channel, this); - return empty_addr_set(); - } - - case DNS_FORCE: - reporter->FatalError("can't find DNS entry for %s in cache", name.c_str()); - return nullptr; - - case DNS_DEFAULT: - { - // We pass T_A here, but DNSRequest::MakeRequest() will special-case that in - // a request that gets both T_A and T_AAAA results at one time. - auto req = new DNS_Request(name, T_A); - req->MakeRequest(channel, this); - Resolve(); - - // Call LookupHost() a second time to get the newly stored value out of the cache. - return LookupHost(name); - } - - default: - reporter->InternalError("bad mode in DNS_Mgr::LookupHost"); - return nullptr; - } - } - -StringValPtr DNS_Mgr::LookupAddr(const IPAddr& addr) - { - if ( shutting_down ) - return nullptr; - - if ( mode == DNS_FAKE ) - return make_intrusive(fake_addr_lookup_result(addr)); - - InitSource(); - - // Check the cache before attempting to look up the name remotely. - if ( mode != DNS_PRIME ) - { - if ( auto val = LookupAddrInCache(addr, false, true) ) - return val; - } - - // Not found, or priming. - switch ( mode ) - { - case DNS_PRIME: - { - auto req = new DNS_Request(addr); - req->MakeRequest(channel, this); - return make_intrusive(""); - } - - case DNS_FORCE: - reporter->FatalError("can't find DNS entry for %s in cache", addr.AsString().c_str()); - return nullptr; - - case DNS_DEFAULT: - { - auto req = new DNS_Request(addr); - req->MakeRequest(channel, this); - Resolve(); - - // Call LookupAddr() a second time to get the newly stored value out of the cache. - return LookupAddr(addr); - } - - default: - reporter->InternalError("bad mode in DNS_Mgr::LookupAddr"); - return nullptr; - } - } - -void DNS_Mgr::LookupHost(const std::string& name, LookupCallback* callback) - { - if ( shutting_down ) - return; - - if ( mode == DNS_FAKE ) - { - resolve_lookup_cb(callback, fake_name_lookup_result(name)); - return; - } - - // Do we already know the answer? - if ( auto addrs = LookupNameInCache(name, true, false) ) - { - resolve_lookup_cb(callback, std::move(addrs)); - return; - } - - AsyncRequest* req = nullptr; - - // If we already have a request waiting for this host, we don't need to make - // another one. We can just add the callback to it and it'll get handled - // when the first request comes back. - auto key = std::make_pair(T_A, name); - auto i = asyncs.find(key); - if ( i != asyncs.end() ) - req = i->second; - else - { - // A new one. - req = new AsyncRequest{name, T_A}; - asyncs_queued.push_back(req); - asyncs.emplace_hint(i, std::move(key), req); - } - - req->callbacks.push_back(callback); - - // There may be requests in the queue that haven't been processed yet - // so go ahead and reissue them, even if this method didn't change - // anything. - IssueAsyncRequests(); - } - -void DNS_Mgr::LookupAddr(const IPAddr& addr, LookupCallback* callback) - { - if ( shutting_down ) - return; - - if ( mode == DNS_FAKE ) - { - resolve_lookup_cb(callback, fake_addr_lookup_result(addr)); - return; - } - - // Do we already know the answer? - if ( auto name = LookupAddrInCache(addr, true, false) ) - { - resolve_lookup_cb(callback, name->CheckString()); - return; - } - - AsyncRequest* req = nullptr; - - // If we already have a request waiting for this host, we don't need to make - // another one. We can just add the callback to it and it'll get handled - // when the first request comes back. - auto i = asyncs.find(addr); - if ( i != asyncs.end() ) - req = i->second; - else - { - // A new one. - req = new AsyncRequest{addr}; - asyncs_queued.push_back(req); - asyncs.emplace_hint(i, addr, req); - } - - req->callbacks.push_back(callback); - - // There may be requests in the queue that haven't been processed yet - // so go ahead and reissue them, even if this method didn't change - // anything. - IssueAsyncRequests(); - } - -void DNS_Mgr::Lookup(const std::string& name, int request_type, LookupCallback* callback) - { - if ( shutting_down ) - return; - - if ( mode == DNS_FAKE ) - { - resolve_lookup_cb(callback, fake_lookup_result(name, request_type)); - return; - } - - // Do we already know the answer? - if ( auto txt = LookupOtherInCache(name, request_type, true) ) - { - resolve_lookup_cb(callback, txt->CheckString()); - return; - } - - AsyncRequest* req = nullptr; - - // If we already have a request waiting for this host, we don't need to make - // another one. We can just add the callback to it and it'll get handled - // when the first request comes back. - auto key = std::make_pair(request_type, name); - auto i = asyncs.find(key); - if ( i != asyncs.end() ) - req = i->second; - else - { - // A new one. - req = new AsyncRequest{name, request_type}; - asyncs_queued.push_back(req); - asyncs.emplace_hint(i, std::move(key), req); - } - - req->callbacks.push_back(callback); - - IssueAsyncRequests(); - } - -void DNS_Mgr::Resolve() - { - int nfds = 0; - struct timeval *tvp, tv; - struct pollfd pollfds[ARES_GETSOCK_MAXNUM]; - ares_socket_t socks[ARES_GETSOCK_MAXNUM]; - - tv.tv_sec = DNS_TIMEOUT; - tv.tv_usec = 0; - - for ( int i = 0; i < MAX_PENDING_REQUESTS; i++ ) - { - int nfds = 0; - int bitmap = ares_getsock(channel, socks, ARES_GETSOCK_MAXNUM); - - for ( int i = 0; i < ARES_GETSOCK_MAXNUM; i++ ) - { - bool rd = ARES_GETSOCK_READABLE(bitmap, i); - bool wr = ARES_GETSOCK_WRITABLE(bitmap, i); - if ( rd || wr ) - { - pollfds[nfds].fd = socks[i]; - pollfds[nfds].events = rd ? POLLIN : 0; - pollfds[nfds].events |= wr ? POLLOUT : 0; - ++nfds; - } - } - - // Do we have any sockets that are read or writable? - if ( nfds == 0 ) - break; - - // poll() timeout is in milliseconds. - tvp = ares_timeout(channel, &tv, &tv); - int timeout_ms = tvp->tv_sec * 1000 + tvp->tv_usec / 1000; - - int res = poll(pollfds, nfds, timeout_ms); - - if ( res > 0 ) - { - for ( int i = 0; i < nfds; i++ ) - { - int rdfd = pollfds[i].revents & POLLIN ? pollfds[i].fd : ARES_SOCKET_BAD; - int wrfd = pollfds[i].revents & POLLOUT ? pollfds[i].fd : ARES_SOCKET_BAD; - - if ( rdfd != ARES_SOCKET_BAD || wrfd != ARES_SOCKET_BAD ) - ares_process_fd(channel, rdfd, wrfd); - } - } - else if ( res == 0 ) - // Do timeout processing when poll() timed out. - ares_process_fd(channel, ARES_SOCKET_BAD, ARES_SOCKET_BAD); - } - } - -void DNS_Mgr::Event(EventHandlerPtr e, const DNS_MappingPtr& dm) - { - if ( e ) - event_mgr.Enqueue(e, BuildMappingVal(dm)); - } - -void DNS_Mgr::Event(EventHandlerPtr e, const DNS_MappingPtr& dm, ListValPtr l1, ListValPtr l2) - { - if ( e ) - event_mgr.Enqueue(e, BuildMappingVal(dm), l1->ToSetVal(), l2->ToSetVal()); - } - -void DNS_Mgr::Event(EventHandlerPtr e, const DNS_MappingPtr& old_dm, DNS_MappingPtr new_dm) - { - if ( e ) - event_mgr.Enqueue(e, BuildMappingVal(old_dm), BuildMappingVal(new_dm)); - } - -ValPtr DNS_Mgr::BuildMappingVal(const DNS_MappingPtr& dm) - { - if ( ! dm_rec ) - return nullptr; - - auto r = make_intrusive(dm_rec); - - r->AssignTime(0, dm->CreationTime()); - r->Assign(1, dm->ReqHost() ? dm->ReqHost() : ""); - r->Assign(2, make_intrusive(dm->ReqAddr())); - r->Assign(3, dm->Valid()); - - auto h = dm->Host(); - r->Assign(4, h ? std::move(h) : make_intrusive("")); - r->Assign(5, dm->AddrsSet()); - - return r; - } - -void DNS_Mgr::AddResult(DNS_Request* dr, struct hostent* h, uint32_t ttl, bool merge) - { - // TODO: the existing code doesn't handle hostname aliases at all. Should we? - - DNS_MappingPtr new_mapping = nullptr; - DNS_MappingPtr prev_mapping = nullptr; - bool keep_prev = true; - - MappingMap::iterator it; - if ( dr->RequestType() == T_PTR ) - { - new_mapping = std::make_shared(dr->Addr(), h, ttl); - it = all_mappings.find(dr->Addr()); - if ( it == all_mappings.end() ) - { - auto result = all_mappings.emplace(dr->Addr(), new_mapping); - it = result.first; - } - else - prev_mapping = it->second; - } - else - { - new_mapping = std::make_shared(dr->Host(), h, ttl, dr->RequestType()); - auto key = std::make_pair(dr->RequestType(), dr->Host()); - - it = all_mappings.find(key); - if ( it == all_mappings.end() ) - { - auto result = all_mappings.emplace(std::move(key), new_mapping); - it = result.first; - } - else - prev_mapping = it->second; - } - - if ( prev_mapping && prev_mapping->Valid() ) - { - if ( new_mapping->Valid() ) - { - if ( merge ) - new_mapping->Merge(prev_mapping); - - it->second = new_mapping; - keep_prev = false; - } - } - else - { - it->second = new_mapping; - keep_prev = false; - } - - if ( prev_mapping && ! dr->IsTxt() ) - CompareMappings(prev_mapping, new_mapping); - - if ( keep_prev ) - new_mapping.reset(); - else - prev_mapping.reset(); - } - -void DNS_Mgr::CompareMappings(const DNS_MappingPtr& prev_mapping, const DNS_MappingPtr& new_mapping) - { - if ( prev_mapping->Failed() ) - { - if ( new_mapping->Failed() ) - // Nothing changed. - return; - - Event(dns_mapping_valid, new_mapping); - return; - } - - else if ( new_mapping->Failed() ) - { - Event(dns_mapping_unverified, prev_mapping); - return; - } - - auto prev_s = prev_mapping->Host(); - auto new_s = new_mapping->Host(); - - if ( prev_s || new_s ) - { - if ( ! prev_s ) - Event(dns_mapping_new_name, new_mapping); - else if ( ! new_s ) - Event(dns_mapping_lost_name, prev_mapping); - else if ( ! Bstr_eq(new_s->AsString(), prev_s->AsString()) ) - Event(dns_mapping_name_changed, prev_mapping, new_mapping); - } - - auto prev_a = prev_mapping->Addrs(); - auto new_a = new_mapping->Addrs(); - - if ( ! prev_a || ! new_a ) - { - reporter->InternalWarning("confused in DNS_Mgr::CompareMappings"); - return; - } - - auto prev_delta = AddrListDelta(prev_a, new_a); - auto new_delta = AddrListDelta(new_a, prev_a); - - if ( prev_delta->Length() > 0 || new_delta->Length() > 0 ) - Event(dns_mapping_altered, new_mapping, std::move(prev_delta), std::move(new_delta)); - } - -ListValPtr DNS_Mgr::AddrListDelta(ListValPtr al1, ListValPtr al2) - { - auto delta = make_intrusive(TYPE_ADDR); - - for ( int i = 0; i < al1->Length(); ++i ) - { - const IPAddr& al1_i = al1->Idx(i)->AsAddr(); - - int j; - for ( j = 0; j < al2->Length(); ++j ) - { - const IPAddr& al2_j = al2->Idx(j)->AsAddr(); - if ( al1_i == al2_j ) - break; - } - - if ( j >= al2->Length() ) - // Didn't find it. - delta->Append(al1->Idx(i)); - } - - return delta; - } - -void DNS_Mgr::LoadCache(const std::string& path) - { - FILE* f = fopen(path.c_str(), "r"); - - if ( ! f ) - return; - - if ( ! DNS_Mapping::ValidateCacheVersion(f) ) - { - fclose(f); - return; - } - - // Loop until we find a mapping that doesn't initialize correctly. - auto m = std::make_shared(f); - for ( ; ! m->NoMapping() && ! m->InitFailed(); m = std::make_shared(f) ) - { - if ( m->ReqHost() ) - all_mappings.insert_or_assign(std::make_pair(m->ReqType(), m->ReqHost()), m); - else - all_mappings.insert_or_assign(m->ReqAddr(), m); - } - - if ( ! m->NoMapping() ) - reporter->FatalError("DNS cache corrupted"); - - fclose(f); - } - -bool DNS_Mgr::Save() - { - if ( cache_name.empty() ) - return false; - - FILE* f = fopen(cache_name.c_str(), "w"); - - if ( ! f ) - return false; - - DNS_Mapping::InitializeCache(f); - Save(f, all_mappings); - - fclose(f); - - return true; - } - -void DNS_Mgr::Save(FILE* f, const MappingMap& m) - { - for ( const auto& [key, mapping] : m ) - { - if ( mapping ) - mapping->Save(f); - } - } - -TableValPtr DNS_Mgr::LookupNameInCache(const std::string& name, bool cleanup_expired, - bool check_failed) - { - auto it = all_mappings.find(std::make_pair(T_A, name)); - if ( it == all_mappings.end() ) - return nullptr; - - auto d = it->second; - - if ( ! d || d->names.empty() ) - return nullptr; - - if ( cleanup_expired && (d && d->Expired()) ) - { - all_mappings.erase(it); - - // If the TTL is zero, we're immediately expiring the response. We don't want - // to return though because the response was valid for a brief moment in time. - if ( d->TTL() != 0 ) - return nullptr; - } - - if ( check_failed && (d && d->Failed()) ) - { - reporter->Warning("Can't resolve host: %s", name.c_str()); - return empty_addr_set(); - } - - return d->AddrsSet(); - } - -StringValPtr DNS_Mgr::LookupAddrInCache(const IPAddr& addr, bool cleanup_expired, bool check_failed) - { - auto it = all_mappings.find(addr); - if ( it == all_mappings.end() ) - return nullptr; - - auto d = it->second; - - if ( cleanup_expired && d->Expired() ) - { - all_mappings.erase(it); - - // If the TTL is zero, we're immediately expiring the response. We don't want - // to return though because the response was valid for a brief moment in time. - if ( d->TTL() != 0 ) - return nullptr; - } - else if ( check_failed && d->Failed() ) - { - std::string s(addr); - reporter->Warning("can't resolve IP address: %s", s.c_str()); - return make_intrusive(s); - } - - if ( d->Host() ) - return d->Host(); - - return make_intrusive("<\?\?\?>"); - } - -StringValPtr DNS_Mgr::LookupOtherInCache(const std::string& name, int request_type, - bool cleanup_expired) - { - auto it = all_mappings.find(std::make_pair(request_type, name)); - if ( it == all_mappings.end() ) - return nullptr; - - auto d = it->second; - - if ( cleanup_expired && d->Expired() ) - { - all_mappings.erase(it); - - // If the TTL is zero, we're immediately expiring the response. We don't want - // to return though because the response was valid for a brief moment in time. - if ( d->TTL() != 0 ) - return nullptr; - } - - if ( d->Host() ) - return d->Host(); - - return make_intrusive("<\?\?\?>"); - } - -void DNS_Mgr::IssueAsyncRequests() - { - while ( ! asyncs_queued.empty() && asyncs_pending < MAX_PENDING_REQUESTS ) - { - DNS_Request* dns_req = nullptr; - AsyncRequest* req = asyncs_queued.front(); - asyncs_queued.pop_front(); - - ++num_requests; - req->time = util::current_time(); - - if ( req->type == T_PTR ) - dns_req = new DNS_Request(req->addr, true); - else if ( req->type == T_A || req->type == T_AAAA ) - // We pass T_A here, but DNSRequest::MakeRequest() will special-case that in - // a request that gets both T_A and T_AAAA results at one time. - dns_req = new DNS_Request(req->host.c_str(), T_A, true); - else - dns_req = new DNS_Request(req->host.c_str(), req->type, true); - - dns_req->MakeRequest(channel, this); - - ++asyncs_pending; - } - } - -void DNS_Mgr::CheckAsyncHostRequest(const std::string& host, bool timeout) - { - // Note that this code is a mirror of that for CheckAsyncAddrRequest. - auto i = asyncs.find(std::make_pair(T_A, host)); - - if ( i != asyncs.end() ) - { - if ( timeout ) - { - ++failed; - i->second->Timeout(); - } - else if ( auto addrs = LookupNameInCache(host, true, false) ) - { - ++successful; - i->second->Resolved(addrs); - } - else - return; - - delete i->second; - asyncs.erase(i); - --asyncs_pending; - } - } - -void DNS_Mgr::CheckAsyncAddrRequest(const IPAddr& addr, bool timeout) - { - // Note that this code is a mirror of that for CheckAsyncHostRequest. - - // In the following, if it's not in the respective map anymore, we've - // already finished it earlier and don't have anything to do. - auto i = asyncs.find(addr); - - if ( i != asyncs.end() ) - { - if ( timeout ) - { - ++failed; - i->second->Timeout(); - } - else if ( auto name = LookupAddrInCache(addr, true, false) ) - { - ++successful; - i->second->Resolved(name->CheckString()); - } - else - return; - - delete i->second; - asyncs.erase(i); - --asyncs_pending; - } - } - -void DNS_Mgr::CheckAsyncOtherRequest(const std::string& host, bool timeout, int request_type) - { - // Note that this code is a mirror of that for CheckAsyncAddrRequest. - - auto i = asyncs.find(std::make_pair(request_type, host)); - if ( i != asyncs.end() ) - { - if ( timeout ) - { - ++failed; - i->second->Timeout(); - } - else if ( auto name = LookupOtherInCache(host, request_type, true) ) - { - ++successful; - i->second->Resolved(name->CheckString()); - } - else - return; - - delete i->second; - asyncs.erase(i); - --asyncs_pending; - } - } - -void DNS_Mgr::Flush() - { - Resolve(); - all_mappings.clear(); - } - -double DNS_Mgr::GetNextTimeout() - { - if ( asyncs_pending == 0 ) - return -1; - - int nfds = 0; - ares_socket_t socks[ARES_GETSOCK_MAXNUM]; - int bitmap = ares_getsock(channel, socks, ARES_GETSOCK_MAXNUM); - for ( int i = 0; i < ARES_GETSOCK_MAXNUM; i++ ) - { - if ( ARES_GETSOCK_READABLE(bitmap, i) || ARES_GETSOCK_WRITABLE(bitmap, i) ) - ++nfds; - } - - // Do we have any sockets that are read or writable? - if ( nfds == 0 ) - return -1; - - struct timeval tv; - tv.tv_sec = DNS_TIMEOUT; - tv.tv_usec = 0; - - struct timeval* tvp = ares_timeout(channel, &tv, &tv); - - return static_cast(tvp->tv_sec) + (static_cast(tvp->tv_usec) / 1e6); - } - -void DNS_Mgr::ProcessFd(int fd, int flags) - { - if ( socket_fds.count(fd) != 0 ) - { - int read_fd = (flags & IOSource::ProcessFlags::READ) != 0 ? fd : ARES_SOCKET_BAD; - int write_fd = (flags & IOSource::ProcessFlags::WRITE) != 0 ? fd : ARES_SOCKET_BAD; - ares_process_fd(channel, read_fd, write_fd); - } - - IssueAsyncRequests(); - } - -void DNS_Mgr::Process() - { - // Process() is called when DNS_Mgr is found "ready" when its - // GetNextTimeout() returns 0.0, but there's no active FD. - // - // Kick off timeouts at least. - ares_process_fd(channel, ARES_SOCKET_BAD, ARES_SOCKET_BAD); - } - -void DNS_Mgr::GetStats(Stats* stats) - { - // TODO: can this use the telemetry framework? - stats->requests = num_requests; - stats->successful = successful; - stats->failed = failed; - stats->pending = asyncs_pending; - - stats->cached_hosts = 0; - stats->cached_addresses = 0; - stats->cached_texts = 0; - stats->cached_total = all_mappings.size(); - - for ( const auto& [key, mapping] : all_mappings ) - { - if ( mapping->ReqType() == T_PTR ) - stats->cached_addresses++; - else if ( mapping->ReqType() == T_A ) - stats->cached_hosts++; - else - stats->cached_texts++; - } - } - -void DNS_Mgr::AsyncRequest::Resolved(const std::string& name) - { - for ( const auto& cb : callbacks ) - { - cb->Resolved(name); - if ( ! doctest::is_running_in_test ) - delete cb; - } - - callbacks.clear(); - processed = true; - } - -void DNS_Mgr::AsyncRequest::Resolved(TableValPtr addrs) - { - for ( const auto& cb : callbacks ) - { - cb->Resolved(addrs); - if ( ! doctest::is_running_in_test ) - delete cb; - } - - callbacks.clear(); - processed = true; - } - -void DNS_Mgr::AsyncRequest::Timeout() - { - for ( const auto& cb : callbacks ) - { - cb->Timeout(); - if ( ! doctest::is_running_in_test ) - delete cb; - } - - callbacks.clear(); - processed = true; - } - -TableValPtr DNS_Mgr::empty_addr_set() - { - // TODO: can this be returned statically as well? Does the result get used in a way - // that would modify the same value being returned repeatedly? - auto addr_t = base_type(TYPE_ADDR); - auto set_index = make_intrusive(addr_t); - set_index->Append(std::move(addr_t)); - auto s = make_intrusive(std::move(set_index), nullptr); - return make_intrusive(std::move(s)); - } +static void sock_cb(void* data, ares_socket_t s, int read, int write) { + auto mgr = reinterpret_cast(data); + mgr->RegisterSocket((int)s, read == 1, write == 1); +} + +DNS_Mgr::DNS_Mgr(DNS_MgrMode arg_mode) : IOSource(true), mode(arg_mode) { ares_library_init(ARES_LIB_INIT_ALL); } + +DNS_Mgr::~DNS_Mgr() { + Flush(); + + ares_cancel(channel); + ares_destroy(channel); + ares_library_cleanup(); +} + +void DNS_Mgr::Done() { + shutting_down = true; + Flush(); +} + +void DNS_Mgr::RegisterSocket(int fd, bool read, bool write) { + if ( read && socket_fds.count(fd) == 0 ) { + socket_fds.insert(fd); + iosource_mgr->RegisterFd(fd, this, IOSource::READ); + } + else if ( ! read && socket_fds.count(fd) != 0 ) { + socket_fds.erase(fd); + iosource_mgr->UnregisterFd(fd, this, IOSource::READ); + } + + if ( write && write_socket_fds.count(fd) == 0 ) { + write_socket_fds.insert(fd); + iosource_mgr->RegisterFd(fd, this, IOSource::WRITE); + } + else if ( ! write && write_socket_fds.count(fd) != 0 ) { + write_socket_fds.erase(fd); + iosource_mgr->UnregisterFd(fd, this, IOSource::WRITE); + } +} + +void DNS_Mgr::InitSource() { + if ( did_init ) + return; + + ares_options options; + int optmask = 0; + + // Enable an EDNS option to be sent with the requests. This allows us to set + // a bigger UDP buffer size in the request, which prevents fallback to TCP + // at least up to that size. + options.flags = ARES_FLAG_EDNS; + optmask |= ARES_OPT_FLAGS; + + options.ednspsz = MAX_UDP_BUFFER_SIZE; + optmask |= ARES_OPT_EDNSPSZ; + + options.socket_receive_buffer_size = MAX_UDP_BUFFER_SIZE; + optmask |= ARES_OPT_SOCK_RCVBUF; + + // This option is in milliseconds. + options.timeout = DNS_TIMEOUT * 1000; + optmask |= ARES_OPT_TIMEOUTMS; + + // This causes c-ares to only attempt each server twice before + // giving up. + options.tries = 2; + optmask |= ARES_OPT_TRIES; + + // See the comment on sock_cb for how this gets used. + options.sock_state_cb = sock_cb; + options.sock_state_cb_data = this; + optmask |= ARES_OPT_SOCK_STATE_CB; + + int status = ares_init_options(&channel, &options, optmask); + if ( status != ARES_SUCCESS ) + reporter->FatalError("Failed to initialize c-ares for DNS resolution: %s", ares_strerror(status)); + + // Note that Init() may be called by way of LookupHost() during the act of + // parsing a hostname literal (e.g. google.com), so we can't use a + // script-layer option to configure the DNS resolver as it may not be + // configured to the user's desired address at the time when we need to to + // the lookup. + auto dns_resolver = getenv("ZEEK_DNS_RESOLVER"); + if ( dns_resolver ) { + ares_addr_node servers; + servers.next = NULL; + + auto dns_resolver_addr = IPAddr(dns_resolver); + + if ( dns_resolver_addr.GetFamily() == IPv4 ) { + servers.family = AF_INET; + dns_resolver_addr.CopyIPv4(&(servers.addr.addr4)); + } + else { + struct sockaddr_in6 sa = {0}; + sa.sin6_family = AF_INET6; + dns_resolver_addr.CopyIPv6(&sa.sin6_addr); + + servers.family = AF_INET6; + memcpy(&(servers.addr.addr6), &sa.sin6_addr, sizeof(ares_in6_addr)); + } + + ares_set_servers(channel, &servers); + } + + did_init = true; +} + +void DNS_Mgr::InitPostScript() { + if ( ! doctest::is_running_in_test ) { + dm_rec = id::find_type("dns_mapping"); + + // Registering will call InitSource(), which sets up all of the DNS library stuff + iosource_mgr->Register(this, true); + } + else { + // This would normally be called when registering the iosource above. + InitSource(); + } + + // Load the DNS cache from disk, if it exists. + std::string cache_dir = dir.empty() ? "." : dir; + cache_name = util::fmt("%s/%s", cache_dir.c_str(), ".zeek-dns-cache"); + LoadCache(cache_name); +} + +static TableValPtr fake_name_lookup_result(const std::string& name) { + hash128_t hash; + KeyedHash::StaticHash128(name.c_str(), name.size(), &hash); + auto hv = make_intrusive(TYPE_ADDR); + hv->Append(make_intrusive(reinterpret_cast(&hash))); + return hv->ToSetVal(); +} + +static std::string fake_lookup_result(const std::string& name, int request_type) { + return util::fmt("fake_lookup_result_%s_%s", request_type_string(request_type), name.c_str()); +} + +static std::string fake_addr_lookup_result(const IPAddr& addr) { + return util::fmt("fake_addr_lookup_result_%s", addr.AsString().c_str()); +} + +static void resolve_lookup_cb(DNS_Mgr::LookupCallback* callback, TableValPtr result) { + callback->Resolved(std::move(result)); + delete callback; +} + +static void resolve_lookup_cb(DNS_Mgr::LookupCallback* callback, const std::string& result) { + callback->Resolved(result); + delete callback; +} + +ValPtr DNS_Mgr::Lookup(const std::string& name, int request_type) { + if ( shutting_down ) + return nullptr; + + if ( request_type == T_A || request_type == T_AAAA ) + return LookupHost(name); + + if ( mode == DNS_FAKE ) + return make_intrusive(fake_lookup_result(name, request_type)); + + InitSource(); + + if ( mode != DNS_PRIME ) { + if ( auto val = LookupOtherInCache(name, request_type, false) ) + return val; + } + + switch ( mode ) { + case DNS_PRIME: { + auto req = new DNS_Request(name, request_type); + req->MakeRequest(channel, this); + return empty_addr_set(); + } + + case DNS_FORCE: + reporter->FatalError("can't find DNS entry for %s (req type %d / %s) in cache", name.c_str(), request_type, + request_type_string(request_type)); + return nullptr; + + case DNS_DEFAULT: { + auto req = new DNS_Request(name, request_type); + req->MakeRequest(channel, this); + Resolve(); + + // Call LookupHost() a second time to get the newly stored value out of the cache. + return Lookup(name, request_type); + } + + default: reporter->InternalError("bad mode %d in DNS_Mgr::Lookup", mode); return nullptr; + } + + return nullptr; +} + +TableValPtr DNS_Mgr::LookupHost(const std::string& name) { + if ( shutting_down ) + return nullptr; + + if ( mode == DNS_FAKE ) + return fake_name_lookup_result(name); + + InitSource(); + + // Check the cache before attempting to look up the name remotely. + if ( mode != DNS_PRIME ) { + if ( auto val = LookupNameInCache(name, false, true) ) + return val; + } + + // Not found, or priming. + switch ( mode ) { + case DNS_PRIME: { + // We pass T_A here, but DNSRequest::MakeRequest() will special-case that in + // a request that gets both T_A and T_AAAA results at one time. + auto req = new DNS_Request(name, T_A); + req->MakeRequest(channel, this); + return empty_addr_set(); + } + + case DNS_FORCE: reporter->FatalError("can't find DNS entry for %s in cache", name.c_str()); return nullptr; + + case DNS_DEFAULT: { + // We pass T_A here, but DNSRequest::MakeRequest() will special-case that in + // a request that gets both T_A and T_AAAA results at one time. + auto req = new DNS_Request(name, T_A); + req->MakeRequest(channel, this); + Resolve(); + + // Call LookupHost() a second time to get the newly stored value out of the cache. + return LookupHost(name); + } + + default: reporter->InternalError("bad mode in DNS_Mgr::LookupHost"); return nullptr; + } +} + +StringValPtr DNS_Mgr::LookupAddr(const IPAddr& addr) { + if ( shutting_down ) + return nullptr; + + if ( mode == DNS_FAKE ) + return make_intrusive(fake_addr_lookup_result(addr)); + + InitSource(); + + // Check the cache before attempting to look up the name remotely. + if ( mode != DNS_PRIME ) { + if ( auto val = LookupAddrInCache(addr, false, true) ) + return val; + } + + // Not found, or priming. + switch ( mode ) { + case DNS_PRIME: { + auto req = new DNS_Request(addr); + req->MakeRequest(channel, this); + return make_intrusive(""); + } + + case DNS_FORCE: + reporter->FatalError("can't find DNS entry for %s in cache", addr.AsString().c_str()); + return nullptr; + + case DNS_DEFAULT: { + auto req = new DNS_Request(addr); + req->MakeRequest(channel, this); + Resolve(); + + // Call LookupAddr() a second time to get the newly stored value out of the cache. + return LookupAddr(addr); + } + + default: reporter->InternalError("bad mode in DNS_Mgr::LookupAddr"); return nullptr; + } +} + +void DNS_Mgr::LookupHost(const std::string& name, LookupCallback* callback) { + if ( shutting_down ) + return; + + if ( mode == DNS_FAKE ) { + resolve_lookup_cb(callback, fake_name_lookup_result(name)); + return; + } + + // Do we already know the answer? + if ( auto addrs = LookupNameInCache(name, true, false) ) { + resolve_lookup_cb(callback, std::move(addrs)); + return; + } + + AsyncRequest* req = nullptr; + + // If we already have a request waiting for this host, we don't need to make + // another one. We can just add the callback to it and it'll get handled + // when the first request comes back. + auto key = std::make_pair(T_A, name); + auto i = asyncs.find(key); + if ( i != asyncs.end() ) + req = i->second; + else { + // A new one. + req = new AsyncRequest{name, T_A}; + asyncs_queued.push_back(req); + asyncs.emplace_hint(i, std::move(key), req); + } + + req->callbacks.push_back(callback); + + // There may be requests in the queue that haven't been processed yet + // so go ahead and reissue them, even if this method didn't change + // anything. + IssueAsyncRequests(); +} + +void DNS_Mgr::LookupAddr(const IPAddr& addr, LookupCallback* callback) { + if ( shutting_down ) + return; + + if ( mode == DNS_FAKE ) { + resolve_lookup_cb(callback, fake_addr_lookup_result(addr)); + return; + } + + // Do we already know the answer? + if ( auto name = LookupAddrInCache(addr, true, false) ) { + resolve_lookup_cb(callback, name->CheckString()); + return; + } + + AsyncRequest* req = nullptr; + + // If we already have a request waiting for this host, we don't need to make + // another one. We can just add the callback to it and it'll get handled + // when the first request comes back. + auto i = asyncs.find(addr); + if ( i != asyncs.end() ) + req = i->second; + else { + // A new one. + req = new AsyncRequest{addr}; + asyncs_queued.push_back(req); + asyncs.emplace_hint(i, addr, req); + } + + req->callbacks.push_back(callback); + + // There may be requests in the queue that haven't been processed yet + // so go ahead and reissue them, even if this method didn't change + // anything. + IssueAsyncRequests(); +} + +void DNS_Mgr::Lookup(const std::string& name, int request_type, LookupCallback* callback) { + if ( shutting_down ) + return; + + if ( mode == DNS_FAKE ) { + resolve_lookup_cb(callback, fake_lookup_result(name, request_type)); + return; + } + + // Do we already know the answer? + if ( auto txt = LookupOtherInCache(name, request_type, true) ) { + resolve_lookup_cb(callback, txt->CheckString()); + return; + } + + AsyncRequest* req = nullptr; + + // If we already have a request waiting for this host, we don't need to make + // another one. We can just add the callback to it and it'll get handled + // when the first request comes back. + auto key = std::make_pair(request_type, name); + auto i = asyncs.find(key); + if ( i != asyncs.end() ) + req = i->second; + else { + // A new one. + req = new AsyncRequest{name, request_type}; + asyncs_queued.push_back(req); + asyncs.emplace_hint(i, std::move(key), req); + } + + req->callbacks.push_back(callback); + + IssueAsyncRequests(); +} + +void DNS_Mgr::Resolve() { + int nfds = 0; + struct timeval *tvp, tv; + struct pollfd pollfds[ARES_GETSOCK_MAXNUM]; + ares_socket_t socks[ARES_GETSOCK_MAXNUM]; + + tv.tv_sec = DNS_TIMEOUT; + tv.tv_usec = 0; + + for ( int i = 0; i < MAX_PENDING_REQUESTS; i++ ) { + int nfds = 0; + int bitmap = ares_getsock(channel, socks, ARES_GETSOCK_MAXNUM); + + for ( int i = 0; i < ARES_GETSOCK_MAXNUM; i++ ) { + bool rd = ARES_GETSOCK_READABLE(bitmap, i); + bool wr = ARES_GETSOCK_WRITABLE(bitmap, i); + if ( rd || wr ) { + pollfds[nfds].fd = socks[i]; + pollfds[nfds].events = rd ? POLLIN : 0; + pollfds[nfds].events |= wr ? POLLOUT : 0; + ++nfds; + } + } + + // Do we have any sockets that are read or writable? + if ( nfds == 0 ) + break; + + // poll() timeout is in milliseconds. + tvp = ares_timeout(channel, &tv, &tv); + int timeout_ms = tvp->tv_sec * 1000 + tvp->tv_usec / 1000; + + int res = poll(pollfds, nfds, timeout_ms); + + if ( res > 0 ) { + for ( int i = 0; i < nfds; i++ ) { + int rdfd = pollfds[i].revents & POLLIN ? pollfds[i].fd : ARES_SOCKET_BAD; + int wrfd = pollfds[i].revents & POLLOUT ? pollfds[i].fd : ARES_SOCKET_BAD; + + if ( rdfd != ARES_SOCKET_BAD || wrfd != ARES_SOCKET_BAD ) + ares_process_fd(channel, rdfd, wrfd); + } + } + else if ( res == 0 ) + // Do timeout processing when poll() timed out. + ares_process_fd(channel, ARES_SOCKET_BAD, ARES_SOCKET_BAD); + } +} + +void DNS_Mgr::Event(EventHandlerPtr e, const DNS_MappingPtr& dm) { + if ( e ) + event_mgr.Enqueue(e, BuildMappingVal(dm)); +} + +void DNS_Mgr::Event(EventHandlerPtr e, const DNS_MappingPtr& dm, ListValPtr l1, ListValPtr l2) { + if ( e ) + event_mgr.Enqueue(e, BuildMappingVal(dm), l1->ToSetVal(), l2->ToSetVal()); +} + +void DNS_Mgr::Event(EventHandlerPtr e, const DNS_MappingPtr& old_dm, DNS_MappingPtr new_dm) { + if ( e ) + event_mgr.Enqueue(e, BuildMappingVal(old_dm), BuildMappingVal(new_dm)); +} + +ValPtr DNS_Mgr::BuildMappingVal(const DNS_MappingPtr& dm) { + if ( ! dm_rec ) + return nullptr; + + auto r = make_intrusive(dm_rec); + + r->AssignTime(0, dm->CreationTime()); + r->Assign(1, dm->ReqHost() ? dm->ReqHost() : ""); + r->Assign(2, make_intrusive(dm->ReqAddr())); + r->Assign(3, dm->Valid()); + + auto h = dm->Host(); + r->Assign(4, h ? std::move(h) : make_intrusive("")); + r->Assign(5, dm->AddrsSet()); + + return r; +} + +void DNS_Mgr::AddResult(DNS_Request* dr, struct hostent* h, uint32_t ttl, bool merge) { + // TODO: the existing code doesn't handle hostname aliases at all. Should we? + + DNS_MappingPtr new_mapping = nullptr; + DNS_MappingPtr prev_mapping = nullptr; + bool keep_prev = true; + + MappingMap::iterator it; + if ( dr->RequestType() == T_PTR ) { + new_mapping = std::make_shared(dr->Addr(), h, ttl); + it = all_mappings.find(dr->Addr()); + if ( it == all_mappings.end() ) { + auto result = all_mappings.emplace(dr->Addr(), new_mapping); + it = result.first; + } + else + prev_mapping = it->second; + } + else { + new_mapping = std::make_shared(dr->Host(), h, ttl, dr->RequestType()); + auto key = std::make_pair(dr->RequestType(), dr->Host()); + + it = all_mappings.find(key); + if ( it == all_mappings.end() ) { + auto result = all_mappings.emplace(std::move(key), new_mapping); + it = result.first; + } + else + prev_mapping = it->second; + } + + if ( prev_mapping && prev_mapping->Valid() ) { + if ( new_mapping->Valid() ) { + if ( merge ) + new_mapping->Merge(prev_mapping); + + it->second = new_mapping; + keep_prev = false; + } + } + else { + it->second = new_mapping; + keep_prev = false; + } + + if ( prev_mapping && ! dr->IsTxt() ) + CompareMappings(prev_mapping, new_mapping); + + if ( keep_prev ) + new_mapping.reset(); + else + prev_mapping.reset(); +} + +void DNS_Mgr::CompareMappings(const DNS_MappingPtr& prev_mapping, const DNS_MappingPtr& new_mapping) { + if ( prev_mapping->Failed() ) { + if ( new_mapping->Failed() ) + // Nothing changed. + return; + + Event(dns_mapping_valid, new_mapping); + return; + } + + else if ( new_mapping->Failed() ) { + Event(dns_mapping_unverified, prev_mapping); + return; + } + + auto prev_s = prev_mapping->Host(); + auto new_s = new_mapping->Host(); + + if ( prev_s || new_s ) { + if ( ! prev_s ) + Event(dns_mapping_new_name, new_mapping); + else if ( ! new_s ) + Event(dns_mapping_lost_name, prev_mapping); + else if ( ! Bstr_eq(new_s->AsString(), prev_s->AsString()) ) + Event(dns_mapping_name_changed, prev_mapping, new_mapping); + } + + auto prev_a = prev_mapping->Addrs(); + auto new_a = new_mapping->Addrs(); + + if ( ! prev_a || ! new_a ) { + reporter->InternalWarning("confused in DNS_Mgr::CompareMappings"); + return; + } + + auto prev_delta = AddrListDelta(prev_a, new_a); + auto new_delta = AddrListDelta(new_a, prev_a); + + if ( prev_delta->Length() > 0 || new_delta->Length() > 0 ) + Event(dns_mapping_altered, new_mapping, std::move(prev_delta), std::move(new_delta)); +} + +ListValPtr DNS_Mgr::AddrListDelta(ListValPtr al1, ListValPtr al2) { + auto delta = make_intrusive(TYPE_ADDR); + + for ( int i = 0; i < al1->Length(); ++i ) { + const IPAddr& al1_i = al1->Idx(i)->AsAddr(); + + int j; + for ( j = 0; j < al2->Length(); ++j ) { + const IPAddr& al2_j = al2->Idx(j)->AsAddr(); + if ( al1_i == al2_j ) + break; + } + + if ( j >= al2->Length() ) + // Didn't find it. + delta->Append(al1->Idx(i)); + } + + return delta; +} + +void DNS_Mgr::LoadCache(const std::string& path) { + FILE* f = fopen(path.c_str(), "r"); + + if ( ! f ) + return; + + if ( ! DNS_Mapping::ValidateCacheVersion(f) ) { + fclose(f); + return; + } + + // Loop until we find a mapping that doesn't initialize correctly. + auto m = std::make_shared(f); + for ( ; ! m->NoMapping() && ! m->InitFailed(); m = std::make_shared(f) ) { + if ( m->ReqHost() ) + all_mappings.insert_or_assign(std::make_pair(m->ReqType(), m->ReqHost()), m); + else + all_mappings.insert_or_assign(m->ReqAddr(), m); + } + + if ( ! m->NoMapping() ) + reporter->FatalError("DNS cache corrupted"); + + fclose(f); +} + +bool DNS_Mgr::Save() { + if ( cache_name.empty() ) + return false; + + FILE* f = fopen(cache_name.c_str(), "w"); + + if ( ! f ) + return false; + + DNS_Mapping::InitializeCache(f); + Save(f, all_mappings); + + fclose(f); + + return true; +} + +void DNS_Mgr::Save(FILE* f, const MappingMap& m) { + for ( const auto& [key, mapping] : m ) { + if ( mapping ) + mapping->Save(f); + } +} + +TableValPtr DNS_Mgr::LookupNameInCache(const std::string& name, bool cleanup_expired, bool check_failed) { + auto it = all_mappings.find(std::make_pair(T_A, name)); + if ( it == all_mappings.end() ) + return nullptr; + + auto d = it->second; + + if ( ! d || d->names.empty() ) + return nullptr; + + if ( cleanup_expired && (d && d->Expired()) ) { + all_mappings.erase(it); + + // If the TTL is zero, we're immediately expiring the response. We don't want + // to return though because the response was valid for a brief moment in time. + if ( d->TTL() != 0 ) + return nullptr; + } + + if ( check_failed && (d && d->Failed()) ) { + reporter->Warning("Can't resolve host: %s", name.c_str()); + return empty_addr_set(); + } + + return d->AddrsSet(); +} + +StringValPtr DNS_Mgr::LookupAddrInCache(const IPAddr& addr, bool cleanup_expired, bool check_failed) { + auto it = all_mappings.find(addr); + if ( it == all_mappings.end() ) + return nullptr; + + auto d = it->second; + + if ( cleanup_expired && d->Expired() ) { + all_mappings.erase(it); + + // If the TTL is zero, we're immediately expiring the response. We don't want + // to return though because the response was valid for a brief moment in time. + if ( d->TTL() != 0 ) + return nullptr; + } + else if ( check_failed && d->Failed() ) { + std::string s(addr); + reporter->Warning("can't resolve IP address: %s", s.c_str()); + return make_intrusive(s); + } + + if ( d->Host() ) + return d->Host(); + + return make_intrusive("<\?\?\?>"); +} + +StringValPtr DNS_Mgr::LookupOtherInCache(const std::string& name, int request_type, bool cleanup_expired) { + auto it = all_mappings.find(std::make_pair(request_type, name)); + if ( it == all_mappings.end() ) + return nullptr; + + auto d = it->second; + + if ( cleanup_expired && d->Expired() ) { + all_mappings.erase(it); + + // If the TTL is zero, we're immediately expiring the response. We don't want + // to return though because the response was valid for a brief moment in time. + if ( d->TTL() != 0 ) + return nullptr; + } + + if ( d->Host() ) + return d->Host(); + + return make_intrusive("<\?\?\?>"); +} + +void DNS_Mgr::IssueAsyncRequests() { + while ( ! asyncs_queued.empty() && asyncs_pending < MAX_PENDING_REQUESTS ) { + DNS_Request* dns_req = nullptr; + AsyncRequest* req = asyncs_queued.front(); + asyncs_queued.pop_front(); + + ++num_requests; + req->time = util::current_time(); + + if ( req->type == T_PTR ) + dns_req = new DNS_Request(req->addr, true); + else if ( req->type == T_A || req->type == T_AAAA ) + // We pass T_A here, but DNSRequest::MakeRequest() will special-case that in + // a request that gets both T_A and T_AAAA results at one time. + dns_req = new DNS_Request(req->host.c_str(), T_A, true); + else + dns_req = new DNS_Request(req->host.c_str(), req->type, true); + + dns_req->MakeRequest(channel, this); + + ++asyncs_pending; + } +} + +void DNS_Mgr::CheckAsyncHostRequest(const std::string& host, bool timeout) { + // Note that this code is a mirror of that for CheckAsyncAddrRequest. + auto i = asyncs.find(std::make_pair(T_A, host)); + + if ( i != asyncs.end() ) { + if ( timeout ) { + ++failed; + i->second->Timeout(); + } + else if ( auto addrs = LookupNameInCache(host, true, false) ) { + ++successful; + i->second->Resolved(addrs); + } + else + return; + + delete i->second; + asyncs.erase(i); + --asyncs_pending; + } +} + +void DNS_Mgr::CheckAsyncAddrRequest(const IPAddr& addr, bool timeout) { + // Note that this code is a mirror of that for CheckAsyncHostRequest. + + // In the following, if it's not in the respective map anymore, we've + // already finished it earlier and don't have anything to do. + auto i = asyncs.find(addr); + + if ( i != asyncs.end() ) { + if ( timeout ) { + ++failed; + i->second->Timeout(); + } + else if ( auto name = LookupAddrInCache(addr, true, false) ) { + ++successful; + i->second->Resolved(name->CheckString()); + } + else + return; + + delete i->second; + asyncs.erase(i); + --asyncs_pending; + } +} + +void DNS_Mgr::CheckAsyncOtherRequest(const std::string& host, bool timeout, int request_type) { + // Note that this code is a mirror of that for CheckAsyncAddrRequest. + + auto i = asyncs.find(std::make_pair(request_type, host)); + if ( i != asyncs.end() ) { + if ( timeout ) { + ++failed; + i->second->Timeout(); + } + else if ( auto name = LookupOtherInCache(host, request_type, true) ) { + ++successful; + i->second->Resolved(name->CheckString()); + } + else + return; + + delete i->second; + asyncs.erase(i); + --asyncs_pending; + } +} + +void DNS_Mgr::Flush() { + Resolve(); + all_mappings.clear(); +} + +double DNS_Mgr::GetNextTimeout() { + if ( asyncs_pending == 0 ) + return -1; + + int nfds = 0; + ares_socket_t socks[ARES_GETSOCK_MAXNUM]; + int bitmap = ares_getsock(channel, socks, ARES_GETSOCK_MAXNUM); + for ( int i = 0; i < ARES_GETSOCK_MAXNUM; i++ ) { + if ( ARES_GETSOCK_READABLE(bitmap, i) || ARES_GETSOCK_WRITABLE(bitmap, i) ) + ++nfds; + } + + // Do we have any sockets that are read or writable? + if ( nfds == 0 ) + return -1; + + struct timeval tv; + tv.tv_sec = DNS_TIMEOUT; + tv.tv_usec = 0; + + struct timeval* tvp = ares_timeout(channel, &tv, &tv); + + return static_cast(tvp->tv_sec) + (static_cast(tvp->tv_usec) / 1e6); +} + +void DNS_Mgr::ProcessFd(int fd, int flags) { + if ( socket_fds.count(fd) != 0 ) { + int read_fd = (flags & IOSource::ProcessFlags::READ) != 0 ? fd : ARES_SOCKET_BAD; + int write_fd = (flags & IOSource::ProcessFlags::WRITE) != 0 ? fd : ARES_SOCKET_BAD; + ares_process_fd(channel, read_fd, write_fd); + } + + IssueAsyncRequests(); +} + +void DNS_Mgr::Process() { + // Process() is called when DNS_Mgr is found "ready" when its + // GetNextTimeout() returns 0.0, but there's no active FD. + // + // Kick off timeouts at least. + ares_process_fd(channel, ARES_SOCKET_BAD, ARES_SOCKET_BAD); +} + +void DNS_Mgr::GetStats(Stats* stats) { + // TODO: can this use the telemetry framework? + stats->requests = num_requests; + stats->successful = successful; + stats->failed = failed; + stats->pending = asyncs_pending; + + stats->cached_hosts = 0; + stats->cached_addresses = 0; + stats->cached_texts = 0; + stats->cached_total = all_mappings.size(); + + for ( const auto& [key, mapping] : all_mappings ) { + if ( mapping->ReqType() == T_PTR ) + stats->cached_addresses++; + else if ( mapping->ReqType() == T_A ) + stats->cached_hosts++; + else + stats->cached_texts++; + } +} + +void DNS_Mgr::AsyncRequest::Resolved(const std::string& name) { + for ( const auto& cb : callbacks ) { + cb->Resolved(name); + if ( ! doctest::is_running_in_test ) + delete cb; + } + + callbacks.clear(); + processed = true; +} + +void DNS_Mgr::AsyncRequest::Resolved(TableValPtr addrs) { + for ( const auto& cb : callbacks ) { + cb->Resolved(addrs); + if ( ! doctest::is_running_in_test ) + delete cb; + } + + callbacks.clear(); + processed = true; +} + +void DNS_Mgr::AsyncRequest::Timeout() { + for ( const auto& cb : callbacks ) { + cb->Timeout(); + if ( ! doctest::is_running_in_test ) + delete cb; + } + + callbacks.clear(); + processed = true; +} + +TableValPtr DNS_Mgr::empty_addr_set() { + // TODO: can this be returned statically as well? Does the result get used in a way + // that would modify the same value being returned repeatedly? + auto addr_t = base_type(TYPE_ADDR); + auto set_index = make_intrusive(addr_t); + set_index->Append(std::move(addr_t)); + auto s = make_intrusive(std::move(set_index), nullptr); + return make_intrusive(std::move(s)); +} ////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////// @@ -1600,327 +1368,305 @@ TableValPtr DNS_Mgr::empty_addr_set() // run them manually, pass the --no-skip flag when running tests. These tests are // run automatically as part of CI builds. -static std::vector get_result_addresses(TableValPtr addrs) - { - std::vector results; +static std::vector get_result_addresses(TableValPtr addrs) { + std::vector results; - auto m = addrs->ToMap(); - for ( const auto& [k, v] : m ) - { - auto lv = cast_intrusive(k); - auto lvv = lv->Vals(); - for ( const auto& addr : lvv ) - { - auto addr_ptr = cast_intrusive(addr); - results.push_back(addr_ptr->Get()); - } - } + auto m = addrs->ToMap(); + for ( const auto& [k, v] : m ) { + auto lv = cast_intrusive(k); + auto lvv = lv->Vals(); + for ( const auto& addr : lvv ) { + auto addr_ptr = cast_intrusive(addr); + results.push_back(addr_ptr->Get()); + } + } - return results; - } + return results; +} -class TestCallback : public DNS_Mgr::LookupCallback - { +class TestCallback : public DNS_Mgr::LookupCallback { public: - TestCallback() { } - void Resolved(const std::string& name) override - { - host_result = name; - done = true; - } - void Resolved(TableValPtr addrs) override - { - addr_results = get_result_addresses(addrs); - done = true; - } - void Timeout() override - { - timeout = true; - done = true; - } + TestCallback() {} + void Resolved(const std::string& name) override { + host_result = name; + done = true; + } + void Resolved(TableValPtr addrs) override { + addr_results = get_result_addresses(addrs); + done = true; + } + void Timeout() override { + timeout = true; + done = true; + } - std::string host_result; - std::vector addr_results; - bool done = false; - bool timeout = false; - }; + std::string host_result; + std::vector addr_results; + bool done = false; + bool timeout = false; +}; /** * Derived testing version of DNS_Mgr so that the Process() method can be exposed * publicly. If new unit tests are added, this class should be used over using * DNS_Mgr directly. */ -class TestDNS_Mgr final : public DNS_Mgr - { +class TestDNS_Mgr final : public DNS_Mgr { public: - explicit TestDNS_Mgr(DNS_MgrMode mode) : DNS_Mgr(mode) { } - void Process() override; - }; + explicit TestDNS_Mgr(DNS_MgrMode mode) : DNS_Mgr(mode) {} + void Process() override; +}; -void TestDNS_Mgr::Process() - { - // Only allow usage of this method when running unit tests. - assert(doctest::is_running_in_test); - Resolve(); - IssueAsyncRequests(); - } +void TestDNS_Mgr::Process() { + // Only allow usage of this method when running unit tests. + assert(doctest::is_running_in_test); + Resolve(); + IssueAsyncRequests(); +} -TEST_CASE("dns_mgr priming" * doctest::skip(true)) - { - // TODO: This test uses mkdtemp, which isn't available on Windows. +TEST_CASE("dns_mgr priming" * doctest::skip(true)) { + // TODO: This test uses mkdtemp, which isn't available on Windows. #ifndef _MSC_VER - char prefix[] = "/tmp/zeek-unit-test-XXXXXX"; - auto tmpdir = mkdtemp(prefix); + char prefix[] = "/tmp/zeek-unit-test-XXXXXX"; + auto tmpdir = mkdtemp(prefix); - // Create a manager to prime the cache, make a few requests, and the save - // the result. This tests that the priming code will create the requests but - // wait for Resolve() to actually make the requests. - TestDNS_Mgr mgr(DNS_PRIME); - mgr.SetDir(tmpdir); - mgr.InitPostScript(); + // Create a manager to prime the cache, make a few requests, and the save + // the result. This tests that the priming code will create the requests but + // wait for Resolve() to actually make the requests. + TestDNS_Mgr mgr(DNS_PRIME); + mgr.SetDir(tmpdir); + mgr.InitPostScript(); - auto host_result = mgr.LookupHost("one.one.one.one"); - REQUIRE(host_result != nullptr); - CHECK(host_result->EqualTo(TestDNS_Mgr::empty_addr_set())); + auto host_result = mgr.LookupHost("one.one.one.one"); + REQUIRE(host_result != nullptr); + CHECK(host_result->EqualTo(TestDNS_Mgr::empty_addr_set())); - IPAddr ones("1.1.1.1"); - auto addr_result = mgr.LookupAddr(ones); - CHECK(strcmp(addr_result->CheckString(), "") == 0); + IPAddr ones("1.1.1.1"); + auto addr_result = mgr.LookupAddr(ones); + CHECK(strcmp(addr_result->CheckString(), "") == 0); - // This should wait until we have all of the results back from the above - // requests. - mgr.Resolve(); + // This should wait until we have all of the results back from the above + // requests. + mgr.Resolve(); - // Save off the resulting values from Resolve() into a file on disk - // in the tmpdir created by mkdtemp. - REQUIRE(mgr.Save()); + // Save off the resulting values from Resolve() into a file on disk + // in the tmpdir created by mkdtemp. + REQUIRE(mgr.Save()); - // Make a second DNS manager and reload the cache that we just saved. - TestDNS_Mgr mgr2(DNS_FORCE); - dns_mgr = &mgr2; - mgr2.SetDir(tmpdir); - mgr2.InitPostScript(); + // Make a second DNS manager and reload the cache that we just saved. + TestDNS_Mgr mgr2(DNS_FORCE); + dns_mgr = &mgr2; + mgr2.SetDir(tmpdir); + mgr2.InitPostScript(); - // Make the same two requests, but verify that we're correctly getting - // data out of the cache. - host_result = mgr2.LookupHost("one.one.one.one"); - REQUIRE(host_result != nullptr); - CHECK_FALSE(host_result->EqualTo(TestDNS_Mgr::empty_addr_set())); + // Make the same two requests, but verify that we're correctly getting + // data out of the cache. + host_result = mgr2.LookupHost("one.one.one.one"); + REQUIRE(host_result != nullptr); + CHECK_FALSE(host_result->EqualTo(TestDNS_Mgr::empty_addr_set())); - addr_result = mgr2.LookupAddr(ones); - REQUIRE(addr_result != nullptr); - CHECK(strcmp(addr_result->CheckString(), "one.one.one.one") == 0); + addr_result = mgr2.LookupAddr(ones); + REQUIRE(addr_result != nullptr); + CHECK(strcmp(addr_result->CheckString(), "one.one.one.one") == 0); - // Clean up cache file and the temp directory - unlink(mgr2.CacheFile().c_str()); - rmdir(tmpdir); + // Clean up cache file and the temp directory + unlink(mgr2.CacheFile().c_str()); + rmdir(tmpdir); #endif - } +} -TEST_CASE("dns_mgr alternate server" * doctest::skip(true)) - { - char* old_server = getenv("ZEEK_DNS_RESOLVER"); +TEST_CASE("dns_mgr alternate server" * doctest::skip(true)) { + char* old_server = getenv("ZEEK_DNS_RESOLVER"); - setenv("ZEEK_DNS_RESOLVER", "1.1.1.1", 1); - TestDNS_Mgr mgr(DNS_DEFAULT); + setenv("ZEEK_DNS_RESOLVER", "1.1.1.1", 1); + TestDNS_Mgr mgr(DNS_DEFAULT); - mgr.InitPostScript(); + mgr.InitPostScript(); - auto result = mgr.LookupAddr("1.1.1.1"); - REQUIRE(result != nullptr); - CHECK(strcmp(result->CheckString(), "one.one.one.one") == 0); + auto result = mgr.LookupAddr("1.1.1.1"); + REQUIRE(result != nullptr); + CHECK(strcmp(result->CheckString(), "one.one.one.one") == 0); - // FIXME: This won't run on systems without IPv6 connectivity. - // setenv("ZEEK_DNS_RESOLVER", "2606:4700:4700::1111", 1); - // TestDNS_Mgr mgr2(DNS_DEFAULT, true); - // mgr2.InitPostScript(); - // result = mgr2.LookupAddr("1.1.1.1"); - // mgr2.Resolve(); + // FIXME: This won't run on systems without IPv6 connectivity. + // setenv("ZEEK_DNS_RESOLVER", "2606:4700:4700::1111", 1); + // TestDNS_Mgr mgr2(DNS_DEFAULT, true); + // mgr2.InitPostScript(); + // result = mgr2.LookupAddr("1.1.1.1"); + // mgr2.Resolve(); - // result = mgr2.LookupAddr("1.1.1.1"); - // CHECK(strcmp(result->CheckString(), "one.one.one.one") == 0); + // result = mgr2.LookupAddr("1.1.1.1"); + // CHECK(strcmp(result->CheckString(), "one.one.one.one") == 0); - if ( old_server ) - setenv("ZEEK_DNS_RESOLVER", old_server, 1); - else - unsetenv("ZEEK_DNS_RESOLVER"); - } + if ( old_server ) + setenv("ZEEK_DNS_RESOLVER", old_server, 1); + else + unsetenv("ZEEK_DNS_RESOLVER"); +} -TEST_CASE("dns_mgr default mode" * doctest::skip(true)) - { - TestDNS_Mgr mgr(DNS_DEFAULT); - mgr.InitPostScript(); +TEST_CASE("dns_mgr default mode" * doctest::skip(true)) { + TestDNS_Mgr mgr(DNS_DEFAULT); + mgr.InitPostScript(); - IPAddr ones4("1.1.1.1"); - IPAddr ones6("2606:4700:4700::1111"); + IPAddr ones4("1.1.1.1"); + IPAddr ones6("2606:4700:4700::1111"); - auto host_result = mgr.LookupHost("one.one.one.one"); - REQUIRE(host_result != nullptr); - CHECK_FALSE(host_result->EqualTo(TestDNS_Mgr::empty_addr_set())); + auto host_result = mgr.LookupHost("one.one.one.one"); + REQUIRE(host_result != nullptr); + CHECK_FALSE(host_result->EqualTo(TestDNS_Mgr::empty_addr_set())); - auto addrs_from_request = get_result_addresses(host_result); - auto it = std::find(addrs_from_request.begin(), addrs_from_request.end(), ones4); - CHECK(it != addrs_from_request.end()); - it = std::find(addrs_from_request.begin(), addrs_from_request.end(), ones6); - CHECK(it != addrs_from_request.end()); + auto addrs_from_request = get_result_addresses(host_result); + auto it = std::find(addrs_from_request.begin(), addrs_from_request.end(), ones4); + CHECK(it != addrs_from_request.end()); + it = std::find(addrs_from_request.begin(), addrs_from_request.end(), ones6); + CHECK(it != addrs_from_request.end()); - auto addr_result = mgr.LookupAddr(ones4); - REQUIRE(addr_result != nullptr); - CHECK(strcmp(addr_result->CheckString(), "one.one.one.one") == 0); + auto addr_result = mgr.LookupAddr(ones4); + REQUIRE(addr_result != nullptr); + CHECK(strcmp(addr_result->CheckString(), "one.one.one.one") == 0); - addr_result = mgr.LookupAddr(ones6); - REQUIRE(addr_result != nullptr); - CHECK(strcmp(addr_result->CheckString(), "one.one.one.one") == 0); + addr_result = mgr.LookupAddr(ones6); + REQUIRE(addr_result != nullptr); + CHECK(strcmp(addr_result->CheckString(), "one.one.one.one") == 0); - IPAddr bad("240.0.0.0"); - addr_result = mgr.LookupAddr(bad); - REQUIRE(addr_result != nullptr); - CHECK(strcmp(addr_result->CheckString(), "240.0.0.0") == 0); - } + IPAddr bad("240.0.0.0"); + addr_result = mgr.LookupAddr(bad); + REQUIRE(addr_result != nullptr); + CHECK(strcmp(addr_result->CheckString(), "240.0.0.0") == 0); +} -TEST_CASE("dns_mgr async host" * doctest::skip(true)) - { - TestDNS_Mgr mgr(DNS_DEFAULT); - mgr.InitPostScript(); +TEST_CASE("dns_mgr async host" * doctest::skip(true)) { + TestDNS_Mgr mgr(DNS_DEFAULT); + mgr.InitPostScript(); - TestCallback cb{}; - mgr.LookupHost("one.one.one.one", &cb); + TestCallback cb{}; + mgr.LookupHost("one.one.one.one", &cb); - // This shouldn't take any longer than DNS_TIMEOUT+1 seconds, so bound it - // just in case of some failure we're not aware of yet. - int count = 0; - while ( ! cb.done && (count < DNS_TIMEOUT + 1) ) - { - mgr.Process(); - sleep(1); - if ( ! cb.timeout ) - count++; - } + // This shouldn't take any longer than DNS_TIMEOUT+1 seconds, so bound it + // just in case of some failure we're not aware of yet. + int count = 0; + while ( ! cb.done && (count < DNS_TIMEOUT + 1) ) { + mgr.Process(); + sleep(1); + if ( ! cb.timeout ) + count++; + } - REQUIRE(count < (DNS_TIMEOUT + 1)); - if ( ! cb.timeout ) - { - REQUIRE_FALSE(cb.addr_results.empty()); - IPAddr ones("1.1.1.1"); - auto it = std::find(cb.addr_results.begin(), cb.addr_results.end(), ones); - CHECK(it != cb.addr_results.end()); - } + REQUIRE(count < (DNS_TIMEOUT + 1)); + if ( ! cb.timeout ) { + REQUIRE_FALSE(cb.addr_results.empty()); + IPAddr ones("1.1.1.1"); + auto it = std::find(cb.addr_results.begin(), cb.addr_results.end(), ones); + CHECK(it != cb.addr_results.end()); + } - mgr.Flush(); - } + mgr.Flush(); +} -TEST_CASE("dns_mgr async addr" * doctest::skip(true)) - { - TestDNS_Mgr mgr(DNS_DEFAULT); - mgr.InitPostScript(); +TEST_CASE("dns_mgr async addr" * doctest::skip(true)) { + TestDNS_Mgr mgr(DNS_DEFAULT); + mgr.InitPostScript(); - TestCallback cb{}; - mgr.LookupAddr(IPAddr{"1.1.1.1"}, &cb); + TestCallback cb{}; + mgr.LookupAddr(IPAddr{"1.1.1.1"}, &cb); - // This shouldn't take any longer than DNS_TIMEOUT +1 seconds, so bound it - // just in case of some failure we're not aware of yet. - int count = 0; - while ( ! cb.done && (count < DNS_TIMEOUT + 1) ) - { - mgr.Process(); - sleep(1); - if ( ! cb.timeout ) - count++; - } + // This shouldn't take any longer than DNS_TIMEOUT +1 seconds, so bound it + // just in case of some failure we're not aware of yet. + int count = 0; + while ( ! cb.done && (count < DNS_TIMEOUT + 1) ) { + mgr.Process(); + sleep(1); + if ( ! cb.timeout ) + count++; + } - REQUIRE(count < (DNS_TIMEOUT + 1)); - if ( ! cb.timeout ) - REQUIRE(cb.host_result == "one.one.one.one"); + REQUIRE(count < (DNS_TIMEOUT + 1)); + if ( ! cb.timeout ) + REQUIRE(cb.host_result == "one.one.one.one"); - mgr.Flush(); - } + mgr.Flush(); +} -TEST_CASE("dns_mgr async text" * doctest::skip(true)) - { - TestDNS_Mgr mgr(DNS_DEFAULT); - mgr.InitPostScript(); +TEST_CASE("dns_mgr async text" * doctest::skip(true)) { + TestDNS_Mgr mgr(DNS_DEFAULT); + mgr.InitPostScript(); - TestCallback cb{}; - mgr.Lookup("unittest.zeek.org", T_TXT, &cb); + TestCallback cb{}; + mgr.Lookup("unittest.zeek.org", T_TXT, &cb); - // This shouldn't take any longer than DNS_TIMEOUT +1 seconds, so bound it - // just in case of some failure we're not aware of yet. - int count = 0; - while ( ! cb.done && (count < DNS_TIMEOUT + 1) ) - { - mgr.Process(); - sleep(1); - if ( ! cb.timeout ) - count++; - } + // This shouldn't take any longer than DNS_TIMEOUT +1 seconds, so bound it + // just in case of some failure we're not aware of yet. + int count = 0; + while ( ! cb.done && (count < DNS_TIMEOUT + 1) ) { + mgr.Process(); + sleep(1); + if ( ! cb.timeout ) + count++; + } - REQUIRE(count < (DNS_TIMEOUT + 1)); - if ( ! cb.timeout ) - REQUIRE(cb.host_result == "testing dns_mgr"); + REQUIRE(count < (DNS_TIMEOUT + 1)); + if ( ! cb.timeout ) + REQUIRE(cb.host_result == "testing dns_mgr"); - mgr.Flush(); - } + mgr.Flush(); +} -TEST_CASE("dns_mgr timeouts" * doctest::skip(true)) - { - char* old_server = getenv("ZEEK_DNS_RESOLVER"); +TEST_CASE("dns_mgr timeouts" * doctest::skip(true)) { + char* old_server = getenv("ZEEK_DNS_RESOLVER"); - // This is the address for blackhole.webpagetest.org, which provides a DNS - // server that lets you connect but never returns any responses, always - // resulting in a timeout. - setenv("ZEEK_DNS_RESOLVER", "3.219.212.117", 1); - TestDNS_Mgr mgr(DNS_DEFAULT); + // This is the address for blackhole.webpagetest.org, which provides a DNS + // server that lets you connect but never returns any responses, always + // resulting in a timeout. + setenv("ZEEK_DNS_RESOLVER", "3.219.212.117", 1); + TestDNS_Mgr mgr(DNS_DEFAULT); - mgr.InitPostScript(); - auto addr_result = mgr.LookupAddr("1.1.1.1"); - REQUIRE(addr_result != nullptr); - CHECK(strcmp(addr_result->CheckString(), "1.1.1.1") == 0); + mgr.InitPostScript(); + auto addr_result = mgr.LookupAddr("1.1.1.1"); + REQUIRE(addr_result != nullptr); + CHECK(strcmp(addr_result->CheckString(), "1.1.1.1") == 0); - auto host_result = mgr.LookupHost("one.one.one.one"); - REQUIRE(host_result != nullptr); - auto addresses = get_result_addresses(host_result); - CHECK(addresses.size() == 0); + auto host_result = mgr.LookupHost("one.one.one.one"); + REQUIRE(host_result != nullptr); + auto addresses = get_result_addresses(host_result); + CHECK(addresses.size() == 0); - if ( old_server ) - setenv("ZEEK_DNS_RESOLVER", old_server, 1); - else - unsetenv("ZEEK_DNS_RESOLVER"); - } + if ( old_server ) + setenv("ZEEK_DNS_RESOLVER", old_server, 1); + else + unsetenv("ZEEK_DNS_RESOLVER"); +} -TEST_CASE("dns_mgr async timeouts" * doctest::skip(true)) - { - char* old_server = getenv("ZEEK_DNS_RESOLVER"); +TEST_CASE("dns_mgr async timeouts" * doctest::skip(true)) { + char* old_server = getenv("ZEEK_DNS_RESOLVER"); - // This is the address for blackhole.webpagetest.org, which provides a DNS - // server that lets you connect but never returns any responses, always - // resulting in a timeout. - setenv("ZEEK_DNS_RESOLVER", "3.219.212.117", 1); - TestDNS_Mgr mgr(DNS_DEFAULT); - mgr.InitPostScript(); + // This is the address for blackhole.webpagetest.org, which provides a DNS + // server that lets you connect but never returns any responses, always + // resulting in a timeout. + setenv("ZEEK_DNS_RESOLVER", "3.219.212.117", 1); + TestDNS_Mgr mgr(DNS_DEFAULT); + mgr.InitPostScript(); - TestCallback cb{}; - mgr.Lookup("unittest.zeek.org", T_TXT, &cb); + TestCallback cb{}; + mgr.Lookup("unittest.zeek.org", T_TXT, &cb); - // This shouldn't take any longer than DNS_TIMEOUT +1 seconds, so bound it - // just in case of some failure we're not aware of yet. - int count = 0; - while ( ! cb.done && (count < DNS_TIMEOUT + 1) ) - { - mgr.Process(); - sleep(1); - if ( ! cb.timeout ) - count++; - } + // This shouldn't take any longer than DNS_TIMEOUT +1 seconds, so bound it + // just in case of some failure we're not aware of yet. + int count = 0; + while ( ! cb.done && (count < DNS_TIMEOUT + 1) ) { + mgr.Process(); + sleep(1); + if ( ! cb.timeout ) + count++; + } - REQUIRE(count < (DNS_TIMEOUT + 1)); - CHECK(cb.timeout); + REQUIRE(count < (DNS_TIMEOUT + 1)); + CHECK(cb.timeout); - mgr.Flush(); + mgr.Flush(); - if ( old_server ) - setenv("ZEEK_DNS_RESOLVER", old_server, 1); - else - unsetenv("ZEEK_DNS_RESOLVER"); - } + if ( old_server ) + setenv("ZEEK_DNS_RESOLVER", old_server, 1); + else + unsetenv("ZEEK_DNS_RESOLVER"); +} - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/DNS_Mgr.h b/src/DNS_Mgr.h index 5f87313e95..3fed818a7f 100644 --- a/src/DNS_Mgr.h +++ b/src/DNS_Mgr.h @@ -27,326 +27,314 @@ typedef struct ares_channeldata* ares_channel; #define T_TXT 16 #endif -namespace zeek - { +namespace zeek { class Val; class ListVal; class TableVal; class StringVal; -template class IntrusivePtr; +template +class IntrusivePtr; using ValPtr = IntrusivePtr; using ListValPtr = IntrusivePtr; using TableValPtr = IntrusivePtr; using StringValPtr = IntrusivePtr; - } // namespace zeek +} // namespace zeek -namespace zeek::detail - { +namespace zeek::detail { class DNS_Mapping; using DNS_MappingPtr = std::shared_ptr; class DNS_Request; -enum DNS_MgrMode - { - DNS_PRIME, // used to prime the cache - DNS_FORCE, // internal error if cache miss - DNS_DEFAULT, // lookup names as they're requested - DNS_FAKE, // don't look up names, just return dummy results - }; +enum DNS_MgrMode { + DNS_PRIME, // used to prime the cache + DNS_FORCE, // internal error if cache miss + DNS_DEFAULT, // lookup names as they're requested + DNS_FAKE, // don't look up names, just return dummy results +}; -class DNS_Mgr : public iosource::IOSource - { +class DNS_Mgr : public iosource::IOSource { public: - /** - * Base class for callback handling for asynchronous lookups. - */ - class LookupCallback - { - public: - virtual ~LookupCallback() = default; + /** + * Base class for callback handling for asynchronous lookups. + */ + class LookupCallback { + public: + virtual ~LookupCallback() = default; - /** - * Called when an address lookup finishes. - * - * @param name The resulting name from the lookup. - */ - virtual void Resolved(const std::string& name){}; + /** + * Called when an address lookup finishes. + * + * @param name The resulting name from the lookup. + */ + virtual void Resolved(const std::string& name){}; - /** - * Called when a name lookup finishes. - * - * @param addrs A table of the resulting addresses from the lookup. - */ - virtual void Resolved(TableValPtr addrs){}; + /** + * Called when a name lookup finishes. + * + * @param addrs A table of the resulting addresses from the lookup. + */ + virtual void Resolved(TableValPtr addrs){}; - /** - * Generic callback method for all request types. - * - * @param val A Val containing the data from the query. - */ - virtual void Resolved(ValPtr data, int request_type) { } + /** + * Generic callback method for all request types. + * + * @param val A Val containing the data from the query. + */ + virtual void Resolved(ValPtr data, int request_type) {} - /** - * Called when a timeout request occurs. - */ - virtual void Timeout() = 0; - }; + /** + * Called when a timeout request occurs. + */ + virtual void Timeout() = 0; + }; - explicit DNS_Mgr(DNS_MgrMode mode); - ~DNS_Mgr() override; + explicit DNS_Mgr(DNS_MgrMode mode); + ~DNS_Mgr() override; - /** - * Finalizes the source when it's being closed. - */ - void Done() override; + /** + * Finalizes the source when it's being closed. + */ + void Done() override; - /** - * Finalizes the manager initialization. This should be called only after all - * of the scripts have been parsed at startup. - */ - void InitPostScript(); + /** + * Finalizes the manager initialization. This should be called only after all + * of the scripts have been parsed at startup. + */ + void InitPostScript(); - /** - * Attempts to process one more round of requests and then flushes the - * mapping caches. - */ - void Flush(); + /** + * Attempts to process one more round of requests and then flushes the + * mapping caches. + */ + void Flush(); - /** - * Looks up the address(es) of a given host and returns a set of addresses. - * This is a shorthand method for doing A/AAAA requests. This is a - * synchronous request and will block until the request completes or times - * out. - * - * @param host The hostname to lookup an address for. - * @return A set of addresses for the host. - */ - TableValPtr LookupHost(const std::string& host); + /** + * Looks up the address(es) of a given host and returns a set of addresses. + * This is a shorthand method for doing A/AAAA requests. This is a + * synchronous request and will block until the request completes or times + * out. + * + * @param host The hostname to lookup an address for. + * @return A set of addresses for the host. + */ + TableValPtr LookupHost(const std::string& host); - /** - * Looks up the hostname of a given address. This is a shorthand method for - * doing PTR requests. This is a synchronous request and will block until - * the request completes or times out. - * - * @param host The addr to lookup a hostname for. - * @return The hostname for the address. - */ - StringValPtr LookupAddr(const IPAddr& addr); + /** + * Looks up the hostname of a given address. This is a shorthand method for + * doing PTR requests. This is a synchronous request and will block until + * the request completes or times out. + * + * @param host The addr to lookup a hostname for. + * @return The hostname for the address. + */ + StringValPtr LookupAddr(const IPAddr& addr); - /** - * Performs a generic request to the DNS server. This is a synchronous - * request and will block until the request completes or times out. - * - * @param name The name or address to make a request for. If this is an - * address it should be in arpa format (x.x.x.x.in-addr.arpa or x-*.ip6.arpa). - * Note that calling LookupAddr for PTR requests does this conversion - * automatically. - * @param request_type The type of request to make. This should be one of - * the type values defined in arpa/nameser.h or ares_nameser.h. - * @return The requested data. - */ - ValPtr Lookup(const std::string& name, int request_type); + /** + * Performs a generic request to the DNS server. This is a synchronous + * request and will block until the request completes or times out. + * + * @param name The name or address to make a request for. If this is an + * address it should be in arpa format (x.x.x.x.in-addr.arpa or x-*.ip6.arpa). + * Note that calling LookupAddr for PTR requests does this conversion + * automatically. + * @param request_type The type of request to make. This should be one of + * the type values defined in arpa/nameser.h or ares_nameser.h. + * @return The requested data. + */ + ValPtr Lookup(const std::string& name, int request_type); - /** - * Looks up the address(es) of a given host. This is a shorthand method - * for doing A/AAAA requests. This is an asynchronous request. The - * response will be handled via the provided callback object. - * - * @param host The hostname to lookup an address for. - * @param callback A callback object for handling the response. - */ - void LookupHost(const std::string& host, LookupCallback* callback); + /** + * Looks up the address(es) of a given host. This is a shorthand method + * for doing A/AAAA requests. This is an asynchronous request. The + * response will be handled via the provided callback object. + * + * @param host The hostname to lookup an address for. + * @param callback A callback object for handling the response. + */ + void LookupHost(const std::string& host, LookupCallback* callback); - /** - * Looks up the hostname of a given address. This is a shorthand method for - * doing PTR requests. This is an asynchronous request. The response will - * be handled via the provided callback object. - * - * @param host The addr to lookup a hostname for. - * @param callback A callback object for handling the response. - */ - void LookupAddr(const IPAddr& addr, LookupCallback* callback); + /** + * Looks up the hostname of a given address. This is a shorthand method for + * doing PTR requests. This is an asynchronous request. The response will + * be handled via the provided callback object. + * + * @param host The addr to lookup a hostname for. + * @param callback A callback object for handling the response. + */ + void LookupAddr(const IPAddr& addr, LookupCallback* callback); - /** - * Performs a generic request to the DNS server. This is an asynchronous - * request. The response will be handled via the provided callback - * object. - * - * @param name The name or address to make a request for. If this is an - * address it should be in arpa format (x.x.x.x.in-addr.arpa or x-*.ip6.arpa). - * Note that calling LookupAddr for PTR requests does this conversion - * automatically. - * @param request_type The type of request to make. This should be one of - * the type values defined in arpa/nameser.h or ares_nameser.h. - * @param callback A callback object for handling the response. - */ - void Lookup(const std::string& name, int request_type, LookupCallback* callback); + /** + * Performs a generic request to the DNS server. This is an asynchronous + * request. The response will be handled via the provided callback + * object. + * + * @param name The name or address to make a request for. If this is an + * address it should be in arpa format (x.x.x.x.in-addr.arpa or x-*.ip6.arpa). + * Note that calling LookupAddr for PTR requests does this conversion + * automatically. + * @param request_type The type of request to make. This should be one of + * the type values defined in arpa/nameser.h or ares_nameser.h. + * @param callback A callback object for handling the response. + */ + void Lookup(const std::string& name, int request_type, LookupCallback* callback); - /** - * Sets the directory where to store DNS data when Save() is called. - */ - void SetDir(const std::string& arg_dir) { dir = arg_dir; } + /** + * Sets the directory where to store DNS data when Save() is called. + */ + void SetDir(const std::string& arg_dir) { dir = arg_dir; } - /** - * Waits for responses to become available or a timeout to occur, - * and handles any responses. - */ - void Resolve(); + /** + * Waits for responses to become available or a timeout to occur, + * and handles any responses. + */ + void Resolve(); - /** - * Saves the current name and address caches to disk. - */ - bool Save(); + /** + * Saves the current name and address caches to disk. + */ + bool Save(); - struct Stats - { - unsigned long requests; // These count only async requests. - unsigned long successful; - unsigned long failed; - unsigned long pending; - unsigned long cached_hosts; - unsigned long cached_addresses; - unsigned long cached_texts; - unsigned long cached_total; - }; + struct Stats { + unsigned long requests; // These count only async requests. + unsigned long successful; + unsigned long failed; + unsigned long pending; + unsigned long cached_hosts; + unsigned long cached_addresses; + unsigned long cached_texts; + unsigned long cached_total; + }; - /** - * Returns the current statistics for the DNS_Manager. - * - * @param stats A pointer to a stats object to return the data in. - */ - void GetStats(Stats* stats); + /** + * Returns the current statistics for the DNS_Manager. + * + * @param stats A pointer to a stats object to return the data in. + */ + void GetStats(Stats* stats); - /** - * Adds a result from a request to the caches. This is public so that the - * callback methods can call it from outside of the DNS_Mgr class. - * - * @param dr The request associated with the result. - * @param h A hostent structure containing the actual result data. - * @param ttl A ttl value contained in the response from the server. - * @param merge A flag for whether these results should be merged into - * an existing mapping. If false, AddResult will attempt to replace the - * existing mapping with the new data and delete the old mapping. - */ - void AddResult(DNS_Request* dr, struct hostent* h, uint32_t ttl, bool merge = false); + /** + * Adds a result from a request to the caches. This is public so that the + * callback methods can call it from outside of the DNS_Mgr class. + * + * @param dr The request associated with the result. + * @param h A hostent structure containing the actual result data. + * @param ttl A ttl value contained in the response from the server. + * @param merge A flag for whether these results should be merged into + * an existing mapping. If false, AddResult will attempt to replace the + * existing mapping with the new data and delete the old mapping. + */ + void AddResult(DNS_Request* dr, struct hostent* h, uint32_t ttl, bool merge = false); - /** - * Returns an empty set of addresses, used in various error cases and during - * cache priming. - */ - static TableValPtr empty_addr_set(); + /** + * Returns an empty set of addresses, used in various error cases and during + * cache priming. + */ + static TableValPtr empty_addr_set(); - /** - * Returns the full path to the file used to store the DNS cache. - */ - std::string CacheFile() const { return cache_name; } + /** + * Returns the full path to the file used to store the DNS cache. + */ + std::string CacheFile() const { return cache_name; } - /** - * Used by the c-ares socket call back to register/unregister a socket file descriptor. - */ - void RegisterSocket(int fd, bool read, bool write); + /** + * Used by the c-ares socket call back to register/unregister a socket file descriptor. + */ + void RegisterSocket(int fd, bool read, bool write); - ares_channel& GetChannel() { return channel; } + ares_channel& GetChannel() { return channel; } protected: - friend class LookupCallback; - friend class DNS_Request; + friend class LookupCallback; + friend class DNS_Request; - StringValPtr LookupAddrInCache(const IPAddr& addr, bool cleanup_expired = false, - bool check_failed = false); - TableValPtr LookupNameInCache(const std::string& name, bool cleanup_expired = false, - bool check_failed = false); - StringValPtr LookupOtherInCache(const std::string& name, int request_type, - bool cleanup_expired = false); + StringValPtr LookupAddrInCache(const IPAddr& addr, bool cleanup_expired = false, bool check_failed = false); + TableValPtr LookupNameInCache(const std::string& name, bool cleanup_expired = false, bool check_failed = false); + StringValPtr LookupOtherInCache(const std::string& name, int request_type, bool cleanup_expired = false); - // Finish the request if we have a result. If not, time it out if - // requested. - void CheckAsyncAddrRequest(const IPAddr& addr, bool timeout); - void CheckAsyncHostRequest(const std::string& host, bool timeout); - void CheckAsyncOtherRequest(const std::string& host, bool timeout, int request_type); + // Finish the request if we have a result. If not, time it out if + // requested. + void CheckAsyncAddrRequest(const IPAddr& addr, bool timeout); + void CheckAsyncHostRequest(const std::string& host, bool timeout); + void CheckAsyncOtherRequest(const std::string& host, bool timeout, int request_type); - void Event(EventHandlerPtr e, const DNS_MappingPtr& dm); - void Event(EventHandlerPtr e, const DNS_MappingPtr& dm, ListValPtr l1, ListValPtr l2); - void Event(EventHandlerPtr e, const DNS_MappingPtr& old_dm, DNS_MappingPtr new_dm); + void Event(EventHandlerPtr e, const DNS_MappingPtr& dm); + void Event(EventHandlerPtr e, const DNS_MappingPtr& dm, ListValPtr l1, ListValPtr l2); + void Event(EventHandlerPtr e, const DNS_MappingPtr& old_dm, DNS_MappingPtr new_dm); - ValPtr BuildMappingVal(const DNS_MappingPtr& dm); + ValPtr BuildMappingVal(const DNS_MappingPtr& dm); - void CompareMappings(const DNS_MappingPtr& prev_dm, const DNS_MappingPtr& new_dm); - ListValPtr AddrListDelta(ListValPtr al1, ListValPtr al2); + void CompareMappings(const DNS_MappingPtr& prev_dm, const DNS_MappingPtr& new_dm); + ListValPtr AddrListDelta(ListValPtr al1, ListValPtr al2); - using MappingKey = std::variant>; - using MappingMap = std::map; - void LoadCache(const std::string& path); - void Save(FILE* f, const MappingMap& m); + using MappingKey = std::variant>; + using MappingMap = std::map; + void LoadCache(const std::string& path); + void Save(FILE* f, const MappingMap& m); - // Issue as many queued async requests as slots are available. - void IssueAsyncRequests(); + // Issue as many queued async requests as slots are available. + void IssueAsyncRequests(); - // IOSource interface. - void Process() override; - void ProcessFd(int fd, int flags) override; - void InitSource() override; - const char* Tag() override { return "DNS_Mgr"; } - double GetNextTimeout() override; + // IOSource interface. + void Process() override; + void ProcessFd(int fd, int flags) override; + void InitSource() override; + const char* Tag() override { return "DNS_Mgr"; } + double GetNextTimeout() override; - DNS_MgrMode mode; + DNS_MgrMode mode; - MappingMap all_mappings; + MappingMap all_mappings; - std::string cache_name; - std::string dir; // directory in which cache_name resides + std::string cache_name; + std::string dir; // directory in which cache_name resides - bool did_init = false; - int asyncs_pending = 0; + bool did_init = false; + int asyncs_pending = 0; - RecordTypePtr dm_rec; + RecordTypePtr dm_rec; - ares_channel channel{}; + ares_channel channel{}; - using CallbackList = std::list; + using CallbackList = std::list; - struct AsyncRequest - { - double time = 0.0; - IPAddr addr; - std::string host; - CallbackList callbacks; - int type = 0; - bool processed = false; + struct AsyncRequest { + double time = 0.0; + IPAddr addr; + std::string host; + CallbackList callbacks; + int type = 0; + bool processed = false; - AsyncRequest(std::string host, int request_type) : host(std::move(host)), type(request_type) - { - } - AsyncRequest(const IPAddr& addr) : addr(addr), type(T_PTR) { } + AsyncRequest(std::string host, int request_type) : host(std::move(host)), type(request_type) {} + AsyncRequest(const IPAddr& addr) : addr(addr), type(T_PTR) {} - void Resolved(const std::string& name); - void Resolved(TableValPtr addrs); - void Timeout(); - }; + void Resolved(const std::string& name); + void Resolved(TableValPtr addrs); + void Timeout(); + }; - struct AsyncRequestCompare - { - bool operator()(const AsyncRequest* a, const AsyncRequest* b) { return a->time > b->time; } - }; + struct AsyncRequestCompare { + bool operator()(const AsyncRequest* a, const AsyncRequest* b) { return a->time > b->time; } + }; - using AsyncRequestMap = std::map; - AsyncRequestMap asyncs; + using AsyncRequestMap = std::map; + AsyncRequestMap asyncs; - using QueuedList = std::list; - QueuedList asyncs_queued; + using QueuedList = std::list; + QueuedList asyncs_queued; - unsigned long num_requests = 0; - unsigned long successful = 0; - unsigned long failed = 0; + unsigned long num_requests = 0; + unsigned long successful = 0; + unsigned long failed = 0; - std::set socket_fds; - std::set write_socket_fds; + std::set socket_fds; + std::set write_socket_fds; - bool shutting_down = false; - }; + bool shutting_down = false; +}; extern DNS_Mgr* dns_mgr; - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/DbgBreakpoint.cc b/src/DbgBreakpoint.cc index 32c87655e1..92912e14e3 100644 --- a/src/DbgBreakpoint.cc +++ b/src/DbgBreakpoint.cc @@ -18,357 +18,307 @@ #include "zeek/Val.h" #include "zeek/module_util.h" -namespace zeek::detail - { +namespace zeek::detail { // BreakpointTimer used for time-based breakpoints -class BreakpointTimer final : public Timer - { +class BreakpointTimer final : public Timer { public: - BreakpointTimer(DbgBreakpoint* arg_bp, double arg_t) : Timer(arg_t, TIMER_BREAKPOINT) - { - bp = arg_bp; - } + BreakpointTimer(DbgBreakpoint* arg_bp, double arg_t) : Timer(arg_t, TIMER_BREAKPOINT) { bp = arg_bp; } - void Dispatch(double t, bool is_expire) override; + void Dispatch(double t, bool is_expire) override; protected: - DbgBreakpoint* bp; - }; - -void BreakpointTimer::Dispatch(double t, bool is_expire) - { - if ( is_expire ) - return; - - bp->ShouldBreak(t); - } - -DbgBreakpoint::DbgBreakpoint() - { - kind = BP_STMT; - - enabled = temporary = false; - BPID = -1; - - at_stmt = nullptr; - at_time = -1.0; - - repeat_count = hit_count = 0; - - description[0] = 0; - source_filename = nullptr; - source_line = 0; - } - -DbgBreakpoint::~DbgBreakpoint() - { - SetEnable(false); // clean up any active state - RemoveFromGlobalMap(); - } - -bool DbgBreakpoint::SetEnable(bool do_enable) - { - bool old_value = enabled; - enabled = do_enable; - - // Update statement counts. - if ( do_enable && ! old_value ) - AddToStmt(); - - else if ( ! do_enable && old_value ) - RemoveFromStmt(); - - return old_value; - } - -void DbgBreakpoint::AddToGlobalMap() - { - // Make sure it's not there already. - RemoveFromGlobalMap(); - - g_debugger_state.breakpoint_map.insert(BPMapType::value_type(at_stmt, this)); - } - -void DbgBreakpoint::RemoveFromGlobalMap() - { - std::pair p; - p = g_debugger_state.breakpoint_map.equal_range(at_stmt); - - for ( BPMapType::iterator i = p.first; i != p.second; ) - { - if ( i->second == this ) - { - BPMapType::iterator next = i; - ++next; - g_debugger_state.breakpoint_map.erase(i); - i = next; - } - else - ++i; - } - } - -void DbgBreakpoint::AddToStmt() - { - if ( at_stmt ) - at_stmt->IncrBPCount(); - } - -void DbgBreakpoint::RemoveFromStmt() - { - if ( at_stmt ) - at_stmt->DecrBPCount(); - } - -bool DbgBreakpoint::SetLocation(ParseLocationRec plr, std::string_view loc_str) - { - if ( plr.type == PLR_UNKNOWN ) - { - debug_msg("Breakpoint specifier invalid or operation canceled.\n"); - return false; - } - - if ( plr.type == PLR_FILE_AND_LINE ) - { - kind = BP_LINE; - source_filename = plr.filename; - source_line = plr.line; - - if ( ! plr.stmt ) - { - debug_msg("No statement at that line.\n"); - return false; - } - - at_stmt = plr.stmt; - snprintf(description, sizeof(description), "%s:%d", source_filename, source_line); - - debug_msg("Breakpoint %d set at %s\n", GetID(), Description()); - } - - else if ( plr.type == PLR_FUNCTION ) - { - std::string loc_s(loc_str); - kind = BP_FUNC; - function_name = make_full_var_name(current_module.c_str(), loc_s.c_str()); - at_stmt = plr.stmt; - const Location* loc = at_stmt->GetLocationInfo(); - snprintf(description, sizeof(description), "%s at %s:%d", function_name.c_str(), - loc->filename, loc->last_line); - - debug_msg("Breakpoint %d set at %s\n", GetID(), Description()); - } - - SetEnable(true); - AddToGlobalMap(); - return true; - } - -bool DbgBreakpoint::SetLocation(Stmt* stmt) - { - if ( ! stmt ) - return false; - - kind = BP_STMT; - at_stmt = stmt; - - SetEnable(true); - AddToGlobalMap(); - - const Location* loc = stmt->GetLocationInfo(); - snprintf(description, sizeof(description), "%s:%d", loc->filename, loc->last_line); - - debug_msg("Breakpoint %d set at %s\n", GetID(), Description()); - - return true; - } - -bool DbgBreakpoint::SetLocation(double t) - { - debug_msg("SetLocation(time) has not been debugged."); - return false; - - kind = BP_TIME; - at_time = t; - - timer_mgr->Add(new BreakpointTimer(this, t)); - - debug_msg("Time-based breakpoints not yet supported.\n"); - return false; - } - -bool DbgBreakpoint::Reset() - { - ParseLocationRec plr; - - switch ( kind ) - { - case BP_TIME: - debug_msg("Time-based breakpoints not yet supported.\n"); - break; - - case BP_FUNC: - case BP_STMT: - case BP_LINE: - plr.type = PLR_FUNCTION; - //### How to deal with wildcards? - //### perhaps save user choices?--tough... - break; - } - - reporter->InternalError("DbgBreakpoint::Reset function incomplete."); - - // Cannot be reached. - return false; - } - -bool DbgBreakpoint::SetCondition(const std::string& new_condition) - { - condition = new_condition; - return true; - } - -bool DbgBreakpoint::SetRepeatCount(int count) - { - repeat_count = count; - return true; - } - -BreakCode DbgBreakpoint::HasHit() - { - if ( temporary ) - { - SetEnable(false); - return BC_HIT_AND_DELETE; - } - - if ( condition.size() ) - { - // TODO: ### evaluate using debugger frame too - auto yes = dbg_eval_expr(condition.c_str()); - - if ( ! yes ) - { - debug_msg("Breakpoint condition '%s' invalid, removing condition.\n", - condition.c_str()); - SetCondition(""); - PrintHitMsg(); - return BC_HIT; - } - - if ( ! IsIntegral(yes->GetType()->Tag()) && ! IsBool(yes->GetType()->Tag()) ) - { - PrintHitMsg(); - debug_msg("Breakpoint condition should return an integral type"); - return BC_HIT_AND_DELETE; - } - - yes->CoerceToInt(); - if ( yes->IsZero() ) - { - return BC_NO_HIT; - } - } - - int repcount = GetRepeatCount(); - if ( repcount ) - { - if ( ++hit_count == repcount ) - { - hit_count = 0; - PrintHitMsg(); - return BC_HIT; - } - - return BC_NO_HIT; - } - - PrintHitMsg(); - return BC_HIT; - } - -BreakCode DbgBreakpoint::ShouldBreak(Stmt* s) - { - if ( ! IsEnabled() ) - return BC_NO_HIT; - - switch ( kind ) - { - case BP_STMT: - case BP_FUNC: - if ( at_stmt != s ) - return BC_NO_HIT; - break; - - case BP_LINE: - assert(s->GetLocationInfo()->first_line <= source_line && - s->GetLocationInfo()->last_line >= source_line); - break; - - case BP_TIME: - assert(false); - - default: - reporter->InternalError("Invalid breakpoint type in DbgBreakpoint::ShouldBreak"); - } - - // If we got here, that means that the breakpoint could hit, - // except potentially if it has a special condition or a repeat count. - - BreakCode code = HasHit(); - if ( code ) - g_debugger_state.BreakBeforeNextStmt(true); - - return code; - } - -BreakCode DbgBreakpoint::ShouldBreak(double t) - { - if ( kind != BP_TIME ) - reporter->InternalError("Calling ShouldBreak(time) on a non-time breakpoint"); - - if ( t < at_time ) - return BC_NO_HIT; - - if ( ! IsEnabled() ) - return BC_NO_HIT; - - BreakCode code = HasHit(); - if ( code ) - g_debugger_state.BreakBeforeNextStmt(true); - - return code; - } - -void DbgBreakpoint::PrintHitMsg() - { - switch ( kind ) - { - case BP_STMT: - case BP_FUNC: - case BP_LINE: - { - ODesc d; - Frame* f = g_frame_stack.back(); - const ScriptFunc* func = f->GetFunction(); - - if ( func ) - func->DescribeDebug(&d, f->GetFuncArgs()); - - const Location* loc = at_stmt->GetLocationInfo(); - - debug_msg("Breakpoint %d, %s at %s:%d\n", GetID(), d.Description(), loc->filename, - loc->first_line); - } - return; - - case BP_TIME: - assert(false); - - default: - reporter->InternalError("Missed a case in DbgBreakpoint::PrintHitMsg\n"); - } - } - - } // namespace zeek::detail + DbgBreakpoint* bp; +}; + +void BreakpointTimer::Dispatch(double t, bool is_expire) { + if ( is_expire ) + return; + + bp->ShouldBreak(t); +} + +DbgBreakpoint::DbgBreakpoint() { + kind = BP_STMT; + + enabled = temporary = false; + BPID = -1; + + at_stmt = nullptr; + at_time = -1.0; + + repeat_count = hit_count = 0; + + description[0] = 0; + source_filename = nullptr; + source_line = 0; +} + +DbgBreakpoint::~DbgBreakpoint() { + SetEnable(false); // clean up any active state + RemoveFromGlobalMap(); +} + +bool DbgBreakpoint::SetEnable(bool do_enable) { + bool old_value = enabled; + enabled = do_enable; + + // Update statement counts. + if ( do_enable && ! old_value ) + AddToStmt(); + + else if ( ! do_enable && old_value ) + RemoveFromStmt(); + + return old_value; +} + +void DbgBreakpoint::AddToGlobalMap() { + // Make sure it's not there already. + RemoveFromGlobalMap(); + + g_debugger_state.breakpoint_map.insert(BPMapType::value_type(at_stmt, this)); +} + +void DbgBreakpoint::RemoveFromGlobalMap() { + std::pair p; + p = g_debugger_state.breakpoint_map.equal_range(at_stmt); + + for ( BPMapType::iterator i = p.first; i != p.second; ) { + if ( i->second == this ) { + BPMapType::iterator next = i; + ++next; + g_debugger_state.breakpoint_map.erase(i); + i = next; + } + else + ++i; + } +} + +void DbgBreakpoint::AddToStmt() { + if ( at_stmt ) + at_stmt->IncrBPCount(); +} + +void DbgBreakpoint::RemoveFromStmt() { + if ( at_stmt ) + at_stmt->DecrBPCount(); +} + +bool DbgBreakpoint::SetLocation(ParseLocationRec plr, std::string_view loc_str) { + if ( plr.type == PLR_UNKNOWN ) { + debug_msg("Breakpoint specifier invalid or operation canceled.\n"); + return false; + } + + if ( plr.type == PLR_FILE_AND_LINE ) { + kind = BP_LINE; + source_filename = plr.filename; + source_line = plr.line; + + if ( ! plr.stmt ) { + debug_msg("No statement at that line.\n"); + return false; + } + + at_stmt = plr.stmt; + snprintf(description, sizeof(description), "%s:%d", source_filename, source_line); + + debug_msg("Breakpoint %d set at %s\n", GetID(), Description()); + } + + else if ( plr.type == PLR_FUNCTION ) { + std::string loc_s(loc_str); + kind = BP_FUNC; + function_name = make_full_var_name(current_module.c_str(), loc_s.c_str()); + at_stmt = plr.stmt; + const Location* loc = at_stmt->GetLocationInfo(); + snprintf(description, sizeof(description), "%s at %s:%d", function_name.c_str(), loc->filename, loc->last_line); + + debug_msg("Breakpoint %d set at %s\n", GetID(), Description()); + } + + SetEnable(true); + AddToGlobalMap(); + return true; +} + +bool DbgBreakpoint::SetLocation(Stmt* stmt) { + if ( ! stmt ) + return false; + + kind = BP_STMT; + at_stmt = stmt; + + SetEnable(true); + AddToGlobalMap(); + + const Location* loc = stmt->GetLocationInfo(); + snprintf(description, sizeof(description), "%s:%d", loc->filename, loc->last_line); + + debug_msg("Breakpoint %d set at %s\n", GetID(), Description()); + + return true; +} + +bool DbgBreakpoint::SetLocation(double t) { + debug_msg("SetLocation(time) has not been debugged."); + return false; + + kind = BP_TIME; + at_time = t; + + timer_mgr->Add(new BreakpointTimer(this, t)); + + debug_msg("Time-based breakpoints not yet supported.\n"); + return false; +} + +bool DbgBreakpoint::Reset() { + ParseLocationRec plr; + + switch ( kind ) { + case BP_TIME: debug_msg("Time-based breakpoints not yet supported.\n"); break; + + case BP_FUNC: + case BP_STMT: + case BP_LINE: + plr.type = PLR_FUNCTION; + //### How to deal with wildcards? + //### perhaps save user choices?--tough... + break; + } + + reporter->InternalError("DbgBreakpoint::Reset function incomplete."); + + // Cannot be reached. + return false; +} + +bool DbgBreakpoint::SetCondition(const std::string& new_condition) { + condition = new_condition; + return true; +} + +bool DbgBreakpoint::SetRepeatCount(int count) { + repeat_count = count; + return true; +} + +BreakCode DbgBreakpoint::HasHit() { + if ( temporary ) { + SetEnable(false); + return BC_HIT_AND_DELETE; + } + + if ( condition.size() ) { + // TODO: ### evaluate using debugger frame too + auto yes = dbg_eval_expr(condition.c_str()); + + if ( ! yes ) { + debug_msg("Breakpoint condition '%s' invalid, removing condition.\n", condition.c_str()); + SetCondition(""); + PrintHitMsg(); + return BC_HIT; + } + + if ( ! IsIntegral(yes->GetType()->Tag()) && ! IsBool(yes->GetType()->Tag()) ) { + PrintHitMsg(); + debug_msg("Breakpoint condition should return an integral type"); + return BC_HIT_AND_DELETE; + } + + yes->CoerceToInt(); + if ( yes->IsZero() ) { + return BC_NO_HIT; + } + } + + int repcount = GetRepeatCount(); + if ( repcount ) { + if ( ++hit_count == repcount ) { + hit_count = 0; + PrintHitMsg(); + return BC_HIT; + } + + return BC_NO_HIT; + } + + PrintHitMsg(); + return BC_HIT; +} + +BreakCode DbgBreakpoint::ShouldBreak(Stmt* s) { + if ( ! IsEnabled() ) + return BC_NO_HIT; + + switch ( kind ) { + case BP_STMT: + case BP_FUNC: + if ( at_stmt != s ) + return BC_NO_HIT; + break; + + case BP_LINE: + assert(s->GetLocationInfo()->first_line <= source_line && s->GetLocationInfo()->last_line >= source_line); + break; + + case BP_TIME: assert(false); + + default: reporter->InternalError("Invalid breakpoint type in DbgBreakpoint::ShouldBreak"); + } + + // If we got here, that means that the breakpoint could hit, + // except potentially if it has a special condition or a repeat count. + + BreakCode code = HasHit(); + if ( code ) + g_debugger_state.BreakBeforeNextStmt(true); + + return code; +} + +BreakCode DbgBreakpoint::ShouldBreak(double t) { + if ( kind != BP_TIME ) + reporter->InternalError("Calling ShouldBreak(time) on a non-time breakpoint"); + + if ( t < at_time ) + return BC_NO_HIT; + + if ( ! IsEnabled() ) + return BC_NO_HIT; + + BreakCode code = HasHit(); + if ( code ) + g_debugger_state.BreakBeforeNextStmt(true); + + return code; +} + +void DbgBreakpoint::PrintHitMsg() { + switch ( kind ) { + case BP_STMT: + case BP_FUNC: + case BP_LINE: { + ODesc d; + Frame* f = g_frame_stack.back(); + const ScriptFunc* func = f->GetFunction(); + + if ( func ) + func->DescribeDebug(&d, f->GetFuncArgs()); + + const Location* loc = at_stmt->GetLocationInfo(); + + debug_msg("Breakpoint %d, %s at %s:%d\n", GetID(), d.Description(), loc->filename, loc->first_line); + } + return; + + case BP_TIME: assert(false); + + default: reporter->InternalError("Missed a case in DbgBreakpoint::PrintHitMsg\n"); + } +} + +} // namespace zeek::detail diff --git a/src/DbgBreakpoint.h b/src/DbgBreakpoint.h index f5af59d7a8..bca79f5498 100644 --- a/src/DbgBreakpoint.h +++ b/src/DbgBreakpoint.h @@ -6,95 +6,82 @@ #include "zeek/util.h" -namespace zeek::detail - { +namespace zeek::detail { class Stmt; class ParseLocationRec; -enum BreakCode - { - BC_NO_HIT, - BC_HIT, - BC_HIT_AND_DELETE - }; -class DbgBreakpoint - { - enum Kind - { - BP_STMT = 0, - BP_FUNC, - BP_LINE, - BP_TIME - }; +enum BreakCode { BC_NO_HIT, BC_HIT, BC_HIT_AND_DELETE }; +class DbgBreakpoint { + enum Kind { BP_STMT = 0, BP_FUNC, BP_LINE, BP_TIME }; public: - DbgBreakpoint(); - ~DbgBreakpoint(); + DbgBreakpoint(); + ~DbgBreakpoint(); - int GetID() const { return BPID; } - void SetID(int newID) { BPID = newID; } + int GetID() const { return BPID; } + void SetID(int newID) { BPID = newID; } - // True if breakpoint could be set; false otherwise - bool SetLocation(ParseLocationRec plr, std::string_view loc_str); - bool SetLocation(Stmt* stmt); - bool SetLocation(double time); + // True if breakpoint could be set; false otherwise + bool SetLocation(ParseLocationRec plr, std::string_view loc_str); + bool SetLocation(Stmt* stmt); + bool SetLocation(double time); - bool Reset(); // cancel and re-apply bpt when restarting execution + bool Reset(); // cancel and re-apply bpt when restarting execution - // Temporary = disable (remove?) the breakpoint right after it's hit. - bool IsTemporary() const { return temporary; } - void SetTemporary(bool is_temporary) { temporary = is_temporary; } + // Temporary = disable (remove?) the breakpoint right after it's hit. + bool IsTemporary() const { return temporary; } + void SetTemporary(bool is_temporary) { temporary = is_temporary; } - // Feed it a Stmt* or a time and see if this breakpoint should - // hit. bcHitAndDelete means that it has hit, and should now be - // deleted entirely. - // - // NOTE: If it returns a hit, the DbgBreakpoint object will take - // appropriate action (e.g., resetting counters). - BreakCode ShouldBreak(Stmt* s); - BreakCode ShouldBreak(double t); + // Feed it a Stmt* or a time and see if this breakpoint should + // hit. bcHitAndDelete means that it has hit, and should now be + // deleted entirely. + // + // NOTE: If it returns a hit, the DbgBreakpoint object will take + // appropriate action (e.g., resetting counters). + BreakCode ShouldBreak(Stmt* s); + BreakCode ShouldBreak(double t); - const std::string& GetCondition() const { return condition; } - bool SetCondition(const std::string& new_condition); + const std::string& GetCondition() const { return condition; } + bool SetCondition(const std::string& new_condition); - int GetRepeatCount() const { return repeat_count; } - bool SetRepeatCount(int count); // implements function of ignore command in gdb + int GetRepeatCount() const { return repeat_count; } + bool SetRepeatCount(int count); // implements function of ignore command in gdb - bool IsEnabled() const { return enabled; } - bool SetEnable(bool do_enable); + bool IsEnabled() const { return enabled; } + bool SetEnable(bool do_enable); - // e.g. "FooBar() at foo.c:23" - const char* Description() const { return description; } + // e.g. "FooBar() at foo.c:23" + const char* Description() const { return description; } protected: - void AddToGlobalMap(); - void RemoveFromGlobalMap(); + void AddToGlobalMap(); + void RemoveFromGlobalMap(); - void AddToStmt(); - void RemoveFromStmt(); + void AddToStmt(); + void RemoveFromStmt(); - BreakCode HasHit(); // a breakpoint hit, update state, return proper code. - void PrintHitMsg(); // display reason when the breakpoint hits + BreakCode HasHit(); // a breakpoint hit, update state, return proper code. + void PrintHitMsg(); // display reason when the breakpoint hits - Kind kind; - int32_t BPID; + Kind kind; + int32_t BPID; - char description[512]; - std::string function_name; // location - const char* source_filename; - int32_t source_line; - bool enabled; // ### comment this and next - bool temporary; + char description[512]; + std::string function_name; // location + const char* source_filename; + int32_t source_line; + bool enabled; // ### comment this and next + bool temporary; - Stmt* at_stmt; - double at_time; // break when the virtual time is this + Stmt* at_stmt; + double at_time; // break when the virtual time is this - // Support for conditional and N'th time breakpoints. - int32_t repeat_count; // if positive, break after this many hits - int32_t hit_count; // how many times it's been hit (w/o breaking) + // Support for conditional and N'th time breakpoints. + int32_t repeat_count; // if positive, break after this many hits + int32_t hit_count; // how many times it's been hit (w/o breaking) - std::string condition; // condition to evaluate; nil for none - }; + std::string condition; // condition to evaluate; nil for none +}; - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/DbgDisplay.h b/src/DbgDisplay.h index 09af4b230e..8e79a5d736 100644 --- a/src/DbgDisplay.h +++ b/src/DbgDisplay.h @@ -2,30 +2,27 @@ #pragma once -namespace zeek::detail - { +namespace zeek::detail { class Expr; // Automatic displays: display these at each stoppage. -class DbgDisplay - { +class DbgDisplay { public: - DbgDisplay(Expr* expr_to_display); + DbgDisplay(Expr* expr_to_display); - bool IsEnabled() { return enabled; } - bool SetEnable(bool do_enable) - { - bool old_value = enabled; - enabled = do_enable; - return old_value; - } + bool IsEnabled() { return enabled; } + bool SetEnable(bool do_enable) { + bool old_value = enabled; + enabled = do_enable; + return old_value; + } - const Expr* Expression() const { return expression; } + const Expr* Expression() const { return expression; } protected: - bool enabled; - Expr* expression; - }; + bool enabled; + Expr* expression; +}; - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/DbgWatch.cc b/src/DbgWatch.cc index a5a9fc3f0d..fd8700bd6e 100644 --- a/src/DbgWatch.cc +++ b/src/DbgWatch.cc @@ -7,18 +7,11 @@ #include "zeek/Debug.h" #include "zeek/Reporter.h" -namespace zeek::detail - { +namespace zeek::detail { // Support classes -DbgWatch::DbgWatch(zeek::Obj* var_to_watch) - { - reporter->InternalError("DbgWatch unimplemented"); - } +DbgWatch::DbgWatch(zeek::Obj* var_to_watch) { reporter->InternalError("DbgWatch unimplemented"); } -DbgWatch::DbgWatch(Expr* expr_to_watch) - { - reporter->InternalError("DbgWatch unimplemented"); - } +DbgWatch::DbgWatch(Expr* expr_to_watch) { reporter->InternalError("DbgWatch unimplemented"); } - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/DbgWatch.h b/src/DbgWatch.h index 22715c2195..93c12f6e38 100644 --- a/src/DbgWatch.h +++ b/src/DbgWatch.h @@ -4,26 +4,23 @@ #include "zeek/util.h" -namespace zeek - { +namespace zeek { class Obj; - } +} -namespace zeek::detail - { +namespace zeek::detail { class Expr; -class DbgWatch - { +class DbgWatch { public: - explicit DbgWatch(Obj* var_to_watch); - explicit DbgWatch(Expr* expr_to_watch); - ~DbgWatch() = default; + explicit DbgWatch(Obj* var_to_watch); + explicit DbgWatch(Expr* expr_to_watch); + ~DbgWatch() = default; protected: - Obj* var; - Expr* expr; - }; + Obj* var; + Expr* expr; +}; - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/Debug.cc b/src/Debug.cc index 6300443171..3d2fd26273 100644 --- a/src/Debug.cc +++ b/src/Debug.cc @@ -32,10 +32,9 @@ #include "zeek/module_util.h" #include "zeek/util.h" -extern "C" - { +extern "C" { #include "zeek/3rdparty/setsignal.h" - } +} using namespace std; @@ -66,362 +65,319 @@ extern YYLTYPE yylloc; // holds start line and column of token extern int line_number; extern const char* filename; -namespace zeek::detail - { +namespace zeek::detail { -DebuggerState::DebuggerState() - { - next_bp_id = next_watch_id = next_display_id = 1; - BreakBeforeNextStmt(false); - curr_frame_idx = 0; - already_did_list = false; - BreakFromSignal(false); +DebuggerState::DebuggerState() { + next_bp_id = next_watch_id = next_display_id = 1; + BreakBeforeNextStmt(false); + curr_frame_idx = 0; + already_did_list = false; + BreakFromSignal(false); - // ### Don't choose this arbitrary size! Extend Frame. - dbg_locals = new Frame(1024, /* func = */ nullptr, /* fn_args = */ nullptr); - } + // ### Don't choose this arbitrary size! Extend Frame. + dbg_locals = new Frame(1024, /* func = */ nullptr, /* fn_args = */ nullptr); +} -DebuggerState::~DebuggerState() - { - Unref(dbg_locals); - } +DebuggerState::~DebuggerState() { Unref(dbg_locals); } -bool StmtLocMapping::StartsAfter(const StmtLocMapping* m2) - { - if ( ! m2 ) - reporter->InternalError("Assertion failed: m2 != 0"); +bool StmtLocMapping::StartsAfter(const StmtLocMapping* m2) { + if ( ! m2 ) + reporter->InternalError("Assertion failed: m2 != 0"); - return loc.first_line > m2->loc.first_line || - (loc.first_line == m2->loc.first_line && loc.first_column > m2->loc.first_column); - } + return loc.first_line > m2->loc.first_line || + (loc.first_line == m2->loc.first_line && loc.first_column > m2->loc.first_column); +} // Generic debug message output. -int debug_msg(const char* fmt, ...) - { - va_list args; - int retval; +int debug_msg(const char* fmt, ...) { + va_list args; + int retval; - va_start(args, fmt); - retval = vfprintf(stderr, fmt, args); - va_end(args); + va_start(args, fmt); + retval = vfprintf(stderr, fmt, args); + va_end(args); - return retval; - } + return retval; +} // Trace message output -FILE* TraceState::SetTraceFile(const char* trace_filename) - { - FILE* newfile; +FILE* TraceState::SetTraceFile(const char* trace_filename) { + FILE* newfile; - if ( util::streq(trace_filename, "-") ) - newfile = stderr; - else - newfile = fopen(trace_filename, "w"); + if ( util::streq(trace_filename, "-") ) + newfile = stderr; + else + newfile = fopen(trace_filename, "w"); - FILE* oldfile = trace_file; - if ( newfile ) - { - trace_file = newfile; - } - else - { - fprintf(stderr, "Unable to open trace file %s\n", trace_filename); - trace_file = nullptr; - } + FILE* oldfile = trace_file; + if ( newfile ) { + trace_file = newfile; + } + else { + fprintf(stderr, "Unable to open trace file %s\n", trace_filename); + trace_file = nullptr; + } - return oldfile; - } + return oldfile; +} -void TraceState::TraceOn() - { - fprintf(stderr, "Execution tracing ON.\n"); - dbgtrace = true; - } +void TraceState::TraceOn() { + fprintf(stderr, "Execution tracing ON.\n"); + dbgtrace = true; +} -void TraceState::TraceOff() - { - fprintf(stderr, "Execution tracing OFF.\n"); - dbgtrace = false; - } +void TraceState::TraceOff() { + fprintf(stderr, "Execution tracing OFF.\n"); + dbgtrace = false; +} -int TraceState::LogTrace(const char* fmt, ...) - { - va_list args; - int retval; +int TraceState::LogTrace(const char* fmt, ...) { + va_list args; + int retval; - va_start(args, fmt); + va_start(args, fmt); - // Prefix includes timestamp and file/line info. - fprintf(trace_file, "%.6f ", run_state::network_time); + // Prefix includes timestamp and file/line info. + fprintf(trace_file, "%.6f ", run_state::network_time); - const Stmt* stmt; - Location loc; - loc.filename = nullptr; + const Stmt* stmt; + Location loc; + loc.filename = nullptr; - if ( g_frame_stack.size() > 0 && g_frame_stack.back() ) - { - stmt = g_frame_stack.back()->GetNextStmt(); - if ( stmt ) - loc = *stmt->GetLocationInfo(); - else - { - const ScriptFunc* f = g_frame_stack.back()->GetFunction(); - if ( f ) - loc = *f->GetLocationInfo(); - } - } + if ( g_frame_stack.size() > 0 && g_frame_stack.back() ) { + stmt = g_frame_stack.back()->GetNextStmt(); + if ( stmt ) + loc = *stmt->GetLocationInfo(); + else { + const ScriptFunc* f = g_frame_stack.back()->GetFunction(); + if ( f ) + loc = *f->GetLocationInfo(); + } + } - if ( ! loc.filename ) - { - loc.filename = util::copy_string(""); - loc.last_line = 0; - } + if ( ! loc.filename ) { + loc.filename = util::copy_string(""); + loc.last_line = 0; + } - fprintf(trace_file, "%s:%d", loc.filename, loc.last_line); + fprintf(trace_file, "%s:%d", loc.filename, loc.last_line); - // Each stack frame is indented. - for ( int i = 0; i < int(g_frame_stack.size()); ++i ) - fprintf(trace_file, "\t"); + // Each stack frame is indented. + for ( int i = 0; i < int(g_frame_stack.size()); ++i ) + fprintf(trace_file, "\t"); - retval = vfprintf(trace_file, fmt, args); + retval = vfprintf(trace_file, fmt, args); - fflush(trace_file); - va_end(args); + fflush(trace_file); + va_end(args); - return retval; - } + return retval; +} // Helper functions. -void get_first_statement(Stmt* list, Stmt*& first, Location& loc) - { - if ( ! list ) - { - first = nullptr; - return; - } +void get_first_statement(Stmt* list, Stmt*& first, Location& loc) { + if ( ! list ) { + first = nullptr; + return; + } - first = list; - while ( first->Tag() == STMT_LIST ) - { - if ( first->AsStmtList()->Stmts()[0] ) - first = first->AsStmtList()->Stmts()[0].get(); - else - break; - } + first = list; + while ( first->Tag() == STMT_LIST ) { + if ( first->AsStmtList()->Stmts()[0] ) + first = first->AsStmtList()->Stmts()[0].get(); + else + break; + } - loc = *first->GetLocationInfo(); - } + loc = *first->GetLocationInfo(); +} static void parse_function_name(vector& result, ParseLocationRec& plr, - const string& s) - { // function name - const auto& id = lookup_ID(s.c_str(), current_module.c_str()); + const string& s) { // function name + const auto& id = lookup_ID(s.c_str(), current_module.c_str()); - if ( ! id ) - { - string fullname = make_full_var_name(current_module.c_str(), s.c_str()); - debug_msg("Function %s not defined.\n", fullname.c_str()); - plr.type = PLR_UNKNOWN; - return; - } + if ( ! id ) { + string fullname = make_full_var_name(current_module.c_str(), s.c_str()); + debug_msg("Function %s not defined.\n", fullname.c_str()); + plr.type = PLR_UNKNOWN; + return; + } - if ( ! id->GetType()->AsFuncType() ) - { - debug_msg("Function %s not declared.\n", id->Name()); - plr.type = PLR_UNKNOWN; - return; - } + if ( ! id->GetType()->AsFuncType() ) { + debug_msg("Function %s not declared.\n", id->Name()); + plr.type = PLR_UNKNOWN; + return; + } - if ( ! id->HasVal() ) - { - debug_msg("Function %s declared but not defined.\n", id->Name()); - plr.type = PLR_UNKNOWN; - return; - } + if ( ! id->HasVal() ) { + debug_msg("Function %s declared but not defined.\n", id->Name()); + plr.type = PLR_UNKNOWN; + return; + } - const Func* func = id->GetVal()->AsFunc(); - const vector& bodies = func->GetBodies(); + const Func* func = id->GetVal()->AsFunc(); + const vector& bodies = func->GetBodies(); - if ( bodies.size() == 0 ) - { - debug_msg("Function %s is a built-in function\n", id->Name()); - plr.type = PLR_UNKNOWN; - return; - } + if ( bodies.size() == 0 ) { + debug_msg("Function %s is a built-in function\n", id->Name()); + plr.type = PLR_UNKNOWN; + return; + } - Stmt* body = nullptr; // the particular body we care about; 0 = all + Stmt* body = nullptr; // the particular body we care about; 0 = all - if ( bodies.size() == 1 ) - body = bodies[0].stmts.get(); - else - { - while ( true ) - { - debug_msg("There are multiple definitions of that event handler.\n" - "Please choose one of the following options:\n"); - for ( unsigned int i = 0; i < bodies.size(); ++i ) - { - Stmt* first; - Location stmt_loc; - get_first_statement(bodies[i].stmts.get(), first, stmt_loc); - debug_msg("[%d] %s:%d\n", i + 1, stmt_loc.filename, stmt_loc.first_line); - } + if ( bodies.size() == 1 ) + body = bodies[0].stmts.get(); + else { + while ( true ) { + debug_msg( + "There are multiple definitions of that event handler.\n" + "Please choose one of the following options:\n"); + for ( unsigned int i = 0; i < bodies.size(); ++i ) { + Stmt* first; + Location stmt_loc; + get_first_statement(bodies[i].stmts.get(), first, stmt_loc); + debug_msg("[%d] %s:%d\n", i + 1, stmt_loc.filename, stmt_loc.first_line); + } - debug_msg("[a] All of the above\n"); - debug_msg("[n] None of the above\n"); - debug_msg("Enter your choice: "); + debug_msg("[a] All of the above\n"); + debug_msg("[n] None of the above\n"); + debug_msg("Enter your choice: "); - char charinput[256]; - if ( ! fgets(charinput, sizeof(charinput) - 1, stdin) ) - { - plr.type = PLR_UNKNOWN; - return; - } + char charinput[256]; + if ( ! fgets(charinput, sizeof(charinput) - 1, stdin) ) { + plr.type = PLR_UNKNOWN; + return; + } - if ( charinput[strlen(charinput) - 1] == '\n' ) - charinput[strlen(charinput) - 1] = 0; + if ( charinput[strlen(charinput) - 1] == '\n' ) + charinput[strlen(charinput) - 1] = 0; - string input = charinput; + string input = charinput; - if ( input == "a" ) - break; + if ( input == "a" ) + break; - if ( input == "n" ) - { - plr.type = PLR_UNKNOWN; - return; - } + if ( input == "n" ) { + plr.type = PLR_UNKNOWN; + return; + } - int option = atoi(input.c_str()); - if ( option > 0 && option <= (int)bodies.size() ) - { - body = bodies[option - 1].stmts.get(); - break; - } - } - } + int option = atoi(input.c_str()); + if ( option > 0 && option <= (int)bodies.size() ) { + body = bodies[option - 1].stmts.get(); + break; + } + } + } - plr.type = PLR_FUNCTION; + plr.type = PLR_FUNCTION; - // Find first atomic (non-STMT_LIST) statement - Stmt* first; - Location stmt_loc; + // Find first atomic (non-STMT_LIST) statement + Stmt* first; + Location stmt_loc; - if ( body ) - { - get_first_statement(body, first, stmt_loc); - if ( first ) - { - plr.stmt = first; - plr.filename = stmt_loc.filename; - plr.line = stmt_loc.last_line; - } - } + if ( body ) { + get_first_statement(body, first, stmt_loc); + if ( first ) { + plr.stmt = first; + plr.filename = stmt_loc.filename; + plr.line = stmt_loc.last_line; + } + } - else - { - result.pop_back(); - ParseLocationRec result_plr; + else { + result.pop_back(); + ParseLocationRec result_plr; - for ( const auto& body : bodies ) - { - get_first_statement(body.stmts.get(), first, stmt_loc); - if ( ! first ) - continue; + for ( const auto& body : bodies ) { + get_first_statement(body.stmts.get(), first, stmt_loc); + if ( ! first ) + continue; - result_plr.type = PLR_FUNCTION; - result_plr.stmt = first; - result_plr.filename = stmt_loc.filename; - result_plr.line = stmt_loc.last_line; - result.push_back(result_plr); - } - } - } + result_plr.type = PLR_FUNCTION; + result_plr.stmt = first; + result_plr.filename = stmt_loc.filename; + result_plr.line = stmt_loc.last_line; + result.push_back(result_plr); + } + } +} -vector parse_location_string(const string& s) - { - vector result; - result.emplace_back(); - ParseLocationRec& plr = result[0]; +vector parse_location_string(const string& s) { + vector result; + result.emplace_back(); + ParseLocationRec& plr = result[0]; - // If PLR_FILE_AND_LINE, set this to the filename you want; for - // memory management reasons, the real filename is set when looking - // up the line number to find the corresponding statement. - std::string loc_filename; + // If PLR_FILE_AND_LINE, set this to the filename you want; for + // memory management reasons, the real filename is set when looking + // up the line number to find the corresponding statement. + std::string loc_filename; - if ( sscanf(s.c_str(), "%d", &plr.line) ) - { // just a line number (implicitly referring to the current file) - loc_filename = g_debugger_state.last_loc.filename; - plr.type = PLR_FILE_AND_LINE; - } + if ( sscanf(s.c_str(), "%d", &plr.line) ) { // just a line number (implicitly referring to the current file) + loc_filename = g_debugger_state.last_loc.filename; + plr.type = PLR_FILE_AND_LINE; + } - else - { - string::size_type pos_colon = s.find(':'); - string::size_type pos_dblcolon = s.find("::"); + else { + string::size_type pos_colon = s.find(':'); + string::size_type pos_dblcolon = s.find("::"); - if ( pos_colon == string::npos || pos_dblcolon != string::npos ) - parse_function_name(result, plr, s); - else - { // file:line - string policy_filename = s.substr(0, pos_colon); - string line_string = s.substr(pos_colon + 1, s.length() - pos_colon); + if ( pos_colon == string::npos || pos_dblcolon != string::npos ) + parse_function_name(result, plr, s); + else { // file:line + string policy_filename = s.substr(0, pos_colon); + string line_string = s.substr(pos_colon + 1, s.length() - pos_colon); - if ( ! sscanf(line_string.c_str(), "%d", &plr.line) ) - plr.type = PLR_UNKNOWN; + if ( ! sscanf(line_string.c_str(), "%d", &plr.line) ) + plr.type = PLR_UNKNOWN; - string path(util::find_script_file(policy_filename, util::zeek_path())); + string path(util::find_script_file(policy_filename, util::zeek_path())); - if ( path.empty() ) - { - debug_msg("No such policy file: %s.\n", policy_filename.c_str()); - plr.type = PLR_UNKNOWN; - return result; - } + if ( path.empty() ) { + debug_msg("No such policy file: %s.\n", policy_filename.c_str()); + plr.type = PLR_UNKNOWN; + return result; + } - loc_filename = path; - plr.type = PLR_FILE_AND_LINE; - } - } + loc_filename = path; + plr.type = PLR_FILE_AND_LINE; + } + } - if ( plr.type == PLR_FILE_AND_LINE ) - { - auto iter = g_dbgfilemaps.find(loc_filename); - if ( iter == g_dbgfilemaps.end() ) - reporter->InternalError("Policy file %s should have been loaded\n", - loc_filename.data()); + if ( plr.type == PLR_FILE_AND_LINE ) { + auto iter = g_dbgfilemaps.find(loc_filename); + if ( iter == g_dbgfilemaps.end() ) + reporter->InternalError("Policy file %s should have been loaded\n", loc_filename.data()); - if ( plr.line > how_many_lines_in(loc_filename.data()) ) - { - debug_msg("No line %d in %s.\n", plr.line, loc_filename.data()); - plr.type = PLR_UNKNOWN; - return result; - } + if ( plr.line > how_many_lines_in(loc_filename.data()) ) { + debug_msg("No line %d in %s.\n", plr.line, loc_filename.data()); + plr.type = PLR_UNKNOWN; + return result; + } - StmtLocMapping* hit = nullptr; - for ( const auto entry : *(iter->second) ) - { - plr.filename = entry->Loc().filename; + StmtLocMapping* hit = nullptr; + for ( const auto entry : *(iter->second) ) { + plr.filename = entry->Loc().filename; - if ( entry->Loc().first_line > plr.line ) - break; + if ( entry->Loc().first_line > plr.line ) + break; - if ( plr.line >= entry->Loc().first_line && plr.line <= entry->Loc().last_line ) - { - hit = entry; - break; - } - } + if ( plr.line >= entry->Loc().first_line && plr.line <= entry->Loc().last_line ) { + hit = entry; + break; + } + } - if ( hit ) - plr.stmt = hit->Statement(); - else - plr.stmt = nullptr; - } + if ( hit ) + plr.stmt = hit->Statement(); + else + plr.stmt = nullptr; + } - return result; - } + return result; +} // Interactive debugging console. @@ -431,54 +387,50 @@ static int dbg_dispatch_cmd(DebugCmd cmd_code, const vector& args); void using_history(void); -static bool init_readline() - { - // ### Set up custom completion. +static bool init_readline() { + // ### Set up custom completion. - rl_outstream = stderr; - using_history(); + rl_outstream = stderr; + using_history(); - return false; - } + return false; +} #endif -void break_signal(int) - { - g_debugger_state.BreakBeforeNextStmt(true); - g_debugger_state.BreakFromSignal(true); - } +void break_signal(int) { + g_debugger_state.BreakBeforeNextStmt(true); + g_debugger_state.BreakFromSignal(true); +} -int dbg_init_debugger(const char* cmdfile) - { - if ( ! g_policy_debug ) - return 0; // probably shouldn't have been called +int dbg_init_debugger(const char* cmdfile) { + if ( ! g_policy_debug ) + return 0; // probably shouldn't have been called - init_global_dbg_constants(); + init_global_dbg_constants(); - // Hit the debugger before running anything. - g_debugger_state.BreakBeforeNextStmt(true); + // Hit the debugger before running anything. + g_debugger_state.BreakBeforeNextStmt(true); - if ( cmdfile ) - // ### Implement this - debug_msg("Command files not supported. Using interactive mode.\n"); + if ( cmdfile ) + // ### Implement this + debug_msg("Command files not supported. Using interactive mode.\n"); - // ### if ( interactive ) (i.e., not reading cmds from a file) + // ### if ( interactive ) (i.e., not reading cmds from a file) #ifdef HAVE_READLINE - init_readline(); + init_readline(); #endif - setsignal(SIGINT, break_signal); - setsignal(SIGTERM, break_signal); + setsignal(SIGINT, break_signal); + setsignal(SIGTERM, break_signal); - return 1; - } + return 1; +} -int dbg_shutdown_debugger() - { - // ### TODO: Remove signal handlers - return 1; - } +int dbg_shutdown_debugger() { + // ### TODO: Remove signal handlers + return 1; +} // Umesh: I stole this code from libedit; I modified it here to use // s to avoid memory management problems. The main command is returned @@ -488,517 +440,465 @@ int dbg_shutdown_debugger() // Parse the string into individual tokens, similarly to how shell // would do it. -void tokenize(const char* cstr, string& operation, vector& arguments) - { - int num_tokens = 0; - char delim = '\0'; - const string str(cstr); +void tokenize(const char* cstr, string& operation, vector& arguments) { + int num_tokens = 0; + char delim = '\0'; + const string str(cstr); - for ( int i = 0; i < (signed int)str.length(); ++i ) - { - while ( isspace((unsigned char)str[i]) ) - ++i; + for ( int i = 0; i < (signed int)str.length(); ++i ) { + while ( isspace((unsigned char)str[i]) ) + ++i; - int start = i; + int start = i; - for ( ; str[i]; ++i ) - { - if ( str[i] == '\\' ) - { - if ( i < (signed int)str.length() ) - ++i; - } + for ( ; str[i]; ++i ) { + if ( str[i] == '\\' ) { + if ( i < (signed int)str.length() ) + ++i; + } - else if ( ! delim && str[i] == '(' ) - delim = ')'; + else if ( ! delim && str[i] == '(' ) + delim = ')'; - else if ( ! delim && (str[i] == '\'' || str[i] == '"') ) - delim = str[i]; + else if ( ! delim && (str[i] == '\'' || str[i] == '"') ) + delim = str[i]; - else if ( delim && str[i] == delim ) - { - delim = '\0'; - ++i; - break; - } + else if ( delim && str[i] == delim ) { + delim = '\0'; + ++i; + break; + } - else if ( ! delim && isspace(str[i]) ) - break; - } + else if ( ! delim && isspace(str[i]) ) + break; + } - size_t len = i - start; + size_t len = i - start; - if ( ! num_tokens ) - operation = string(str, start, len); - else - arguments.emplace_back(str, start, len); + if ( ! num_tokens ) + operation = string(str, start, len); + else + arguments.emplace_back(str, start, len); - ++num_tokens; - } - } + ++num_tokens; + } +} // Given a command string, parse it and send the command to be dispatched. -int dbg_execute_command(const char* cmd) - { - bool matched_history = false; +int dbg_execute_command(const char* cmd) { + bool matched_history = false; - if ( ! cmd ) - return 0; + if ( ! cmd ) + return 0; - if ( util::streq(cmd, "") ) // do the GDB command completion - { + if ( util::streq(cmd, "") ) // do the GDB command completion + { #ifdef HAVE_READLINE - int i; - for ( i = history_length; i >= 1; --i ) - { - HIST_ENTRY* entry = history_get(i); - if ( ! entry ) - return 0; + int i; + for ( i = history_length; i >= 1; --i ) { + HIST_ENTRY* entry = history_get(i); + if ( ! entry ) + return 0; - const DebugCmdInfo* info = (const DebugCmdInfo*)entry->data; + const DebugCmdInfo* info = (const DebugCmdInfo*)entry->data; - if ( info && info->Repeatable() ) - { - cmd = entry->line; - matched_history = true; - break; - } - } + if ( info && info->Repeatable() ) { + cmd = entry->line; + matched_history = true; + break; + } + } #endif - if ( ! matched_history ) - return 0; - } + if ( ! matched_history ) + return 0; + } - char* localcmd = util::copy_string(cmd); + char* localcmd = util::copy_string(cmd); - string opstring; - vector arguments; - tokenize(localcmd, opstring, arguments); + string opstring; + vector arguments; + tokenize(localcmd, opstring, arguments); - delete[] localcmd; + delete[] localcmd; - // Make sure we know this op name. - auto matching_cmds_buf = std::make_unique(num_debug_cmds()); - auto matching_cmds = matching_cmds_buf.get(); - int num_matches = find_all_matching_cmds(opstring, matching_cmds); + // Make sure we know this op name. + auto matching_cmds_buf = std::make_unique(num_debug_cmds()); + auto matching_cmds = matching_cmds_buf.get(); + int num_matches = find_all_matching_cmds(opstring, matching_cmds); - if ( ! num_matches ) - { - debug_msg("No Matching command for '%s'.\n", opstring.c_str()); - return 0; - } + if ( ! num_matches ) { + debug_msg("No Matching command for '%s'.\n", opstring.c_str()); + return 0; + } - if ( num_matches > 1 ) - { - debug_msg("Ambiguous command; could be\n"); + if ( num_matches > 1 ) { + debug_msg("Ambiguous command; could be\n"); - for ( int i = 0; i < num_debug_cmds(); ++i ) - if ( matching_cmds[i] ) - debug_msg("\t%s\n", matching_cmds[i]); + for ( int i = 0; i < num_debug_cmds(); ++i ) + if ( matching_cmds[i] ) + debug_msg("\t%s\n", matching_cmds[i]); - return 0; - } + return 0; + } - // Matched exactly one command: find out which one. - DebugCmd cmd_code = dcInvalid; - for ( int i = 0; i < num_debug_cmds(); ++i ) - if ( matching_cmds[i] ) - { - cmd_code = (DebugCmd)i; - break; - } + // Matched exactly one command: find out which one. + DebugCmd cmd_code = dcInvalid; + for ( int i = 0; i < num_debug_cmds(); ++i ) + if ( matching_cmds[i] ) { + cmd_code = (DebugCmd)i; + break; + } #ifdef HAVE_READLINE - // Insert command into history. - if ( ! matched_history && cmd && *cmd ) - { - /* The prototype for add_history(), at least under MacOS, - * has it taking a char* rather than a const char*. - * But documentation at - * http://tiswww.case.edu/php/chet/readline/history.html - * suggests that it's safe to assume it's really const char*. - */ - add_history((char*)cmd); - HISTORY_STATE* state = history_get_history_state(); - state->entries[state->length - 1]->data = (histdata_t*)get_debug_cmd_info(cmd_code); - } + // Insert command into history. + if ( ! matched_history && cmd && *cmd ) { + /* The prototype for add_history(), at least under MacOS, + * has it taking a char* rather than a const char*. + * But documentation at + * http://tiswww.case.edu/php/chet/readline/history.html + * suggests that it's safe to assume it's really const char*. + */ + add_history((char*)cmd); + HISTORY_STATE* state = history_get_history_state(); + state->entries[state->length - 1]->data = (histdata_t*)get_debug_cmd_info(cmd_code); + } #endif - if ( int(cmd_code) >= num_debug_cmds() ) - reporter->InternalError("Assertion failed: int(cmd_code) < num_debug_cmds()"); + if ( int(cmd_code) >= num_debug_cmds() ) + reporter->InternalError("Assertion failed: int(cmd_code) < num_debug_cmds()"); - // Dispatch to the op-specific handler (with args). - int retcode = dbg_dispatch_cmd(cmd_code, arguments); - if ( retcode < 0 ) - return retcode; + // Dispatch to the op-specific handler (with args). + int retcode = dbg_dispatch_cmd(cmd_code, arguments); + if ( retcode < 0 ) + return retcode; - const DebugCmdInfo* info = get_debug_cmd_info(cmd_code); - if ( ! info ) - reporter->InternalError("Assertion failed: info"); + const DebugCmdInfo* info = get_debug_cmd_info(cmd_code); + if ( ! info ) + reporter->InternalError("Assertion failed: info"); - if ( ! info ) - return -2; // ### yuck, why -2? + if ( ! info ) + return -2; // ### yuck, why -2? - return info->ResumeExecution(); - } + return info->ResumeExecution(); +} // Call the appropriate function for the command. -static int dbg_dispatch_cmd(DebugCmd cmd_code, const vector& args) - { - switch ( cmd_code ) - { - case dcHelp: - dbg_cmd_help(cmd_code, args); - break; +static int dbg_dispatch_cmd(DebugCmd cmd_code, const vector& args) { + switch ( cmd_code ) { + case dcHelp: dbg_cmd_help(cmd_code, args); break; - case dcQuit: - debug_msg("Program Terminating\n"); - exit(0); + case dcQuit: debug_msg("Program Terminating\n"); exit(0); - case dcNext: - g_frame_stack.back()->BreakBeforeNextStmt(true); - step_or_next_pending = true; - last_frame = g_frame_stack.back(); - break; + case dcNext: + g_frame_stack.back()->BreakBeforeNextStmt(true); + step_or_next_pending = true; + last_frame = g_frame_stack.back(); + break; - case dcStep: - g_debugger_state.BreakBeforeNextStmt(true); - step_or_next_pending = true; - last_frame = g_frame_stack.back(); - break; + case dcStep: + g_debugger_state.BreakBeforeNextStmt(true); + step_or_next_pending = true; + last_frame = g_frame_stack.back(); + break; - case dcContinue: - g_debugger_state.BreakBeforeNextStmt(false); - debug_msg("Continuing.\n"); - break; + case dcContinue: + g_debugger_state.BreakBeforeNextStmt(false); + debug_msg("Continuing.\n"); + break; - case dcFinish: - g_frame_stack.back()->BreakOnReturn(true); - g_debugger_state.BreakBeforeNextStmt(false); - break; + case dcFinish: + g_frame_stack.back()->BreakOnReturn(true); + g_debugger_state.BreakBeforeNextStmt(false); + break; - case dcBreak: - dbg_cmd_break(cmd_code, args); - break; + case dcBreak: dbg_cmd_break(cmd_code, args); break; - case dcBreakCondition: - dbg_cmd_break_condition(cmd_code, args); - break; + case dcBreakCondition: dbg_cmd_break_condition(cmd_code, args); break; - case dcDeleteBreak: - case dcClearBreak: - case dcDisableBreak: - case dcEnableBreak: - case dcIgnoreBreak: - dbg_cmd_break_set_state(cmd_code, args); - break; + case dcDeleteBreak: + case dcClearBreak: + case dcDisableBreak: + case dcEnableBreak: + case dcIgnoreBreak: dbg_cmd_break_set_state(cmd_code, args); break; - case dcPrint: - dbg_cmd_print(cmd_code, args); - break; + case dcPrint: dbg_cmd_print(cmd_code, args); break; - case dcBacktrace: - return dbg_cmd_backtrace(cmd_code, args); + case dcBacktrace: return dbg_cmd_backtrace(cmd_code, args); - case dcFrame: - case dcUp: - case dcDown: - return dbg_cmd_frame(cmd_code, args); + case dcFrame: + case dcUp: + case dcDown: return dbg_cmd_frame(cmd_code, args); - case dcInfo: - return dbg_cmd_info(cmd_code, args); + case dcInfo: return dbg_cmd_info(cmd_code, args); - case dcList: - return dbg_cmd_list(cmd_code, args); + case dcList: return dbg_cmd_list(cmd_code, args); - case dcDisplay: - case dcUndisplay: - debug_msg("Command not yet implemented.\n"); - break; + case dcDisplay: + case dcUndisplay: debug_msg("Command not yet implemented.\n"); break; - case dcTrace: - return dbg_cmd_trace(cmd_code, args); + case dcTrace: return dbg_cmd_trace(cmd_code, args); - default: - debug_msg("INTERNAL ERROR: " - "Got an unknown debugger command in DbgDispatchCmd: %d\n", - cmd_code); - return 0; - } + default: + debug_msg( + "INTERNAL ERROR: " + "Got an unknown debugger command in DbgDispatchCmd: %d\n", + cmd_code); + return 0; + } - return 0; - } + return 0; +} -static char* get_prompt(bool reset_counter = false) - { - static char prompt[512]; - static int counter = 0; +static char* get_prompt(bool reset_counter = false) { + static char prompt[512]; + static int counter = 0; - if ( reset_counter ) - counter = 0; + if ( reset_counter ) + counter = 0; - snprintf(prompt, sizeof(prompt), "(Zeek [%d]) ", counter++); + snprintf(prompt, sizeof(prompt), "(Zeek [%d]) ", counter++); - return prompt; - } + return prompt; +} -string get_context_description(const Stmt* stmt, const Frame* frame) - { - ODesc d; - const ScriptFunc* func = frame ? frame->GetFunction() : nullptr; +string get_context_description(const Stmt* stmt, const Frame* frame) { + ODesc d; + const ScriptFunc* func = frame ? frame->GetFunction() : nullptr; - if ( func ) - func->DescribeDebug(&d, frame->GetFuncArgs()); - else - d.Add("", 0); + if ( func ) + func->DescribeDebug(&d, frame->GetFuncArgs()); + else + d.Add("", 0); - Location loc; - if ( stmt ) - loc = *stmt->GetLocationInfo(); - else - { - loc.filename = util::copy_string(""); - loc.last_line = 0; - } + Location loc; + if ( stmt ) + loc = *stmt->GetLocationInfo(); + else { + loc.filename = util::copy_string(""); + loc.last_line = 0; + } - size_t buf_size = strlen(d.Description()) + strlen(loc.filename) + 1024; - char* buf = new char[buf_size]; - snprintf(buf, buf_size, "In %s at %s:%d", d.Description(), loc.filename, loc.last_line); + size_t buf_size = strlen(d.Description()) + strlen(loc.filename) + 1024; + char* buf = new char[buf_size]; + snprintf(buf, buf_size, "In %s at %s:%d", d.Description(), loc.filename, loc.last_line); - string retval(buf); - delete[] buf; - return retval; - } + string retval(buf); + delete[] buf; + return retval; +} -int dbg_handle_debug_input() - { - static char* input_line = nullptr; - int status = 0; +int dbg_handle_debug_input() { + static char* input_line = nullptr; + int status = 0; - if ( g_debugger_state.BreakFromSignal() ) - { - debug_msg("Program received signal SIGINT: entering debugger\n"); + if ( g_debugger_state.BreakFromSignal() ) { + debug_msg("Program received signal SIGINT: entering debugger\n"); - g_debugger_state.BreakFromSignal(false); - } + g_debugger_state.BreakFromSignal(false); + } - Frame* curr_frame = g_frame_stack.back(); - const ScriptFunc* func = curr_frame->GetFunction(); - if ( func ) - current_module = extract_module_name(func->Name()); - else - current_module = GLOBAL_MODULE_NAME; + Frame* curr_frame = g_frame_stack.back(); + const ScriptFunc* func = curr_frame->GetFunction(); + if ( func ) + current_module = extract_module_name(func->Name()); + else + current_module = GLOBAL_MODULE_NAME; - const Stmt* stmt = curr_frame->GetNextStmt(); - if ( ! stmt ) - reporter->InternalError("Assertion failed: stmt != 0"); + const Stmt* stmt = curr_frame->GetNextStmt(); + if ( ! stmt ) + reporter->InternalError("Assertion failed: stmt != 0"); - const Location loc = *stmt->GetLocationInfo(); + const Location loc = *stmt->GetLocationInfo(); - if ( ! step_or_next_pending || g_frame_stack.back() != last_frame ) - { - string context = get_context_description(stmt, g_frame_stack.back()); - debug_msg("%s\n", context.c_str()); - } + if ( ! step_or_next_pending || g_frame_stack.back() != last_frame ) { + string context = get_context_description(stmt, g_frame_stack.back()); + debug_msg("%s\n", context.c_str()); + } - step_or_next_pending = false; + step_or_next_pending = false; - PrintLines(loc.filename, loc.first_line, loc.last_line - loc.first_line + 1, true); - g_debugger_state.last_loc = loc; + PrintLines(loc.filename, loc.first_line, loc.last_line - loc.first_line + 1, true); + g_debugger_state.last_loc = loc; - do - { - // readline returns a pointer to a buffer it allocates; it's - // freed at the bottom. + do { + // readline returns a pointer to a buffer it allocates; it's + // freed at the bottom. #ifdef HAVE_READLINE - input_line = readline(get_prompt()); + input_line = readline(get_prompt()); #else - printf("%s", get_prompt()); + printf("%s", get_prompt()); - // readline uses malloc, and we want to be consistent - // with it. - input_line = (char*)util::safe_malloc(1024); - input_line[1023] = 0; - // ### Maybe it's not always stdin. - input_line = fgets(input_line, 1023, stdin); + // readline uses malloc, and we want to be consistent + // with it. + input_line = (char*)util::safe_malloc(1024); + input_line[1023] = 0; + // ### Maybe it's not always stdin. + input_line = fgets(input_line, 1023, stdin); #endif - // ### Maybe not stdin; maybe do better cleanup. - if ( feof(stdin) ) - exit(0); + // ### Maybe not stdin; maybe do better cleanup. + if ( feof(stdin) ) + exit(0); - status = dbg_execute_command(input_line); + status = dbg_execute_command(input_line); - if ( input_line ) - { - free(input_line); // this was malloc'ed - input_line = nullptr; - } - else - exit(0); - } while ( status == 0 ); + if ( input_line ) { + free(input_line); // this was malloc'ed + input_line = nullptr; + } + else + exit(0); + } while ( status == 0 ); - // Clear out some state. ### Is there a better place? - g_debugger_state.curr_frame_idx = 0; - g_debugger_state.already_did_list = false; + // Clear out some state. ### Is there a better place? + g_debugger_state.curr_frame_idx = 0; + g_debugger_state.already_did_list = false; - setsignal(SIGINT, break_signal); - setsignal(SIGTERM, break_signal); + setsignal(SIGINT, break_signal); + setsignal(SIGTERM, break_signal); - return 0; - } + return 0; +} // Return true to continue execution, false to abort. -bool pre_execute_stmt(Stmt* stmt, Frame* f) - { - if ( ! g_policy_debug || stmt->Tag() == STMT_LIST || stmt->Tag() == STMT_NULL ) - return true; +bool pre_execute_stmt(Stmt* stmt, Frame* f) { + if ( ! g_policy_debug || stmt->Tag() == STMT_LIST || stmt->Tag() == STMT_NULL ) + return true; - if ( g_trace_state.DoTrace() ) - { - ODesc d; - stmt->Describe(&d); + if ( g_trace_state.DoTrace() ) { + ODesc d; + stmt->Describe(&d); - const char* desc = d.Description(); - const char* s = strchr(desc, '\n'); + const char* desc = d.Description(); + const char* s = strchr(desc, '\n'); - int len; - if ( s ) - len = s - desc; - else - len = strlen(desc); + int len; + if ( s ) + len = s - desc; + else + len = strlen(desc); - g_trace_state.LogTrace("%*s\n", len, desc); - } + g_trace_state.LogTrace("%*s\n", len, desc); + } - bool should_break = false; + bool should_break = false; - if ( g_debugger_state.BreakBeforeNextStmt() || f->BreakBeforeNextStmt() ) - { - if ( g_debugger_state.BreakBeforeNextStmt() ) - g_debugger_state.BreakBeforeNextStmt(false); + if ( g_debugger_state.BreakBeforeNextStmt() || f->BreakBeforeNextStmt() ) { + if ( g_debugger_state.BreakBeforeNextStmt() ) + g_debugger_state.BreakBeforeNextStmt(false); - if ( f->BreakBeforeNextStmt() ) - f->BreakBeforeNextStmt(false); + if ( f->BreakBeforeNextStmt() ) + f->BreakBeforeNextStmt(false); - should_break = true; - } + should_break = true; + } - if ( stmt->BPCount() ) - { - pair p; + if ( stmt->BPCount() ) { + pair p; - p = g_debugger_state.breakpoint_map.equal_range(stmt); + p = g_debugger_state.breakpoint_map.equal_range(stmt); - if ( p.first == p.second ) - reporter->InternalError("Breakpoint count nonzero, but no matching breakpoints"); + if ( p.first == p.second ) + reporter->InternalError("Breakpoint count nonzero, but no matching breakpoints"); - for ( BPMapType::iterator i = p.first; i != p.second; ++i ) - { - int break_code = i->second->ShouldBreak(stmt); - if ( break_code == 2 ) // ### 2? - { - i->second->SetEnable(false); - delete i->second; - } + for ( BPMapType::iterator i = p.first; i != p.second; ++i ) { + int break_code = i->second->ShouldBreak(stmt); + if ( break_code == 2 ) // ### 2? + { + i->second->SetEnable(false); + delete i->second; + } - should_break = should_break || break_code; - } - } + should_break = should_break || break_code; + } + } - if ( should_break ) - dbg_handle_debug_input(); + if ( should_break ) + dbg_handle_debug_input(); - return true; - } + return true; +} -bool post_execute_stmt(Stmt* stmt, Frame* f, Val* result, StmtFlowType* flow) - { - // Handle the case where someone issues a "next" debugger command, - // but we're at a return statement, so the next statement is in - // some other function. - if ( *flow == FLOW_RETURN && f->BreakBeforeNextStmt() ) - g_debugger_state.BreakBeforeNextStmt(true); +bool post_execute_stmt(Stmt* stmt, Frame* f, Val* result, StmtFlowType* flow) { + // Handle the case where someone issues a "next" debugger command, + // but we're at a return statement, so the next statement is in + // some other function. + if ( *flow == FLOW_RETURN && f->BreakBeforeNextStmt() ) + g_debugger_state.BreakBeforeNextStmt(true); - // Handle "finish" commands. - if ( *flow == FLOW_RETURN && f->BreakOnReturn() ) - { - if ( result ) - { - ODesc d; - result->Describe(&d); - debug_msg("Return Value: '%s'\n", d.Description()); - } - else - debug_msg("Return Value: \n"); + // Handle "finish" commands. + if ( *flow == FLOW_RETURN && f->BreakOnReturn() ) { + if ( result ) { + ODesc d; + result->Describe(&d); + debug_msg("Return Value: '%s'\n", d.Description()); + } + else + debug_msg("Return Value: \n"); - g_debugger_state.BreakBeforeNextStmt(true); - f->BreakOnReturn(false); - } + g_debugger_state.BreakBeforeNextStmt(true); + f->BreakOnReturn(false); + } - return true; - } + return true; +} -ValPtr dbg_eval_expr(const char* expr) - { - // Push the current frame's associated scope. - // Note: g_debugger_state.curr_frame_idx is the user-visible number, - // while the array index goes in the opposite direction - int frame_idx = (g_frame_stack.size() - 1) - g_debugger_state.curr_frame_idx; +ValPtr dbg_eval_expr(const char* expr) { + // Push the current frame's associated scope. + // Note: g_debugger_state.curr_frame_idx is the user-visible number, + // while the array index goes in the opposite direction + int frame_idx = (g_frame_stack.size() - 1) - g_debugger_state.curr_frame_idx; - if ( ! (frame_idx >= 0 && (unsigned)frame_idx < g_frame_stack.size()) ) - reporter->InternalError( - "Assertion failed: frame_idx >= 0 && (unsigned) frame_idx < g_frame_stack.size()"); + if ( ! (frame_idx >= 0 && (unsigned)frame_idx < g_frame_stack.size()) ) + reporter->InternalError("Assertion failed: frame_idx >= 0 && (unsigned) frame_idx < g_frame_stack.size()"); - Frame* frame = g_frame_stack[frame_idx]; - if ( ! (frame) ) - reporter->InternalError("Assertion failed: frame"); + Frame* frame = g_frame_stack[frame_idx]; + if ( ! (frame) ) + reporter->InternalError("Assertion failed: frame"); - const ScriptFunc* func = frame->GetFunction(); - if ( func ) - push_existing_scope(func->GetScope()); + const ScriptFunc* func = frame->GetFunction(); + if ( func ) + push_existing_scope(func->GetScope()); - // ### Possibly push a debugger-local scope? + // ### Possibly push a debugger-local scope? - // Set up the lexer to read from the string. - string parse_string = string("@DEBUG ") + expr; - zeek_scan_string(parse_string.c_str()); + // Set up the lexer to read from the string. + string parse_string = string("@DEBUG ") + expr; + zeek_scan_string(parse_string.c_str()); - // Fix filename and line number for the lexer/parser, which record it. - filename = ""; - line_number = 1; - yylloc.filename = filename; - yylloc.first_line = yylloc.last_line = line_number = 1; + // Fix filename and line number for the lexer/parser, which record it. + filename = ""; + line_number = 1; + yylloc.filename = filename; + yylloc.first_line = yylloc.last_line = line_number = 1; - // Parse the thing into an expr. - ValPtr result; - if ( yyparse() ) - { - if ( g_curr_debug_error ) - debug_msg("Parsing expression '%s' failed: %s\n", expr, g_curr_debug_error); - else - debug_msg("Parsing expression '%s' failed\n", expr); + // Parse the thing into an expr. + ValPtr result; + if ( yyparse() ) { + if ( g_curr_debug_error ) + debug_msg("Parsing expression '%s' failed: %s\n", expr, g_curr_debug_error); + else + debug_msg("Parsing expression '%s' failed\n", expr); - if ( g_curr_debug_expr ) - { - delete g_curr_debug_expr; - g_curr_debug_expr = nullptr; - } - } - else - result = g_curr_debug_expr->Eval(frame); + if ( g_curr_debug_expr ) { + delete g_curr_debug_expr; + g_curr_debug_expr = nullptr; + } + } + else + result = g_curr_debug_expr->Eval(frame); - if ( func ) - pop_scope(); + if ( func ) + pop_scope(); - delete g_curr_debug_expr; - g_curr_debug_expr = nullptr; - delete[] g_curr_debug_error; - g_curr_debug_error = nullptr; - in_debug = false; + delete g_curr_debug_expr; + g_curr_debug_expr = nullptr; + delete[] g_curr_debug_error; + g_curr_debug_error = nullptr; + in_debug = false; - return result; - } + return result; +} - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/Debug.h b/src/Debug.h index a2e7da266d..0a818866d5 100644 --- a/src/Debug.h +++ b/src/Debug.h @@ -11,17 +11,16 @@ #include "zeek/StmtEnums.h" #include "zeek/util.h" -namespace zeek - { +namespace zeek { class Val; -template class IntrusivePtr; +template +class IntrusivePtr; using ValPtr = zeek::IntrusivePtr; extern std::string current_module; -namespace detail - { +namespace detail { class Frame; class Stmt; @@ -30,20 +29,14 @@ class DbgWatch; class DbgDisplay; // This needs to be defined before we do the includes that come after it. -enum ParseLocationRecType - { - PLR_UNKNOWN, - PLR_FILE_AND_LINE, - PLR_FUNCTION - }; -class ParseLocationRec - { +enum ParseLocationRecType { PLR_UNKNOWN, PLR_FILE_AND_LINE, PLR_FUNCTION }; +class ParseLocationRec { public: - ParseLocationRecType type; - int32_t line; - Stmt* stmt; - const char* filename; - }; + ParseLocationRecType type; + int32_t line; + Stmt* stmt; + const char* filename; +}; class StmtLocMapping; using Filemap = std::deque; // mapping for a single file @@ -51,94 +44,89 @@ using Filemap = std::deque; // mapping for a single file using BPIDMapType = std::map; using BPMapType = std::multimap; -class TraceState - { +class TraceState { public: - TraceState() - { - dbgtrace = false; - trace_file = stderr; - } + TraceState() { + dbgtrace = false; + trace_file = stderr; + } - // Returns previous filename. - FILE* SetTraceFile(const char* trace_filename); + // Returns previous filename. + FILE* SetTraceFile(const char* trace_filename); - bool DoTrace() const { return dbgtrace; } - void TraceOn(); - void TraceOff(); + bool DoTrace() const { return dbgtrace; } + void TraceOn(); + void TraceOff(); - int LogTrace(const char* fmt, ...) __attribute__((format(printf, 2, 3))); - ; + int LogTrace(const char* fmt, ...) __attribute__((format(printf, 2, 3))); + ; protected: - bool dbgtrace; // print an execution trace - FILE* trace_file; - }; + bool dbgtrace; // print an execution trace + FILE* trace_file; +}; extern TraceState g_trace_state; -class DebuggerState - { +class DebuggerState { public: - DebuggerState(); - ~DebuggerState(); + DebuggerState(); + ~DebuggerState(); - int NextBPID() { return next_bp_id++; } - int NextWatchID() { return next_watch_id++; } - int NextDisplayID() { return next_display_id++; } + int NextBPID() { return next_bp_id++; } + int NextWatchID() { return next_watch_id++; } + int NextDisplayID() { return next_display_id++; } - bool BreakBeforeNextStmt() { return break_before_next_stmt; } - void BreakBeforeNextStmt(bool dobrk) { break_before_next_stmt = dobrk; } + bool BreakBeforeNextStmt() { return break_before_next_stmt; } + void BreakBeforeNextStmt(bool dobrk) { break_before_next_stmt = dobrk; } - bool BreakFromSignal() { return break_from_signal; } - void BreakFromSignal(bool dobrk) { break_from_signal = dobrk; } + bool BreakFromSignal() { return break_from_signal; } + void BreakFromSignal(bool dobrk) { break_from_signal = dobrk; } - // Temporary state: vanishes when execution resumes. + // Temporary state: vanishes when execution resumes. - //### Umesh, why do these all need to be public? -- Vern + //### Umesh, why do these all need to be public? -- Vern - // Which frame we're looking at; 0 = the innermost frame. - int curr_frame_idx; + // Which frame we're looking at; 0 = the innermost frame. + int curr_frame_idx; - bool already_did_list; // did we already do a 'list' command? + bool already_did_list; // did we already do a 'list' command? - Location last_loc; // used by 'list'; the last location listed + Location last_loc; // used by 'list'; the last location listed - BPIDMapType breakpoints; // BPID -> Breakpoint - std::vector watches; - std::vector displays; - BPMapType breakpoint_map; // maps Stmt -> Breakpoints on it + BPIDMapType breakpoints; // BPID -> Breakpoint + std::vector watches; + std::vector displays; + BPMapType breakpoint_map; // maps Stmt -> Breakpoints on it protected: - bool break_before_next_stmt; // trap into debugger (used for "step") - bool break_from_signal; // was break caused by a signal? + bool break_before_next_stmt; // trap into debugger (used for "step") + bool break_from_signal; // was break caused by a signal? - int next_bp_id, next_watch_id, next_display_id; + int next_bp_id, next_watch_id, next_display_id; private: - Frame* dbg_locals; // unused - }; + Frame* dbg_locals; // unused +}; // Source line -> statement mapping. // (obj -> source line mapping available in object itself) -class StmtLocMapping - { +class StmtLocMapping { public: - StmtLocMapping() { } - StmtLocMapping(const Location* l, Stmt* s) - { - loc = *l; - stmt = s; - } + StmtLocMapping() {} + StmtLocMapping(const Location* l, Stmt* s) { + loc = *l; + stmt = s; + } - bool StartsAfter(const StmtLocMapping* m2); - const Location& Loc() const { return loc; } - Stmt* Statement() const { return stmt; } + bool StartsAfter(const StmtLocMapping* m2); + const Location& Loc() const { return loc; } + Stmt* Statement() const { return stmt; } protected: - Location loc; - Stmt* stmt = nullptr; - }; + Location loc; + Stmt* stmt = nullptr; +}; extern bool g_policy_debug; // enable debugging facility extern DebuggerState g_debugger_state; @@ -195,5 +183,5 @@ extern std::map g_dbgfilemaps; // filename => filemap // Perhaps add a code/priority argument to do selective output. int debug_msg(const char* fmt, ...) __attribute__((format(printf, 1, 2))); - } // namespace detail - } // namespace zeek +} // namespace detail +} // namespace zeek diff --git a/src/DebugCmds.cc b/src/DebugCmds.cc index 5a89101952..0cb510b18c 100644 --- a/src/DebugCmds.cc +++ b/src/DebugCmds.cc @@ -26,675 +26,584 @@ using namespace std; -namespace zeek::detail - { +namespace zeek::detail { DebugCmdInfoQueue g_DebugCmdInfos; // // Helper routines // -static bool string_is_regex(const string& s) - { - return strpbrk(s.data(), "?*\\+"); - } +static bool string_is_regex(const string& s) { return strpbrk(s.data(), "?*\\+"); } -static void lookup_global_symbols_regex(const string& orig_regex, vector& matches, - bool func_only = false) - { - if ( util::streq(orig_regex.c_str(), "") ) - return; +static void lookup_global_symbols_regex(const string& orig_regex, vector& matches, bool func_only = false) { + if ( util::streq(orig_regex.c_str(), "") ) + return; - string regex = "^"; - int len = orig_regex.length(); - for ( int i = 0; i < len; ++i ) - { - if ( orig_regex[i] == '*' ) - regex.push_back('.'); - regex.push_back(orig_regex[i]); - } - regex.push_back('$'); + string regex = "^"; + int len = orig_regex.length(); + for ( int i = 0; i < len; ++i ) { + if ( orig_regex[i] == '*' ) + regex.push_back('.'); + regex.push_back(orig_regex[i]); + } + regex.push_back('$'); - regex_t re; - if ( regcomp(&re, regex.c_str(), REG_EXTENDED | REG_NOSUB) ) - { - debug_msg("Invalid regular expression: %s\n", regex.c_str()); - return; - } + regex_t re; + if ( regcomp(&re, regex.c_str(), REG_EXTENDED | REG_NOSUB) ) { + debug_msg("Invalid regular expression: %s\n", regex.c_str()); + return; + } - auto global = global_scope(); - const auto& syms = global->Vars(); + auto global = global_scope(); + const auto& syms = global->Vars(); - ID* nextid; - for ( const auto& sym : syms ) - { - ID* nextid = sym.second.get(); - if ( ! func_only || nextid->GetType()->Tag() == TYPE_FUNC ) - if ( ! regexec(&re, nextid->Name(), 0, 0, 0) ) - matches.push_back(nextid); - } - } + ID* nextid; + for ( const auto& sym : syms ) { + ID* nextid = sym.second.get(); + if ( ! func_only || nextid->GetType()->Tag() == TYPE_FUNC ) + if ( ! regexec(&re, nextid->Name(), 0, 0, 0) ) + matches.push_back(nextid); + } +} -static void choose_global_symbols_regex(const string& regex, vector& choices, - bool func_only = false) - { - lookup_global_symbols_regex(regex, choices, func_only); +static void choose_global_symbols_regex(const string& regex, vector& choices, bool func_only = false) { + lookup_global_symbols_regex(regex, choices, func_only); - if ( choices.size() <= 1 ) - return; + if ( choices.size() <= 1 ) + return; - while ( true ) - { - debug_msg("There were multiple matches, please choose:\n"); + while ( true ) { + debug_msg("There were multiple matches, please choose:\n"); - for ( size_t i = 0; i < choices.size(); i++ ) - debug_msg("[%zu] %s\n", i + 1, choices[i]->Name()); + for ( size_t i = 0; i < choices.size(); i++ ) + debug_msg("[%zu] %s\n", i + 1, choices[i]->Name()); - debug_msg("[a] All of the above\n"); - debug_msg("[n] None of the above\n"); - debug_msg("Enter your choice: "); + debug_msg("[a] All of the above\n"); + debug_msg("[n] None of the above\n"); + debug_msg("Enter your choice: "); - char charinput[256]; - if ( ! fgets(charinput, sizeof(charinput) - 1, stdin) ) - { - choices.clear(); - return; - } - if ( charinput[strlen(charinput) - 1] == '\n' ) - charinput[strlen(charinput) - 1] = 0; + char charinput[256]; + if ( ! fgets(charinput, sizeof(charinput) - 1, stdin) ) { + choices.clear(); + return; + } + if ( charinput[strlen(charinput) - 1] == '\n' ) + charinput[strlen(charinput) - 1] = 0; - string input = charinput; - if ( input == "a" ) - return; + string input = charinput; + if ( input == "a" ) + return; - if ( input == "n" ) - { - choices.clear(); - return; - } - int option = atoi(input.c_str()); - if ( option > 0 && option <= (int)choices.size() ) - { - ID* choice = choices[option - 1]; - choices.clear(); - choices.push_back(choice); - return; - } - } - } + if ( input == "n" ) { + choices.clear(); + return; + } + int option = atoi(input.c_str()); + if ( option > 0 && option <= (int)choices.size() ) { + ID* choice = choices[option - 1]; + choices.clear(); + choices.push_back(choice); + return; + } + } +} // // DebugCmdInfo implementation // -DebugCmdInfo::DebugCmdInfo(const DebugCmdInfo& info) : cmd(info.cmd), helpstring(nullptr) - { - num_names = info.num_names; - names = info.names; - resume_execution = info.resume_execution; - repeatable = info.repeatable; - } +DebugCmdInfo::DebugCmdInfo(const DebugCmdInfo& info) : cmd(info.cmd), helpstring(nullptr) { + num_names = info.num_names; + names = info.names; + resume_execution = info.resume_execution; + repeatable = info.repeatable; +} -DebugCmdInfo::DebugCmdInfo(DebugCmd arg_cmd, const char* const* arg_names, int arg_num_names, - bool arg_resume_execution, const char* const arg_helpstring, - bool arg_repeatable) - : cmd(arg_cmd), helpstring(arg_helpstring) - { - num_names = arg_num_names; - resume_execution = arg_resume_execution; - repeatable = arg_repeatable; +DebugCmdInfo::DebugCmdInfo(DebugCmd arg_cmd, const char* const* arg_names, int arg_num_names, bool arg_resume_execution, + const char* const arg_helpstring, bool arg_repeatable) + : cmd(arg_cmd), helpstring(arg_helpstring) { + num_names = arg_num_names; + resume_execution = arg_resume_execution; + repeatable = arg_repeatable; - for ( int i = 0; i < num_names; ++i ) - names.push_back(arg_names[i]); - } + for ( int i = 0; i < num_names; ++i ) + names.push_back(arg_names[i]); +} -const DebugCmdInfo* get_debug_cmd_info(DebugCmd cmd) - { - if ( (int)cmd < g_DebugCmdInfos.size() ) - return g_DebugCmdInfos[(int)cmd]; - else - return nullptr; - } +const DebugCmdInfo* get_debug_cmd_info(DebugCmd cmd) { + if ( (int)cmd < g_DebugCmdInfos.size() ) + return g_DebugCmdInfos[(int)cmd]; + else + return nullptr; +} -int find_all_matching_cmds(const string& prefix, const char* array_of_matches[]) - { - // Trivial implementation for now (### use hashing later). +int find_all_matching_cmds(const string& prefix, const char* array_of_matches[]) { + // Trivial implementation for now (### use hashing later). - unsigned int arglen = prefix.length(); - int matches = 0; + unsigned int arglen = prefix.length(); + int matches = 0; - for ( int i = 0; i < num_debug_cmds(); ++i ) - { - array_of_matches[g_DebugCmdInfos[i]->Cmd()] = nullptr; + for ( int i = 0; i < num_debug_cmds(); ++i ) { + array_of_matches[g_DebugCmdInfos[i]->Cmd()] = nullptr; - for ( int j = 0; j < g_DebugCmdInfos[i]->NumNames(); ++j ) - { - const char* curr_name = g_DebugCmdInfos[i]->Names()[j]; - if ( strncmp(curr_name, prefix.c_str(), arglen) ) - continue; + for ( int j = 0; j < g_DebugCmdInfos[i]->NumNames(); ++j ) { + const char* curr_name = g_DebugCmdInfos[i]->Names()[j]; + if ( strncmp(curr_name, prefix.c_str(), arglen) ) + continue; - // If exact match, then only return that one. - if ( ! prefix.compare(curr_name) ) - { - for ( int k = 0; k < num_debug_cmds(); ++k ) - array_of_matches[k] = nullptr; + // If exact match, then only return that one. + if ( ! prefix.compare(curr_name) ) { + for ( int k = 0; k < num_debug_cmds(); ++k ) + array_of_matches[k] = nullptr; - array_of_matches[g_DebugCmdInfos[i]->Cmd()] = curr_name; - return 1; - } + array_of_matches[g_DebugCmdInfos[i]->Cmd()] = curr_name; + return 1; + } - array_of_matches[g_DebugCmdInfos[i]->Cmd()] = curr_name; - ++matches; - } - } + array_of_matches[g_DebugCmdInfos[i]->Cmd()] = curr_name; + ++matches; + } + } - return matches; - } + return matches; +} // // ------------------------------------------------------------ // Implementation of some debugger commands // Start, end bounds of which frame numbers to print -static int dbg_backtrace_internal(int start, int end) - { - if ( start < 0 || end < 0 || (unsigned)start >= g_frame_stack.size() || - (unsigned)end >= g_frame_stack.size() ) - reporter->InternalError("Invalid stack frame index in DbgBacktraceInternal\n"); +static int dbg_backtrace_internal(int start, int end) { + if ( start < 0 || end < 0 || (unsigned)start >= g_frame_stack.size() || (unsigned)end >= g_frame_stack.size() ) + reporter->InternalError("Invalid stack frame index in DbgBacktraceInternal\n"); - if ( start < end ) - { - int temp = start; - start = end; - end = temp; - } + if ( start < end ) { + int temp = start; + start = end; + end = temp; + } - for ( int i = start; i >= end; --i ) - { - const Frame* f = g_frame_stack[i]; - const Stmt* stmt = f ? f->GetNextStmt() : nullptr; + for ( int i = start; i >= end; --i ) { + const Frame* f = g_frame_stack[i]; + const Stmt* stmt = f ? f->GetNextStmt() : nullptr; - string context = get_context_description(stmt, f); - debug_msg("#%d %s\n", int(g_frame_stack.size() - 1 - i), context.c_str()); - }; + string context = get_context_description(stmt, f); + debug_msg("#%d %s\n", int(g_frame_stack.size() - 1 - i), context.c_str()); + }; - return 1; - } + return 1; +} // Returns 0 for illegal arguments, or 1 on success. -int dbg_cmd_backtrace(DebugCmd cmd, const vector& args) - { - assert(cmd == dcBacktrace); - assert(g_frame_stack.size() > 0); +int dbg_cmd_backtrace(DebugCmd cmd, const vector& args) { + assert(cmd == dcBacktrace); + assert(g_frame_stack.size() > 0); - unsigned int start_iter; - int end_iter; + unsigned int start_iter; + int end_iter; - if ( args.size() > 0 ) - { - int how_many; // determines how we traverse the frames - int valid_arg = sscanf(args[0].c_str(), "%i", &how_many); - if ( ! valid_arg ) - { - debug_msg("Argument to backtrace '%s' invalid: must be an integer\n", args[0].c_str()); - return 0; - } + if ( args.size() > 0 ) { + int how_many; // determines how we traverse the frames + int valid_arg = sscanf(args[0].c_str(), "%i", &how_many); + if ( ! valid_arg ) { + debug_msg("Argument to backtrace '%s' invalid: must be an integer\n", args[0].c_str()); + return 0; + } - if ( how_many > 0 ) - { // innermost N frames - start_iter = g_frame_stack.size() - 1; - end_iter = start_iter - how_many + 1; - if ( end_iter < 0 ) - end_iter = 0; - } - else - { // outermost N frames - start_iter = how_many - 1; - if ( start_iter + 1 > g_frame_stack.size() ) - start_iter = g_frame_stack.size() - 1; - end_iter = 0; - } - } - else - { - start_iter = g_frame_stack.size() - 1; - end_iter = 0; - } + if ( how_many > 0 ) { // innermost N frames + start_iter = g_frame_stack.size() - 1; + end_iter = start_iter - how_many + 1; + if ( end_iter < 0 ) + end_iter = 0; + } + else { // outermost N frames + start_iter = how_many - 1; + if ( start_iter + 1 > g_frame_stack.size() ) + start_iter = g_frame_stack.size() - 1; + end_iter = 0; + } + } + else { + start_iter = g_frame_stack.size() - 1; + end_iter = 0; + } - return dbg_backtrace_internal(start_iter, end_iter); - } + return dbg_backtrace_internal(start_iter, end_iter); +} // Returns 0 if invalid args, else 1. -int dbg_cmd_frame(DebugCmd cmd, const vector& args) - { - assert(cmd == dcFrame || cmd == dcUp || cmd == dcDown); +int dbg_cmd_frame(DebugCmd cmd, const vector& args) { + assert(cmd == dcFrame || cmd == dcUp || cmd == dcDown); - if ( cmd == dcFrame ) - { - int idx = 0; + if ( cmd == dcFrame ) { + int idx = 0; - if ( args.size() > 0 ) - { - if ( args.size() > 1 ) - { - debug_msg("Too many arguments: expecting frame number 'n'\n"); - return 0; - } + if ( args.size() > 0 ) { + if ( args.size() > 1 ) { + debug_msg("Too many arguments: expecting frame number 'n'\n"); + return 0; + } - if ( ! sscanf(args[0].c_str(), "%d", &idx) ) - { - debug_msg("Argument to frame must be a positive integer\n"); - return 0; - } + if ( ! sscanf(args[0].c_str(), "%d", &idx) ) { + debug_msg("Argument to frame must be a positive integer\n"); + return 0; + } - if ( idx < 0 || (unsigned int)idx >= g_frame_stack.size() ) - { - debug_msg("No frame %d", idx); - return 0; - } - } + if ( idx < 0 || (unsigned int)idx >= g_frame_stack.size() ) { + debug_msg("No frame %d", idx); + return 0; + } + } - g_debugger_state.curr_frame_idx = idx; - } + g_debugger_state.curr_frame_idx = idx; + } - else if ( cmd == dcDown ) - { - if ( g_debugger_state.curr_frame_idx == 0 ) - { - debug_msg("Innermost frame already selected\n"); - return 0; - } + else if ( cmd == dcDown ) { + if ( g_debugger_state.curr_frame_idx == 0 ) { + debug_msg("Innermost frame already selected\n"); + return 0; + } - g_debugger_state.curr_frame_idx--; - } + g_debugger_state.curr_frame_idx--; + } - else if ( cmd == dcUp ) - { - if ( (unsigned int)(g_debugger_state.curr_frame_idx + 1) == g_frame_stack.size() ) - { - debug_msg("Outermost frame already selected\n"); - return 0; - } + else if ( cmd == dcUp ) { + if ( (unsigned int)(g_debugger_state.curr_frame_idx + 1) == g_frame_stack.size() ) { + debug_msg("Outermost frame already selected\n"); + return 0; + } - ++g_debugger_state.curr_frame_idx; - } + ++g_debugger_state.curr_frame_idx; + } - int user_frame_number = g_frame_stack.size() - 1 - g_debugger_state.curr_frame_idx; + int user_frame_number = g_frame_stack.size() - 1 - g_debugger_state.curr_frame_idx; - // Set the current location to the new frame being looked at - // for 'list', 'break', etc. - const Stmt* stmt = g_frame_stack[user_frame_number]->GetNextStmt(); - if ( ! stmt ) - reporter->InternalError("Assertion failed: stmt is null"); + // Set the current location to the new frame being looked at + // for 'list', 'break', etc. + const Stmt* stmt = g_frame_stack[user_frame_number]->GetNextStmt(); + if ( ! stmt ) + reporter->InternalError("Assertion failed: stmt is null"); - const Location loc = *stmt->GetLocationInfo(); - g_debugger_state.last_loc = loc; - g_debugger_state.already_did_list = false; + const Location loc = *stmt->GetLocationInfo(); + g_debugger_state.last_loc = loc; + g_debugger_state.already_did_list = false; - return dbg_backtrace_internal(user_frame_number, user_frame_number); - } + return dbg_backtrace_internal(user_frame_number, user_frame_number); +} -int dbg_cmd_help(DebugCmd cmd, const vector& args) - { - assert(cmd == dcHelp); +int dbg_cmd_help(DebugCmd cmd, const vector& args) { + assert(cmd == dcHelp); - debug_msg("Help summary: \n\n"); - for ( int i = 1; i < num_debug_cmds(); ++i ) - { - const DebugCmdInfo* info = get_debug_cmd_info(DebugCmd(i)); - debug_msg("%s -- %s\n", info->Names()[0], info->Helpstring()); - } + debug_msg("Help summary: \n\n"); + for ( int i = 1; i < num_debug_cmds(); ++i ) { + const DebugCmdInfo* info = get_debug_cmd_info(DebugCmd(i)); + debug_msg("%s -- %s\n", info->Names()[0], info->Helpstring()); + } - return -1; - } + return -1; +} -int dbg_cmd_break(DebugCmd cmd, const vector& args) - { - assert(cmd == dcBreak); +int dbg_cmd_break(DebugCmd cmd, const vector& args) { + assert(cmd == dcBreak); - vector bps; + vector bps; - int cond_index = -1; // at which argument pos. does bp condition start? + int cond_index = -1; // at which argument pos. does bp condition start? - if ( args.empty() || args[0] == "if" ) - { // break on next stmt - int user_frame_number = g_frame_stack.size() - 1 - g_debugger_state.curr_frame_idx; + if ( args.empty() || args[0] == "if" ) { // break on next stmt + int user_frame_number = g_frame_stack.size() - 1 - g_debugger_state.curr_frame_idx; - Stmt* stmt = g_frame_stack[user_frame_number]->GetNextStmt(); - if ( ! stmt ) - reporter->InternalError("Assertion failed: stmt is null"); + Stmt* stmt = g_frame_stack[user_frame_number]->GetNextStmt(); + if ( ! stmt ) + reporter->InternalError("Assertion failed: stmt is null"); - DbgBreakpoint* bp = new DbgBreakpoint(); - bp->SetID(g_debugger_state.NextBPID()); + DbgBreakpoint* bp = new DbgBreakpoint(); + bp->SetID(g_debugger_state.NextBPID()); - if ( ! bp->SetLocation(stmt) ) - { - debug_msg("Breakpoint not set.\n"); - delete bp; - return 0; - } + if ( ! bp->SetLocation(stmt) ) { + debug_msg("Breakpoint not set.\n"); + delete bp; + return 0; + } - if ( args.size() > 0 && args[0] == "if" ) - cond_index = 1; + if ( args.size() > 0 && args[0] == "if" ) + cond_index = 1; - bps.push_back(bp); - } + bps.push_back(bp); + } - else - { - vector locstrings; - if ( string_is_regex(args[0]) ) - { - vector choices; - choose_global_symbols_regex(args[0], choices, true); - for ( unsigned int i = 0; i < choices.size(); ++i ) - locstrings.emplace_back(choices[i]->Name()); - } - else - locstrings.push_back(args[0]); + else { + vector locstrings; + if ( string_is_regex(args[0]) ) { + vector choices; + choose_global_symbols_regex(args[0], choices, true); + for ( unsigned int i = 0; i < choices.size(); ++i ) + locstrings.emplace_back(choices[i]->Name()); + } + else + locstrings.push_back(args[0]); - for ( unsigned int strindex = 0; strindex < locstrings.size(); ++strindex ) - { - debug_msg("Setting breakpoint on %s:\n", locstrings[strindex].c_str()); - vector plrs = parse_location_string(locstrings[strindex]); - for ( const auto& plr : plrs ) - { - DbgBreakpoint* bp = new DbgBreakpoint(); - bp->SetID(g_debugger_state.NextBPID()); - if ( ! bp->SetLocation(plr, locstrings[strindex]) ) - { - debug_msg("Breakpoint not set.\n"); - delete bp; - } - else - bps.push_back(bp); - } - } + for ( unsigned int strindex = 0; strindex < locstrings.size(); ++strindex ) { + debug_msg("Setting breakpoint on %s:\n", locstrings[strindex].c_str()); + vector plrs = parse_location_string(locstrings[strindex]); + for ( const auto& plr : plrs ) { + DbgBreakpoint* bp = new DbgBreakpoint(); + bp->SetID(g_debugger_state.NextBPID()); + if ( ! bp->SetLocation(plr, locstrings[strindex]) ) { + debug_msg("Breakpoint not set.\n"); + delete bp; + } + else + bps.push_back(bp); + } + } - if ( args.size() > 1 && args[1] == "if" ) - cond_index = 2; - } + if ( args.size() > 1 && args[1] == "if" ) + cond_index = 2; + } - // Is there a condition specified? - if ( cond_index >= 0 && ! bps.empty() ) - { - // ### Implement conditions - string cond; - for ( const auto& arg : args ) - { - cond += arg; - cond += " "; - } - bps[0]->SetCondition(cond); - } + // Is there a condition specified? + if ( cond_index >= 0 && ! bps.empty() ) { + // ### Implement conditions + string cond; + for ( const auto& arg : args ) { + cond += arg; + cond += " "; + } + bps[0]->SetCondition(cond); + } - for ( auto& bp : bps ) - { - bp->SetTemporary(false); - g_debugger_state.breakpoints[bp->GetID()] = bp; - } + for ( auto& bp : bps ) { + bp->SetTemporary(false); + g_debugger_state.breakpoints[bp->GetID()] = bp; + } - return 0; - } + return 0; +} // Set a condition on an existing breakpoint. -int dbg_cmd_break_condition(DebugCmd cmd, const vector& args) - { - assert(cmd == dcBreakCondition); +int dbg_cmd_break_condition(DebugCmd cmd, const vector& args) { + assert(cmd == dcBreakCondition); - if ( args.size() < 2 ) - { - debug_msg("Arguments must specify breakpoint number and condition.\n"); - return 0; - } + if ( args.size() < 2 ) { + debug_msg("Arguments must specify breakpoint number and condition.\n"); + return 0; + } - int idx = atoi(args[0].c_str()); - DbgBreakpoint* bp = g_debugger_state.breakpoints[idx]; + int idx = atoi(args[0].c_str()); + DbgBreakpoint* bp = g_debugger_state.breakpoints[idx]; - string expr; - for ( int i = 1; i < int(args.size()); ++i ) - { - expr += args[i]; - expr += " "; - } - bp->SetCondition(expr); + string expr; + for ( int i = 1; i < int(args.size()); ++i ) { + expr += args[i]; + expr += " "; + } + bp->SetCondition(expr); - return 1; - } + return 1; +} // Change the state of a breakpoint. -int dbg_cmd_break_set_state(DebugCmd cmd, const vector& args) - { - assert(cmd == dcDeleteBreak || cmd == dcClearBreak || cmd == dcDisableBreak || - cmd == dcEnableBreak || cmd == dcIgnoreBreak); +int dbg_cmd_break_set_state(DebugCmd cmd, const vector& args) { + assert(cmd == dcDeleteBreak || cmd == dcClearBreak || cmd == dcDisableBreak || cmd == dcEnableBreak || + cmd == dcIgnoreBreak); - if ( cmd == dcClearBreak || cmd == dcIgnoreBreak ) - { - debug_msg("'clear' and 'ignore' commands not currently supported\n"); - return 0; - } + if ( cmd == dcClearBreak || cmd == dcIgnoreBreak ) { + debug_msg("'clear' and 'ignore' commands not currently supported\n"); + return 0; + } - if ( g_debugger_state.breakpoints.empty() ) - { - debug_msg("No breakpoints currently set.\n"); - return -1; - } + if ( g_debugger_state.breakpoints.empty() ) { + debug_msg("No breakpoints currently set.\n"); + return -1; + } - vector bps_to_change; + vector bps_to_change; - if ( args.empty() ) - { - BPIDMapType::iterator iter; - for ( iter = g_debugger_state.breakpoints.begin(); - iter != g_debugger_state.breakpoints.end(); ++iter ) - bps_to_change.push_back(iter->second->GetID()); - } - else - { - for ( const auto& arg : args ) - if ( int idx = atoi(arg.c_str()) ) - bps_to_change.push_back(idx); - } + if ( args.empty() ) { + BPIDMapType::iterator iter; + for ( iter = g_debugger_state.breakpoints.begin(); iter != g_debugger_state.breakpoints.end(); ++iter ) + bps_to_change.push_back(iter->second->GetID()); + } + else { + for ( const auto& arg : args ) + if ( int idx = atoi(arg.c_str()) ) + bps_to_change.push_back(idx); + } - for ( auto bp_change : bps_to_change ) - { - BPIDMapType::iterator result = g_debugger_state.breakpoints.find(bp_change); + for ( auto bp_change : bps_to_change ) { + BPIDMapType::iterator result = g_debugger_state.breakpoints.find(bp_change); - if ( result != g_debugger_state.breakpoints.end() ) - { - switch ( cmd ) - { - case dcDisableBreak: - g_debugger_state.breakpoints[bp_change]->SetEnable(false); - debug_msg("Breakpoint %d disabled\n", bp_change); - break; + if ( result != g_debugger_state.breakpoints.end() ) { + switch ( cmd ) { + case dcDisableBreak: + g_debugger_state.breakpoints[bp_change]->SetEnable(false); + debug_msg("Breakpoint %d disabled\n", bp_change); + break; - case dcEnableBreak: - g_debugger_state.breakpoints[bp_change]->SetEnable(true); - debug_msg("Breakpoint %d enabled\n", bp_change); - break; + case dcEnableBreak: + g_debugger_state.breakpoints[bp_change]->SetEnable(true); + debug_msg("Breakpoint %d enabled\n", bp_change); + break; - case dcDeleteBreak: - delete g_debugger_state.breakpoints[bp_change]; - g_debugger_state.breakpoints.erase(bp_change); - debug_msg("Breakpoint %d deleted\n", bp_change); - break; + case dcDeleteBreak: + delete g_debugger_state.breakpoints[bp_change]; + g_debugger_state.breakpoints.erase(bp_change); + debug_msg("Breakpoint %d deleted\n", bp_change); + break; - default: - reporter->InternalError("Invalid command in DbgCmdBreakSetState\n"); - } - } + default: reporter->InternalError("Invalid command in DbgCmdBreakSetState\n"); + } + } - else - debug_msg("Breakpoint %d does not exist\n", bp_change); - } + else + debug_msg("Breakpoint %d does not exist\n", bp_change); + } - return -1; - } + return -1; +} // Evaluate an expression and print the result. -int dbg_cmd_print(DebugCmd cmd, const vector& args) - { - assert(cmd == dcPrint); +int dbg_cmd_print(DebugCmd cmd, const vector& args) { + assert(cmd == dcPrint); - // ### TODO: add support for formats + // ### TODO: add support for formats - // Just concatenate all the 'args' into one expression. - string expr; - for ( size_t i = 0; i < args.size(); ++i ) - { - expr += args[i]; - if ( i < args.size() - 1 ) - expr += " "; - } + // Just concatenate all the 'args' into one expression. + string expr; + for ( size_t i = 0; i < args.size(); ++i ) { + expr += args[i]; + if ( i < args.size() - 1 ) + expr += " "; + } - auto val = dbg_eval_expr(expr.c_str()); + auto val = dbg_eval_expr(expr.c_str()); - if ( val ) - { - ODesc d; - val->Describe(&d); - debug_msg("%s\n", d.Description()); - } - else - { - debug_msg("\n"); - } + if ( val ) { + ODesc d; + val->Describe(&d); + debug_msg("%s\n", d.Description()); + } + else { + debug_msg("\n"); + } - return 1; - } + return 1; +} // Get the debugger's state. // Allowed arguments are: break (breakpoints), watch, display, source. -int dbg_cmd_info(DebugCmd cmd, const vector& args) - { - assert(cmd == dcInfo); +int dbg_cmd_info(DebugCmd cmd, const vector& args) { + assert(cmd == dcInfo); - if ( args.empty() ) - { - debug_msg("Syntax: info info-command\n"); - debug_msg("List of info-commands:\n"); - debug_msg("info breakpoints -- List of breakpoints and watches\n"); - return 1; - } + if ( args.empty() ) { + debug_msg("Syntax: info info-command\n"); + debug_msg("List of info-commands:\n"); + debug_msg("info breakpoints -- List of breakpoints and watches\n"); + return 1; + } - if ( ! strncmp(args[0].c_str(), "breakpoints", args[0].size()) || - ! strncmp(args[0].c_str(), "watch", args[0].size()) ) - { - debug_msg("Num Type Disp Enb What\n"); + if ( ! strncmp(args[0].c_str(), "breakpoints", args[0].size()) || + ! strncmp(args[0].c_str(), "watch", args[0].size()) ) { + debug_msg("Num Type Disp Enb What\n"); - BPIDMapType::iterator iter; - for ( iter = g_debugger_state.breakpoints.begin(); - iter != g_debugger_state.breakpoints.end(); ++iter ) - { - DbgBreakpoint* bp = (*iter).second; - debug_msg("%-4d%-15s%-5s%-4s%s\n", bp->GetID(), "breakpoint", - bp->IsTemporary() ? "del" : "keep", bp->IsEnabled() ? "y" : "n", - bp->Description()); - } - } + BPIDMapType::iterator iter; + for ( iter = g_debugger_state.breakpoints.begin(); iter != g_debugger_state.breakpoints.end(); ++iter ) { + DbgBreakpoint* bp = (*iter).second; + debug_msg("%-4d%-15s%-5s%-4s%s\n", bp->GetID(), "breakpoint", bp->IsTemporary() ? "del" : "keep", + bp->IsEnabled() ? "y" : "n", bp->Description()); + } + } - else - debug_msg("I don't have info for that yet.\n"); + else + debug_msg("I don't have info for that yet.\n"); - return 1; - } + return 1; +} -int dbg_cmd_list(DebugCmd cmd, const vector& args) - { - assert(cmd == dcList); +int dbg_cmd_list(DebugCmd cmd, const vector& args) { + assert(cmd == dcList); - // The constant 4 is to match the GDB behavior. - const unsigned int CENTER_IDX = 4; // 5th line is the 'interesting' one + // The constant 4 is to match the GDB behavior. + const unsigned int CENTER_IDX = 4; // 5th line is the 'interesting' one - int pre_offset = 0; - if ( args.size() > 1 ) - { - debug_msg("Syntax: list [file:]line OR list function_name\n"); - return 0; - } + int pre_offset = 0; + if ( args.size() > 1 ) { + debug_msg("Syntax: list [file:]line OR list function_name\n"); + return 0; + } - if ( args.empty() ) - { - // Special case: if we just hit a breakpoint, then show - // that line without advancing first. - if ( g_debugger_state.already_did_list ) - pre_offset = 10; - } + if ( args.empty() ) { + // Special case: if we just hit a breakpoint, then show + // that line without advancing first. + if ( g_debugger_state.already_did_list ) + pre_offset = 10; + } - else if ( args[0] == "-" ) - // Why -10 ? Because that's what GDB does. - pre_offset = -10; + else if ( args[0] == "-" ) + // Why -10 ? Because that's what GDB does. + pre_offset = -10; - else if ( args[0][0] == '-' || args[0][0] == '+' ) - { - int offset; - if ( ! sscanf(args[0].c_str(), "%d", &offset) ) - { - debug_msg("Offset must be a number\n"); - return false; - } + else if ( args[0][0] == '-' || args[0][0] == '+' ) { + int offset; + if ( ! sscanf(args[0].c_str(), "%d", &offset) ) { + debug_msg("Offset must be a number\n"); + return false; + } - pre_offset = offset; - } + pre_offset = offset; + } - else - { - vector plrs = parse_location_string(args[0]); - ParseLocationRec plr = plrs[0]; - if ( plr.type == PLR_UNKNOWN ) - { - debug_msg("Invalid location specifier\n"); - return false; - } + else { + vector plrs = parse_location_string(args[0]); + ParseLocationRec plr = plrs[0]; + if ( plr.type == PLR_UNKNOWN ) { + debug_msg("Invalid location specifier\n"); + return false; + } - g_debugger_state.last_loc.filename = plr.filename; - g_debugger_state.last_loc.first_line = plr.line; - pre_offset = 0; - } + g_debugger_state.last_loc.filename = plr.filename; + g_debugger_state.last_loc.first_line = plr.line; + pre_offset = 0; + } - if ( (int)pre_offset + (int)g_debugger_state.last_loc.first_line - (int)CENTER_IDX < 0 ) - pre_offset = CENTER_IDX - g_debugger_state.last_loc.first_line; + if ( (int)pre_offset + (int)g_debugger_state.last_loc.first_line - (int)CENTER_IDX < 0 ) + pre_offset = CENTER_IDX - g_debugger_state.last_loc.first_line; - g_debugger_state.last_loc.first_line += pre_offset; + g_debugger_state.last_loc.first_line += pre_offset; - int last_line_in_file = how_many_lines_in(g_debugger_state.last_loc.filename); + int last_line_in_file = how_many_lines_in(g_debugger_state.last_loc.filename); - if ( g_debugger_state.last_loc.first_line > last_line_in_file ) - g_debugger_state.last_loc.first_line = last_line_in_file; + if ( g_debugger_state.last_loc.first_line > last_line_in_file ) + g_debugger_state.last_loc.first_line = last_line_in_file; - PrintLines(g_debugger_state.last_loc.filename, - g_debugger_state.last_loc.first_line - CENTER_IDX, 10, true); + PrintLines(g_debugger_state.last_loc.filename, g_debugger_state.last_loc.first_line - CENTER_IDX, 10, true); - g_debugger_state.already_did_list = true; + g_debugger_state.already_did_list = true; - return 1; - } + return 1; +} -int dbg_cmd_trace(DebugCmd cmd, const vector& args) - { - assert(cmd == dcTrace); +int dbg_cmd_trace(DebugCmd cmd, const vector& args) { + assert(cmd == dcTrace); - if ( args.empty() ) - { - debug_msg("Execution tracing is %s.\n", g_trace_state.DoTrace() ? "on" : "off"); - return 1; - } + if ( args.empty() ) { + debug_msg("Execution tracing is %s.\n", g_trace_state.DoTrace() ? "on" : "off"); + return 1; + } - if ( args[0] == "on" ) - { - g_trace_state.TraceOn(); - return 1; - } + if ( args[0] == "on" ) { + g_trace_state.TraceOn(); + return 1; + } - if ( args[0] == "off" ) - { - g_trace_state.TraceOff(); - return 1; - } + if ( args[0] == "off" ) { + g_trace_state.TraceOff(); + return 1; + } - debug_msg("Invalid argument"); - return 0; - } + debug_msg("Invalid argument"); + return 0; +} - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/DebugCmds.h b/src/DebugCmds.h index 3eb0ec6b2c..2e8d14c5ae 100644 --- a/src/DebugCmds.h +++ b/src/DebugCmds.h @@ -11,39 +11,37 @@ // This file is generated during the build. #include "DebugCmdConstants.h" -namespace zeek::detail - { +namespace zeek::detail { -class DebugCmdInfo - { +class DebugCmdInfo { public: - DebugCmdInfo(const DebugCmdInfo& info); + DebugCmdInfo(const DebugCmdInfo& info); - DebugCmdInfo(DebugCmd cmd, const char* const* names, int num_names, bool resume_execution, - const char* const helpstring, bool repeatable); + DebugCmdInfo(DebugCmd cmd, const char* const* names, int num_names, bool resume_execution, + const char* const helpstring, bool repeatable); - DebugCmdInfo() : helpstring(nullptr) { } + DebugCmdInfo() : helpstring(nullptr) {} - int Cmd() const { return cmd; } - int NumNames() const { return num_names; } - const std::vector& Names() const { return names; } - bool ResumeExecution() const { return resume_execution; } - const char* Helpstring() const { return helpstring; } - bool Repeatable() const { return repeatable; } + int Cmd() const { return cmd; } + int NumNames() const { return num_names; } + const std::vector& Names() const { return names; } + bool ResumeExecution() const { return resume_execution; } + const char* Helpstring() const { return helpstring; } + bool Repeatable() const { return repeatable; } protected: - DebugCmd cmd; + DebugCmd cmd; - int32_t num_names; - std::vector names; - const char* const helpstring; + int32_t num_names; + std::vector names; + const char* const helpstring; - // Whether executing this should restart execution of the script. - bool resume_execution; + // Whether executing this should restart execution of the script. + bool resume_execution; - // Does entering a blank line repeat this command? - bool repeatable; - }; + // Does entering a blank line repeat this command? + bool repeatable; +}; using DebugCmdInfoQueue = std::deque; extern DebugCmdInfoQueue g_DebugCmdInfos; @@ -80,4 +78,4 @@ DbgCmdFn dbg_cmd_info; DbgCmdFn dbg_cmd_list; DbgCmdFn dbg_cmd_trace; - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/DebugLogger.cc b/src/DebugLogger.cc index 4ce647157b..8e964aac32 100644 --- a/src/DebugLogger.cc +++ b/src/DebugLogger.cc @@ -11,202 +11,178 @@ zeek::detail::DebugLogger zeek::detail::debug_logger; zeek::detail::DebugLogger& debug_logger = zeek::detail::debug_logger; -namespace zeek::detail - { +namespace zeek::detail { // Same order here as in DebugStream. -DebugLogger::Stream DebugLogger::streams[NUM_DBGS] = { - {"serial", 0, false}, {"rules", 0, false}, {"string", 0, false}, - {"notifiers", 0, false}, {"main-loop", 0, false}, {"dpd", 0, false}, - {"packet_analysis", 0, false}, {"file_analysis", 0, false}, {"tm", 0, false}, - {"logging", 0, false}, {"input", 0, false}, {"threading", 0, false}, - {"plugins", 0, false}, {"zeekygen", 0, false}, {"pktio", 0, false}, - {"broker", 0, false}, {"scripts", 0, false}, {"supervisor", 0, false}, - {"hashkey", 0, false}, {"spicy", 0, false}}; +DebugLogger::Stream DebugLogger::streams[NUM_DBGS] = + {{"serial", 0, false}, {"rules", 0, false}, {"string", 0, false}, {"notifiers", 0, false}, + {"main-loop", 0, false}, {"dpd", 0, false}, {"packet_analysis", 0, false}, {"file_analysis", 0, false}, + {"tm", 0, false}, {"logging", 0, false}, {"input", 0, false}, {"threading", 0, false}, + {"plugins", 0, false}, {"zeekygen", 0, false}, {"pktio", 0, false}, {"broker", 0, false}, + {"scripts", 0, false}, {"supervisor", 0, false}, {"hashkey", 0, false}, {"spicy", 0, false}}; -DebugLogger::DebugLogger() - { - verbose = false; - file = nullptr; - } +DebugLogger::DebugLogger() { + verbose = false; + file = nullptr; +} -DebugLogger::~DebugLogger() - { - if ( file && file != stderr ) - fclose(file); - } +DebugLogger::~DebugLogger() { + if ( file && file != stderr ) + fclose(file); +} -void DebugLogger::OpenDebugLog(const char* filename) - { - if ( filename ) - { - filename = util::detail::log_file_name(filename); +void DebugLogger::OpenDebugLog(const char* filename) { + if ( filename ) { + filename = util::detail::log_file_name(filename); - file = fopen(filename, "w"); - if ( ! file ) - { - // The reporter may not be initialized here yet. - if ( reporter ) - reporter->FatalError("can't open '%s' for debugging output", filename); - else - { - fprintf(stderr, "can't open '%s' for debugging output\n", filename); - exit(1); - } - } + file = fopen(filename, "w"); + if ( ! file ) { + // The reporter may not be initialized here yet. + if ( reporter ) + reporter->FatalError("can't open '%s' for debugging output", filename); + else { + fprintf(stderr, "can't open '%s' for debugging output\n", filename); + exit(1); + } + } - util::detail::setvbuf(file, NULL, _IOLBF, 0); - } - else - file = stderr; - } + util::detail::setvbuf(file, NULL, _IOLBF, 0); + } + else + file = stderr; +} -void DebugLogger::ShowStreamsHelp() - { - fprintf(stderr, "\n"); - fprintf(stderr, "Enable debug output into debug.log with -B .\n"); - fprintf(stderr, " is a comma-separated list of streams to enable.\n"); - fprintf(stderr, "\n"); - fprintf(stderr, "Available streams:\n"); +void DebugLogger::ShowStreamsHelp() { + fprintf(stderr, "\n"); + fprintf(stderr, "Enable debug output into debug.log with -B .\n"); + fprintf(stderr, " is a comma-separated list of streams to enable.\n"); + fprintf(stderr, "\n"); + fprintf(stderr, "Available streams:\n"); - for ( int i = 0; i < NUM_DBGS; ++i ) - fprintf(stderr, " %s\n", streams[i].prefix); + for ( int i = 0; i < NUM_DBGS; ++i ) + fprintf(stderr, " %s\n", streams[i].prefix); - fprintf(stderr, "\n"); - fprintf(stderr, " plugin- (replace '::' in name with '-'; e.g., '-B " - "plugin-Zeek-Netmap')\n"); - fprintf(stderr, "\n"); - fprintf(stderr, "Pseudo streams\n"); - fprintf(stderr, " verbose Increase verbosity.\n"); - fprintf(stderr, " all Enable all streams at maximum verbosity.\n"); - fprintf(stderr, "\n"); - } + fprintf(stderr, "\n"); + fprintf(stderr, + " plugin- (replace '::' in name with '-'; e.g., '-B " + "plugin-Zeek-Netmap')\n"); + fprintf(stderr, "\n"); + fprintf(stderr, "Pseudo streams\n"); + fprintf(stderr, " verbose Increase verbosity.\n"); + fprintf(stderr, " all Enable all streams at maximum verbosity.\n"); + fprintf(stderr, "\n"); +} -void DebugLogger::EnableStreams(const char* s) - { - char* brkt; - char* tmp = util::copy_string(s); - char* tok = strtok(tmp, ","); +void DebugLogger::EnableStreams(const char* s) { + char* brkt; + char* tmp = util::copy_string(s); + char* tok = strtok(tmp, ","); - while ( tok ) - { - if ( strcasecmp("all", tok) == 0 ) - { - for ( int i = 0; i < NUM_DBGS; ++i ) - { - streams[i].enabled = true; - enabled_streams.insert(streams[i].prefix); - } + while ( tok ) { + if ( strcasecmp("all", tok) == 0 ) { + for ( int i = 0; i < NUM_DBGS; ++i ) { + streams[i].enabled = true; + enabled_streams.insert(streams[i].prefix); + } - verbose = true; - goto next; - } + verbose = true; + goto next; + } - if ( strcasecmp("verbose", tok) == 0 ) - { - verbose = true; - goto next; - } + if ( strcasecmp("verbose", tok) == 0 ) { + verbose = true; + goto next; + } - if ( strcasecmp("help", tok) == 0 ) - { - ShowStreamsHelp(); - exit(0); - } + if ( strcasecmp("help", tok) == 0 ) { + ShowStreamsHelp(); + exit(0); + } - if ( util::starts_with(tok, "plugin-") ) - { - // Cannot verify this at this time, plugins may not - // have been loaded. - enabled_streams.insert(tok); - goto next; - } + if ( util::starts_with(tok, "plugin-") ) { + // Cannot verify this at this time, plugins may not + // have been loaded. + enabled_streams.insert(tok); + goto next; + } - int i; + int i; - for ( i = 0; i < NUM_DBGS; ++i ) - { - if ( strcasecmp(streams[i].prefix, tok) == 0 ) - { - streams[i].enabled = true; - enabled_streams.insert(tok); - goto next; - } - } + for ( i = 0; i < NUM_DBGS; ++i ) { + if ( strcasecmp(streams[i].prefix, tok) == 0 ) { + streams[i].enabled = true; + enabled_streams.insert(tok); + goto next; + } + } - reporter->FatalError("unknown debug stream '%s', try -B help.\n", tok); + reporter->FatalError("unknown debug stream '%s', try -B help.\n", tok); - next: - tok = strtok(0, ","); - } + next: + tok = strtok(0, ","); + } - delete[] tmp; - } + delete[] tmp; +} -bool DebugLogger::CheckStreams(const std::set& plugin_names) - { - bool ok = true; +bool DebugLogger::CheckStreams(const std::set& plugin_names) { + bool ok = true; - std::set available_plugin_streams; - for ( const auto& p : plugin_names ) - available_plugin_streams.insert(PluginStreamName(p)); + std::set available_plugin_streams; + for ( const auto& p : plugin_names ) + available_plugin_streams.insert(PluginStreamName(p)); - for ( const auto& stream : enabled_streams ) - { - if ( ! util::starts_with(stream, "plugin-") ) - continue; + for ( const auto& stream : enabled_streams ) { + if ( ! util::starts_with(stream, "plugin-") ) + continue; - if ( available_plugin_streams.count(stream) == 0 ) - { - reporter->Error("No plugin debug stream '%s' found", stream.c_str()); - ok = false; - } - } + if ( available_plugin_streams.count(stream) == 0 ) { + reporter->Error("No plugin debug stream '%s' found", stream.c_str()); + ok = false; + } + } - return ok; - } + return ok; +} -void DebugLogger::Log(DebugStream stream, const char* fmt, ...) - { - Stream* g = &streams[int(stream)]; +void DebugLogger::Log(DebugStream stream, const char* fmt, ...) { + Stream* g = &streams[int(stream)]; - if ( ! g->enabled ) - return; + if ( ! g->enabled ) + return; - fprintf(file, "%17.06f/%17.06f [%s] ", run_state::network_time, util::current_time(true), - g->prefix); + fprintf(file, "%17.06f/%17.06f [%s] ", run_state::network_time, util::current_time(true), g->prefix); - for ( int i = g->indent; i > 0; --i ) - fputs(" ", file); + for ( int i = g->indent; i > 0; --i ) + fputs(" ", file); - va_list ap; - va_start(ap, fmt); - vfprintf(file, fmt, ap); - va_end(ap); + va_list ap; + va_start(ap, fmt); + vfprintf(file, fmt, ap); + va_end(ap); - fputc('\n', file); - fflush(file); - } + fputc('\n', file); + fflush(file); +} -void DebugLogger::Log(const plugin::Plugin& plugin, const char* fmt, ...) - { - std::string tok = PluginStreamName(plugin.Name()); +void DebugLogger::Log(const plugin::Plugin& plugin, const char* fmt, ...) { + std::string tok = PluginStreamName(plugin.Name()); - if ( enabled_streams.find(tok) == enabled_streams.end() ) - return; + if ( enabled_streams.find(tok) == enabled_streams.end() ) + return; - fprintf(file, "%17.06f/%17.06f [plugin %s] ", run_state::network_time, util::current_time(true), - plugin.Name().c_str()); + fprintf(file, "%17.06f/%17.06f [plugin %s] ", run_state::network_time, util::current_time(true), + plugin.Name().c_str()); - va_list ap; - va_start(ap, fmt); - vfprintf(file, fmt, ap); - va_end(ap); + va_list ap; + va_start(ap, fmt); + vfprintf(file, fmt, ap); + va_end(ap); - fputc('\n', file); - fflush(file); - } + fputc('\n', file); + fflush(file); +} - } // namespace zeek::detail +} // namespace zeek::detail #endif diff --git a/src/DebugLogger.h b/src/DebugLogger.h index 041e18e407..d22373b3ae 100644 --- a/src/DebugLogger.h +++ b/src/DebugLogger.h @@ -13,115 +13,106 @@ #include "zeek/util.h" -#define DBG_LOG(stream, ...) \ - if ( ::zeek::detail::debug_logger.IsEnabled(stream) ) \ - ::zeek::detail::debug_logger.Log(stream, __VA_ARGS__) -#define DBG_LOG_VERBOSE(stream, ...) \ - if ( ::zeek::detail::debug_logger.IsVerbose() && \ - ::zeek::detail::debug_logger.IsEnabled(stream) ) \ - ::zeek::detail::debug_logger.Log(stream, __VA_ARGS__) +#define DBG_LOG(stream, ...) \ + if ( ::zeek::detail::debug_logger.IsEnabled(stream) ) \ + ::zeek::detail::debug_logger.Log(stream, __VA_ARGS__) +#define DBG_LOG_VERBOSE(stream, ...) \ + if ( ::zeek::detail::debug_logger.IsVerbose() && ::zeek::detail::debug_logger.IsEnabled(stream) ) \ + ::zeek::detail::debug_logger.Log(stream, __VA_ARGS__) #define DBG_PUSH(stream) ::zeek::detail::debug_logger.PushIndent(stream) #define DBG_POP(stream) ::zeek::detail::debug_logger.PopIndent(stream) #define PLUGIN_DBG_LOG(plugin, ...) ::zeek::detail::debug_logger.Log(plugin, __VA_ARGS__) -namespace zeek - { +namespace zeek { -namespace plugin - { +namespace plugin { class Plugin; - } +} // To add a new debugging stream, add a constant here as well as // an entry to DebugLogger::streams in DebugLogger.cc. -enum DebugStream - { - DBG_SERIAL, // Serialization - DBG_RULES, // Signature matching - DBG_STRING, // String code - DBG_NOTIFIERS, // Notifiers - DBG_MAINLOOP, // Main IOSource loop - DBG_ANALYZER, // Analyzer framework - DBG_PACKET_ANALYSIS, // Packet analysis - DBG_FILE_ANALYSIS, // File analysis - DBG_TM, // Time-machine packet input via Broccoli - DBG_LOGGING, // Logging streams - DBG_INPUT, // Input streams - DBG_THREADING, // Threading system - DBG_PLUGINS, // Plugin system - DBG_ZEEKYGEN, // Zeekygen - DBG_PKTIO, // Packet sources and dumpers. - DBG_BROKER, // Broker communication - DBG_SCRIPTS, // Script initialization - DBG_SUPERVISOR, // Process supervisor - DBG_HASHKEY, // HashKey buffers - DBG_SPICY, // Spicy functionality +enum DebugStream { + DBG_SERIAL, // Serialization + DBG_RULES, // Signature matching + DBG_STRING, // String code + DBG_NOTIFIERS, // Notifiers + DBG_MAINLOOP, // Main IOSource loop + DBG_ANALYZER, // Analyzer framework + DBG_PACKET_ANALYSIS, // Packet analysis + DBG_FILE_ANALYSIS, // File analysis + DBG_TM, // Time-machine packet input via Broccoli + DBG_LOGGING, // Logging streams + DBG_INPUT, // Input streams + DBG_THREADING, // Threading system + DBG_PLUGINS, // Plugin system + DBG_ZEEKYGEN, // Zeekygen + DBG_PKTIO, // Packet sources and dumpers. + DBG_BROKER, // Broker communication + DBG_SCRIPTS, // Script initialization + DBG_SUPERVISOR, // Process supervisor + DBG_HASHKEY, // HashKey buffers + DBG_SPICY, // Spicy functionality - NUM_DBGS // Has to be last - }; + NUM_DBGS // Has to be last +}; -namespace detail - { +namespace detail { -class DebugLogger - { +class DebugLogger { public: - // Output goes to stderr per default. - DebugLogger(); - ~DebugLogger(); + // Output goes to stderr per default. + DebugLogger(); + ~DebugLogger(); - void OpenDebugLog(const char* filename = 0); + void OpenDebugLog(const char* filename = 0); - void Log(DebugStream stream, const char* fmt, ...) __attribute__((format(printf, 3, 4))); - void Log(const plugin::Plugin& plugin, const char* fmt, ...) - __attribute__((format(printf, 3, 4))); + void Log(DebugStream stream, const char* fmt, ...) __attribute__((format(printf, 3, 4))); + void Log(const plugin::Plugin& plugin, const char* fmt, ...) __attribute__((format(printf, 3, 4))); - void PushIndent(DebugStream stream) { ++streams[int(stream)].indent; } - void PopIndent(DebugStream stream) { --streams[int(stream)].indent; } + void PushIndent(DebugStream stream) { ++streams[int(stream)].indent; } + void PopIndent(DebugStream stream) { --streams[int(stream)].indent; } - void EnableStream(DebugStream stream) { streams[int(stream)].enabled = true; } - void DisableStream(DebugStream stream) { streams[int(stream)].enabled = false; } + void EnableStream(DebugStream stream) { streams[int(stream)].enabled = true; } + void DisableStream(DebugStream stream) { streams[int(stream)].enabled = false; } - // Takes comma-separated list of stream prefixes. - void EnableStreams(const char* streams); + // Takes comma-separated list of stream prefixes. + void EnableStreams(const char* streams); - // Check the enabled streams for invalid ones. - bool CheckStreams(const std::set& plugin_names); + // Check the enabled streams for invalid ones. + bool CheckStreams(const std::set& plugin_names); - bool IsEnabled(DebugStream stream) const { return streams[int(stream)].enabled; } + bool IsEnabled(DebugStream stream) const { return streams[int(stream)].enabled; } - void SetVerbose(bool arg_verbose) { verbose = arg_verbose; } - bool IsVerbose() const { return verbose; } + void SetVerbose(bool arg_verbose) { verbose = arg_verbose; } + bool IsVerbose() const { return verbose; } - void ShowStreamsHelp(); + void ShowStreamsHelp(); private: - FILE* file; - bool verbose; + FILE* file; + bool verbose; - struct Stream - { - const char* prefix; - int indent; - bool enabled; - }; + struct Stream { + const char* prefix; + int indent; + bool enabled; + }; - std::set enabled_streams; + std::set enabled_streams; - static Stream streams[NUM_DBGS]; + static Stream streams[NUM_DBGS]; - const std::string PluginStreamName(const std::string& plugin_name) - { - return "plugin-" + util::strreplace(plugin_name, "::", "-"); - } - }; + const std::string PluginStreamName(const std::string& plugin_name) { + return "plugin-" + util::strreplace(plugin_name, "::", "-"); + } +}; extern DebugLogger debug_logger; - } // namespace detail - } // namespace zeek +} // namespace detail +} // namespace zeek #else #define DBG_LOG(...) diff --git a/src/Desc.cc b/src/Desc.cc index 7dd1c2e61f..27660dce97 100644 --- a/src/Desc.cc +++ b/src/Desc.cc @@ -17,428 +17,362 @@ #define DEFAULT_SIZE 128 #define SLOP 10 -namespace zeek - { - -ODesc::ODesc(DescType t, File* arg_f) - { - type = t; - style = STANDARD_STYLE; - f = arg_f; - - if ( f == nullptr ) - { - size = DEFAULT_SIZE; - base = util::safe_malloc(size); - ((char*)base)[0] = '\0'; - offset = 0; - } - else - { - offset = size = 0; - base = nullptr; - } - - indent_level = 0; - is_short = false; - want_quotes = false; - want_determinism = false; - do_flush = true; - include_stats = false; - indent_with_spaces = 0; - escape = false; - utf8 = false; - } - -ODesc::~ODesc() - { - if ( f ) - { - if ( do_flush ) - f->Flush(); - } - else if ( base ) - free(base); - } - -void ODesc::EnableEscaping() - { - escape = true; - } - -void ODesc::EnableUTF8() - { - utf8 = true; - } - -void ODesc::PushIndent() - { - ++indent_level; - NL(); - } - -void ODesc::PopIndent() - { - if ( --indent_level < 0 ) - reporter->InternalError("ODesc::PopIndent underflow"); - - NL(); - } - -void ODesc::PopIndentNoNL() - { - if ( --indent_level < 0 ) - reporter->InternalError("ODesc::PopIndent underflow"); - } - -void ODesc::Add(const char* s, int do_indent) - { - unsigned int n = strlen(s); - - if ( do_indent && IsReadable() && offset > 0 && ((const char*)base)[offset - 1] == '\n' ) - Indent(); - - if ( IsBinary() ) - AddBytes(s, n + 1); - else - AddBytes(s, n); - } - -void ODesc::Add(int i) - { - if ( IsBinary() ) - AddBytes(&i, sizeof(i)); - else - { - char tmp[256]; - modp_litoa10(i, tmp); - Add(tmp); - } - } - -void ODesc::Add(uint32_t u) - { - if ( IsBinary() ) - AddBytes(&u, sizeof(u)); - else - { - char tmp[256]; - modp_ulitoa10(u, tmp); - Add(tmp); - } - } - -void ODesc::Add(int64_t i) - { - if ( IsBinary() ) - AddBytes(&i, sizeof(i)); - else - { - char tmp[256]; - modp_litoa10(i, tmp); - Add(tmp); - } - } - -void ODesc::Add(uint64_t u) - { - if ( IsBinary() ) - AddBytes(&u, sizeof(u)); - else - { - char tmp[256]; - modp_ulitoa10(u, tmp); - Add(tmp); - } - } - -void ODesc::Add(double d, bool no_exp) - { - if ( IsBinary() ) - AddBytes(&d, sizeof(d)); - else - { - // Buffer needs enough chars to store max. possible "double" value - // of 1.79e308 without using scientific notation. - char tmp[350]; - - if ( no_exp ) - modp_dtoa3(d, tmp, sizeof(tmp), IsReadable() ? 6 : 8); - else - modp_dtoa2(d, tmp, IsReadable() ? 6 : 8); - - Add(tmp); - - auto approx_equal = [](double a, double b, double tolerance = 1e-6) -> bool - { - auto v = a - b; - return v < 0 ? -v < tolerance : v < tolerance; - }; - - if ( approx_equal(d, nearbyint(d), 1e-9) && std::isfinite(d) && ! strchr(tmp, 'e') ) - // disambiguate from integer - Add(".0"); - } - } - -void ODesc::Add(const IPAddr& addr) - { - Add(addr.AsString()); - } - -void ODesc::Add(const IPPrefix& prefix) - { - Add(prefix.AsString()); - } - -void ODesc::AddCS(const char* s) - { - int n = strlen(s); - Add(n); - if ( ! IsBinary() ) - Add(" "); - Add(s); - } - -void ODesc::AddBytes(const String* s) - { - if ( IsReadable() ) - { - if ( Style() == RAW_STYLE ) - AddBytes(reinterpret_cast(s->Bytes()), s->Len()); - else - { - const char* str = s->Render(String::EXPANDED_STRING); - Add(str); - delete[] str; - } - } - else - { - Add(s->Len()); - if ( ! IsBinary() ) - Add(" "); - AddBytes(s->Bytes(), s->Len()); - } - } - -void ODesc::Indent() - { - if ( indent_with_spaces > 0 ) - { - for ( int i = 0; i < indent_level; ++i ) - for ( int j = 0; j < indent_with_spaces; ++j ) - Add(" ", 0); - } - else - { - for ( int i = 0; i < indent_level; ++i ) - Add("\t", 0); - } - } - -static bool starts_with(const char* str1, const char* str2, size_t len) - { - for ( size_t i = 0; i < len; ++i ) - if ( str1[i] != str2[i] ) - return false; - - return true; - } - -size_t ODesc::StartsWithEscapeSequence(const char* start, const char* end) - { - if ( escape_sequences.empty() ) - return 0; - - for ( const auto& esc_str : escape_sequences ) - { - size_t esc_len = esc_str.length(); - - if ( start + esc_len > end ) - continue; - - if ( starts_with(start, esc_str.c_str(), esc_len) ) - return esc_len; - } - - return 0; - } - -std::pair ODesc::FirstEscapeLoc(const char* bytes, size_t n) - { - if ( IsBinary() ) - return {nullptr, 0}; - - for ( size_t i = 0; i < n; ++i ) - { - auto printable = isprint(bytes[i]); - - if ( ! printable && ! utf8 ) - return {bytes + i, 1}; - - if ( bytes[i] == '\\' ) - return {bytes + i, 1}; - - size_t len = StartsWithEscapeSequence(bytes + i, bytes + n); - - if ( len ) - return {bytes + i, len}; - } - - return {nullptr, 0}; - } - -void ODesc::AddBytes(const void* bytes, unsigned int n) - { - if ( ! escape ) - { - AddBytesRaw(bytes, n); - return; - } - - const char* s = (const char*)bytes; - const char* e = (const char*)bytes + n; - - while ( s < e ) - { - auto [esc_start, esc_len] = FirstEscapeLoc(s, e - s); - - if ( esc_start != nullptr ) - { - if ( utf8 ) - { - std::string result = util::json_escape_utf8(s, esc_start - s, false); - AddBytesRaw(result.c_str(), result.size()); - } - else - AddBytesRaw(s, esc_start - s); - - util::get_escaped_string(this, esc_start, esc_len, true); - s = esc_start + esc_len; - } - else - { - if ( utf8 ) - { - std::string result = util::json_escape_utf8(s, e - s, false); - AddBytesRaw(result.c_str(), result.size()); - } - else - AddBytesRaw(s, e - s); - - break; - } - } - } - -void ODesc::AddBytesRaw(const void* bytes, unsigned int n) - { - if ( n == 0 ) - return; - - if ( f ) - { - static bool write_failed = false; - - if ( ! f->Write((const char*)bytes, n) ) - { - if ( ! write_failed ) - // Most likely it's a "disk full" so report - // subsequent failures only once. - reporter->Error("error writing to %s: %s", f->Name(), strerror(errno)); - - write_failed = true; - return; - } - - write_failed = false; - } - - else - { - Grow(n); - - // The following casting contortions are necessary because - // simply using &base[offset] generates complaints about - // using a void* for pointer arithmetic. - memcpy((void*)&((char*)base)[offset], bytes, n); - offset += n; - - ((char*)base)[offset] = '\0'; // ensure that always NUL-term. - } - } - -void ODesc::Grow(unsigned int n) - { - bool size_changed = false; - while ( offset + n + SLOP >= size ) - { - size *= 2; - size_changed = true; - } - - if ( size_changed ) - base = util::safe_realloc(base, size); - } - -void ODesc::Clear() - { - offset = 0; - - // If we've allocated an exceedingly large amount of space, free it. - if ( size > 10 * 1024 * 1024 ) - { - free(base); - size = DEFAULT_SIZE; - base = util::safe_malloc(size); - ((char*)base)[0] = '\0'; - } - } - -bool ODesc::PushType(const Type* type) - { - auto res = encountered_types.insert(type); - return std::get<1>(res); - } - -bool ODesc::PopType(const Type* type) - { - size_t res = encountered_types.erase(type); - return (res == 1); - } - -bool ODesc::FindType(const Type* type) - { - auto res = encountered_types.find(type); - - if ( res != encountered_types.end() ) - return true; - - return false; - } - -std::string obj_desc(const Obj* o) - { - static ODesc d; - - d.Clear(); - o->Describe(&d); - d.SP(); - o->GetLocationInfo()->Describe(&d); - - return d.Description(); - } - -std::string obj_desc_short(const Obj* o) - { - static ODesc d; - - d.SetShort(true); - d.Clear(); - o->Describe(&d); - - return d.Description(); - } - - } // namespace zeek +namespace zeek { + +ODesc::ODesc(DescType t, File* arg_f) { + type = t; + style = STANDARD_STYLE; + f = arg_f; + + if ( f == nullptr ) { + size = DEFAULT_SIZE; + base = util::safe_malloc(size); + ((char*)base)[0] = '\0'; + offset = 0; + } + else { + offset = size = 0; + base = nullptr; + } + + indent_level = 0; + is_short = false; + want_quotes = false; + want_determinism = false; + do_flush = true; + include_stats = false; + indent_with_spaces = 0; + escape = false; + utf8 = false; +} + +ODesc::~ODesc() { + if ( f ) { + if ( do_flush ) + f->Flush(); + } + else if ( base ) + free(base); +} + +void ODesc::EnableEscaping() { escape = true; } + +void ODesc::EnableUTF8() { utf8 = true; } + +void ODesc::PushIndent() { + ++indent_level; + NL(); +} + +void ODesc::PopIndent() { + if ( --indent_level < 0 ) + reporter->InternalError("ODesc::PopIndent underflow"); + + NL(); +} + +void ODesc::PopIndentNoNL() { + if ( --indent_level < 0 ) + reporter->InternalError("ODesc::PopIndent underflow"); +} + +void ODesc::Add(const char* s, int do_indent) { + unsigned int n = strlen(s); + + if ( do_indent && IsReadable() && offset > 0 && ((const char*)base)[offset - 1] == '\n' ) + Indent(); + + if ( IsBinary() ) + AddBytes(s, n + 1); + else + AddBytes(s, n); +} + +void ODesc::Add(int i) { + if ( IsBinary() ) + AddBytes(&i, sizeof(i)); + else { + char tmp[256]; + modp_litoa10(i, tmp); + Add(tmp); + } +} + +void ODesc::Add(uint32_t u) { + if ( IsBinary() ) + AddBytes(&u, sizeof(u)); + else { + char tmp[256]; + modp_ulitoa10(u, tmp); + Add(tmp); + } +} + +void ODesc::Add(int64_t i) { + if ( IsBinary() ) + AddBytes(&i, sizeof(i)); + else { + char tmp[256]; + modp_litoa10(i, tmp); + Add(tmp); + } +} + +void ODesc::Add(uint64_t u) { + if ( IsBinary() ) + AddBytes(&u, sizeof(u)); + else { + char tmp[256]; + modp_ulitoa10(u, tmp); + Add(tmp); + } +} + +void ODesc::Add(double d, bool no_exp) { + if ( IsBinary() ) + AddBytes(&d, sizeof(d)); + else { + // Buffer needs enough chars to store max. possible "double" value + // of 1.79e308 without using scientific notation. + char tmp[350]; + + if ( no_exp ) + modp_dtoa3(d, tmp, sizeof(tmp), IsReadable() ? 6 : 8); + else + modp_dtoa2(d, tmp, IsReadable() ? 6 : 8); + + Add(tmp); + + auto approx_equal = [](double a, double b, double tolerance = 1e-6) -> bool { + auto v = a - b; + return v < 0 ? -v < tolerance : v < tolerance; + }; + + if ( approx_equal(d, nearbyint(d), 1e-9) && std::isfinite(d) && ! strchr(tmp, 'e') ) + // disambiguate from integer + Add(".0"); + } +} + +void ODesc::Add(const IPAddr& addr) { Add(addr.AsString()); } + +void ODesc::Add(const IPPrefix& prefix) { Add(prefix.AsString()); } + +void ODesc::AddCS(const char* s) { + int n = strlen(s); + Add(n); + if ( ! IsBinary() ) + Add(" "); + Add(s); +} + +void ODesc::AddBytes(const String* s) { + if ( IsReadable() ) { + if ( Style() == RAW_STYLE ) + AddBytes(reinterpret_cast(s->Bytes()), s->Len()); + else { + const char* str = s->Render(String::EXPANDED_STRING); + Add(str); + delete[] str; + } + } + else { + Add(s->Len()); + if ( ! IsBinary() ) + Add(" "); + AddBytes(s->Bytes(), s->Len()); + } +} + +void ODesc::Indent() { + if ( indent_with_spaces > 0 ) { + for ( int i = 0; i < indent_level; ++i ) + for ( int j = 0; j < indent_with_spaces; ++j ) + Add(" ", 0); + } + else { + for ( int i = 0; i < indent_level; ++i ) + Add("\t", 0); + } +} + +static bool starts_with(const char* str1, const char* str2, size_t len) { + for ( size_t i = 0; i < len; ++i ) + if ( str1[i] != str2[i] ) + return false; + + return true; +} + +size_t ODesc::StartsWithEscapeSequence(const char* start, const char* end) { + if ( escape_sequences.empty() ) + return 0; + + for ( const auto& esc_str : escape_sequences ) { + size_t esc_len = esc_str.length(); + + if ( start + esc_len > end ) + continue; + + if ( starts_with(start, esc_str.c_str(), esc_len) ) + return esc_len; + } + + return 0; +} + +std::pair ODesc::FirstEscapeLoc(const char* bytes, size_t n) { + if ( IsBinary() ) + return {nullptr, 0}; + + for ( size_t i = 0; i < n; ++i ) { + auto printable = isprint(bytes[i]); + + if ( ! printable && ! utf8 ) + return {bytes + i, 1}; + + if ( bytes[i] == '\\' ) + return {bytes + i, 1}; + + size_t len = StartsWithEscapeSequence(bytes + i, bytes + n); + + if ( len ) + return {bytes + i, len}; + } + + return {nullptr, 0}; +} + +void ODesc::AddBytes(const void* bytes, unsigned int n) { + if ( ! escape ) { + AddBytesRaw(bytes, n); + return; + } + + const char* s = (const char*)bytes; + const char* e = (const char*)bytes + n; + + while ( s < e ) { + auto [esc_start, esc_len] = FirstEscapeLoc(s, e - s); + + if ( esc_start != nullptr ) { + if ( utf8 ) { + std::string result = util::json_escape_utf8(s, esc_start - s, false); + AddBytesRaw(result.c_str(), result.size()); + } + else + AddBytesRaw(s, esc_start - s); + + util::get_escaped_string(this, esc_start, esc_len, true); + s = esc_start + esc_len; + } + else { + if ( utf8 ) { + std::string result = util::json_escape_utf8(s, e - s, false); + AddBytesRaw(result.c_str(), result.size()); + } + else + AddBytesRaw(s, e - s); + + break; + } + } +} + +void ODesc::AddBytesRaw(const void* bytes, unsigned int n) { + if ( n == 0 ) + return; + + if ( f ) { + static bool write_failed = false; + + if ( ! f->Write((const char*)bytes, n) ) { + if ( ! write_failed ) + // Most likely it's a "disk full" so report + // subsequent failures only once. + reporter->Error("error writing to %s: %s", f->Name(), strerror(errno)); + + write_failed = true; + return; + } + + write_failed = false; + } + + else { + Grow(n); + + // The following casting contortions are necessary because + // simply using &base[offset] generates complaints about + // using a void* for pointer arithmetic. + memcpy((void*)&((char*)base)[offset], bytes, n); + offset += n; + + ((char*)base)[offset] = '\0'; // ensure that always NUL-term. + } +} + +void ODesc::Grow(unsigned int n) { + bool size_changed = false; + while ( offset + n + SLOP >= size ) { + size *= 2; + size_changed = true; + } + + if ( size_changed ) + base = util::safe_realloc(base, size); +} + +void ODesc::Clear() { + offset = 0; + + // If we've allocated an exceedingly large amount of space, free it. + if ( size > 10 * 1024 * 1024 ) { + free(base); + size = DEFAULT_SIZE; + base = util::safe_malloc(size); + ((char*)base)[0] = '\0'; + } +} + +bool ODesc::PushType(const Type* type) { + auto res = encountered_types.insert(type); + return std::get<1>(res); +} + +bool ODesc::PopType(const Type* type) { + size_t res = encountered_types.erase(type); + return (res == 1); +} + +bool ODesc::FindType(const Type* type) { + auto res = encountered_types.find(type); + + if ( res != encountered_types.end() ) + return true; + + return false; +} + +std::string obj_desc(const Obj* o) { + static ODesc d; + + d.Clear(); + o->Describe(&d); + d.SP(); + o->GetLocationInfo()->Describe(&d); + + return d.Description(); +} + +std::string obj_desc_short(const Obj* o) { + static ODesc d; + + d.SetShort(true); + d.Clear(); + o->Describe(&d); + + return d.Description(); +} + +} // namespace zeek diff --git a/src/Desc.h b/src/Desc.h index e907693da0..ed092cb0b3 100644 --- a/src/Desc.h +++ b/src/Desc.h @@ -8,222 +8,207 @@ #include #include "zeek/ZeekString.h" // for byte_vec -#include "zeek/util.h" // for zeek_int_t +#include "zeek/util.h" // for zeek_int_t -namespace zeek - { +namespace zeek { class IPAddr; class IPPrefix; class File; class Type; -enum DescType - { - DESC_READABLE, - DESC_BINARY, - }; +enum DescType { + DESC_READABLE, + DESC_BINARY, +}; -enum DescStyle - { - STANDARD_STYLE, - RAW_STYLE, - }; +enum DescStyle { + STANDARD_STYLE, + RAW_STYLE, +}; -class ODesc - { +class ODesc { public: - explicit ODesc(DescType t = DESC_READABLE, File* f = nullptr); + explicit ODesc(DescType t = DESC_READABLE, File* f = nullptr); - ~ODesc(); + ~ODesc(); - bool IsReadable() const { return type == DESC_READABLE; } - bool IsBinary() const { return type == DESC_BINARY; } + bool IsReadable() const { return type == DESC_READABLE; } + bool IsBinary() const { return type == DESC_BINARY; } - bool IsShort() const { return is_short; } - void SetShort() { is_short = true; } - void SetShort(bool s) { is_short = s; } + bool IsShort() const { return is_short; } + void SetShort() { is_short = true; } + void SetShort(bool s) { is_short = s; } - // Whether we want to have quotes around strings. - bool WantQuotes() const { return want_quotes; } - void SetQuotes(bool q) { want_quotes = q; } + // Whether we want to have quotes around strings. + bool WantQuotes() const { return want_quotes; } + void SetQuotes(bool q) { want_quotes = q; } - // Whether to ensure deterministic output (for example, when - // describing TableVal's). - bool WantDeterminism() const { return want_determinism; } - void SetDeterminism(bool d) { want_determinism = d; } + // Whether to ensure deterministic output (for example, when + // describing TableVal's). + bool WantDeterminism() const { return want_determinism; } + void SetDeterminism(bool d) { want_determinism = d; } - // Whether we want to print statistics like access time and execution - // count where available. - bool IncludeStats() const { return include_stats; } - void SetIncludeStats(bool s) { include_stats = s; } + // Whether we want to print statistics like access time and execution + // count where available. + bool IncludeStats() const { return include_stats; } + void SetIncludeStats(bool s) { include_stats = s; } - DescStyle Style() const { return style; } - void SetStyle(DescStyle s) { style = s; } + DescStyle Style() const { return style; } + void SetStyle(DescStyle s) { style = s; } - void SetFlush(bool arg_do_flush) { do_flush = arg_do_flush; } + void SetFlush(bool arg_do_flush) { do_flush = arg_do_flush; } - void EnableEscaping(); - void EnableUTF8(); - void AddEscapeSequence(const char* s) { escape_sequences.insert(s); } - void AddEscapeSequence(const char* s, size_t n) { escape_sequences.insert(std::string(s, n)); } - void AddEscapeSequence(const std::string& s) { escape_sequences.insert(s); } - void RemoveEscapeSequence(const char* s) { escape_sequences.erase(s); } - void RemoveEscapeSequence(const char* s, size_t n) - { - escape_sequences.erase(std::string(s, n)); - } - void RemoveEscapeSequence(const std::string& s) { escape_sequences.erase(s); } + void EnableEscaping(); + void EnableUTF8(); + void AddEscapeSequence(const char* s) { escape_sequences.insert(s); } + void AddEscapeSequence(const char* s, size_t n) { escape_sequences.insert(std::string(s, n)); } + void AddEscapeSequence(const std::string& s) { escape_sequences.insert(s); } + void RemoveEscapeSequence(const char* s) { escape_sequences.erase(s); } + void RemoveEscapeSequence(const char* s, size_t n) { escape_sequences.erase(std::string(s, n)); } + void RemoveEscapeSequence(const std::string& s) { escape_sequences.erase(s); } - void PushIndent(); - void PopIndent(); - void PopIndentNoNL(); - int GetIndentLevel() const { return indent_level; } - void ClearIndentLevel() { indent_level = 0; } + void PushIndent(); + void PopIndent(); + void PopIndentNoNL(); + int GetIndentLevel() const { return indent_level; } + void ClearIndentLevel() { indent_level = 0; } - int IndentSpaces() const { return indent_with_spaces; } - void SetIndentSpaces(int i) { indent_with_spaces = i; } + int IndentSpaces() const { return indent_with_spaces; } + void SetIndentSpaces(int i) { indent_with_spaces = i; } - void Add(const char* s, int do_indent = 1); - void AddN(const char* s, int len) { AddBytes(s, len); } - void Add(const std::string& s) { AddBytes(s.data(), s.size()); } - void Add(int i); - void Add(uint32_t u); - void Add(int64_t i); - void Add(uint64_t u); - void Add(double d, bool no_exp = false); - void Add(const IPAddr& addr); - void Add(const IPPrefix& prefix); + void Add(const char* s, int do_indent = 1); + void AddN(const char* s, int len) { AddBytes(s, len); } + void Add(const std::string& s) { AddBytes(s.data(), s.size()); } + void Add(int i); + void Add(uint32_t u); + void Add(int64_t i); + void Add(uint64_t u); + void Add(double d, bool no_exp = false); + void Add(const IPAddr& addr); + void Add(const IPPrefix& prefix); - // Add s as a counted string. - void AddCS(const char* s); + // Add s as a counted string. + void AddCS(const char* s); - void AddBytes(const String* s); + void AddBytes(const String* s); - void Add(const char* s1, const char* s2) - { - Add(s1); - Add(s2); - } + void Add(const char* s1, const char* s2) { + Add(s1); + Add(s2); + } - void AddSP(const char* s1, const char* s2) - { - Add(s1); - AddSP(s2); - } + void AddSP(const char* s1, const char* s2) { + Add(s1); + AddSP(s2); + } - void AddSP(const char* s) - { - Add(s); - SP(); - } + void AddSP(const char* s) { + Add(s); + SP(); + } - void AddCount(zeek_int_t n) - { - if ( ! IsReadable() ) - { - Add(n); - SP(); - } - } + void AddCount(zeek_int_t n) { + if ( ! IsReadable() ) { + Add(n); + SP(); + } + } - void SP() - { - if ( ! IsBinary() ) - Add(" ", 0); - } - void NL() - { - if ( ! IsBinary() && ! is_short ) - Add("\n", 0); - } + void SP() { + if ( ! IsBinary() ) + Add(" ", 0); + } + void NL() { + if ( ! IsBinary() && ! is_short ) + Add("\n", 0); + } - // Bypasses the escaping enabled via EnableEscaping(). - void AddRaw(const char* s, int len) { AddBytesRaw(s, len); } - void AddRaw(const std::string& s) { AddBytesRaw(s.data(), s.size()); } + // Bypasses the escaping enabled via EnableEscaping(). + void AddRaw(const char* s, int len) { AddBytesRaw(s, len); } + void AddRaw(const std::string& s) { AddBytesRaw(s.data(), s.size()); } - // Returns the description as a string. - const char* Description() const { return (const char*)base; } + // Returns the description as a string. + const char* Description() const { return (const char*)base; } - const u_char* Bytes() const { return (const u_char*)base; } - byte_vec TakeBytes() - { - const void* t = base; - base = nullptr; - size = 0; + const u_char* Bytes() const { return (const u_char*)base; } + byte_vec TakeBytes() { + const void* t = base; + base = nullptr; + size = 0; - // Don't clear offset, as we want to still support - // subsequent calls to Len(). + // Don't clear offset, as we want to still support + // subsequent calls to Len(). - return byte_vec(t); - } + return byte_vec(t); + } - int Len() const { return offset; } + int Len() const { return offset; } - void Clear(); + void Clear(); - // Used to determine recursive types. Records push their types on here; - // if the same type (by address) is re-encountered, processing aborts. - bool PushType(const Type* type); - bool PopType(const Type* type); - bool FindType(const Type* type); + // Used to determine recursive types. Records push their types on here; + // if the same type (by address) is re-encountered, processing aborts. + bool PushType(const Type* type); + bool PopType(const Type* type); + bool FindType(const Type* type); protected: - void Indent(); + void Indent(); - void AddBytes(const void* bytes, unsigned int n); - void AddBytesRaw(const void* bytes, unsigned int n); + void AddBytes(const void* bytes, unsigned int n); + void AddBytesRaw(const void* bytes, unsigned int n); - // Make buffer big enough for n bytes beyond bufp. - void Grow(unsigned int n); + // Make buffer big enough for n bytes beyond bufp. + void Grow(unsigned int n); - /** - * Returns the location of the first place in the bytes to be hex-escaped. - * - * @param bytes the starting memory address to start searching for - * escapable character. - * @param n the maximum number of bytes to search. - * @return a pair whose first element represents a starting memory address - * to be escaped up to the number of characters indicated by the - * second element. The first element may be 0 if nothing is - * to be escaped. - */ - std::pair FirstEscapeLoc(const char* bytes, size_t n); + /** + * Returns the location of the first place in the bytes to be hex-escaped. + * + * @param bytes the starting memory address to start searching for + * escapable character. + * @param n the maximum number of bytes to search. + * @return a pair whose first element represents a starting memory address + * to be escaped up to the number of characters indicated by the + * second element. The first element may be 0 if nothing is + * to be escaped. + */ + std::pair FirstEscapeLoc(const char* bytes, size_t n); - /** - * @param start start of string to check for starting with an escape - * sequence. - * @param end one byte past the last character in the string. - * @return The number of bytes in the escape sequence that the string - * starts with. - */ - size_t StartsWithEscapeSequence(const char* start, const char* end); + /** + * @param start start of string to check for starting with an escape + * sequence. + * @param end one byte past the last character in the string. + * @return The number of bytes in the escape sequence that the string + * starts with. + */ + size_t StartsWithEscapeSequence(const char* start, const char* end); - DescType type; - DescStyle style; + DescType type; + DescStyle style; - void* base; // beginning of buffer - unsigned int offset; // where we are in the buffer - unsigned int size; // size of buffer in bytes + void* base; // beginning of buffer + unsigned int offset; // where we are in the buffer + unsigned int size; // size of buffer in bytes - bool utf8; // whether valid utf-8 sequences may pass through unescaped - bool escape; // escape unprintable characters in output? - bool is_short; - bool want_quotes; - bool want_determinism; - bool do_flush; - bool include_stats; + bool utf8; // whether valid utf-8 sequences may pass through unescaped + bool escape; // escape unprintable characters in output? + bool is_short; + bool want_quotes; + bool want_determinism; + bool do_flush; + bool include_stats; - int indent_with_spaces; - int indent_level; + int indent_with_spaces; + int indent_level; - using escape_set = std::set; - escape_set escape_sequences; // additional sequences of chars to escape + using escape_set = std::set; + escape_set escape_sequences; // additional sequences of chars to escape - File* f; // or the file we're using. + File* f; // or the file we're using. - std::set encountered_types; - }; + std::set encountered_types; +}; // Returns a string representation of an object's description. Used for // debugging and error messages. takes a bare pointer rather than an @@ -235,4 +220,4 @@ std::string obj_desc(const Obj* o); // Same as obj_desc(), but ensure it is short and don't include location info. std::string obj_desc_short(const Obj* o); - } // namespace zeek +} // namespace zeek diff --git a/src/Dict.cc b/src/Dict.cc index a8fc362f12..c02444bc57 100644 --- a/src/Dict.cc +++ b/src/Dict.cc @@ -5,458 +5,435 @@ #include "zeek/3rdparty/doctest.h" #include "zeek/Hash.h" -namespace zeek - { +namespace zeek { // namespace detail TEST_SUITE_BEGIN("Dict"); -TEST_CASE("dict construction") - { - PDict dict; - CHECK(! dict.IsOrdered()); - CHECK(dict.Length() == 0); - - PDict dict2(ORDERED); - CHECK(dict2.IsOrdered()); - CHECK(dict2.Length() == 0); - } +TEST_CASE("dict construction") { + PDict dict; + CHECK(! dict.IsOrdered()); + CHECK(dict.Length() == 0); -TEST_CASE("dict operation") - { - PDict dict; + PDict dict2(ORDERED); + CHECK(dict2.IsOrdered()); + CHECK(dict2.Length() == 0); +} - uint32_t val = 10; - uint32_t key_val = 5; - - detail::HashKey* key = new detail::HashKey(key_val); - dict.Insert(key, &val); - CHECK(dict.Length() == 1); - - detail::HashKey* key2 = new detail::HashKey(key_val); - uint32_t* lookup = dict.Lookup(key2); - CHECK(*lookup == val); - - dict.Remove(key2); - CHECK(dict.Length() == 0); - uint32_t* lookup2 = dict.Lookup(key2); - CHECK(lookup2 == (uint32_t*)0); - delete key2; - - CHECK(dict.MaxLength() == 1); - CHECK(dict.NumCumulativeInserts() == 1); - - dict.Insert(key, &val); - dict.Remove(key); - - CHECK(dict.MaxLength() == 1); - CHECK(dict.NumCumulativeInserts() == 2); - - uint32_t val2 = 15; - uint32_t key_val2 = 25; - key2 = new detail::HashKey(key_val2); - - dict.Insert(key, &val); - dict.Insert(key2, &val2); - CHECK(dict.Length() == 2); - CHECK(dict.NumCumulativeInserts() == 4); - - dict.Clear(); - CHECK(dict.Length() == 0); - - delete key; - delete key2; - } - -TEST_CASE("dict nthentry") - { - PDict unordered(UNORDERED); - PDict ordered(ORDERED); - - uint32_t val = 15; - uint32_t key_val = 5; - detail::HashKey* okey = new detail::HashKey(key_val); - detail::HashKey* ukey = new detail::HashKey(key_val); - - uint32_t val2 = 10; - uint32_t key_val2 = 25; - detail::HashKey* okey2 = new detail::HashKey(key_val2); - detail::HashKey* ukey2 = new detail::HashKey(key_val2); - - unordered.Insert(ukey, &val); - unordered.Insert(ukey2, &val2); - - ordered.Insert(okey, &val); - ordered.Insert(okey2, &val2); - - // NthEntry returns null for unordered dicts - uint32_t* lookup = unordered.NthEntry(0); - CHECK(lookup == (uint32_t*)0); - - // Ordered dicts are based on order of insertion, nothing about the - // data itself - lookup = ordered.NthEntry(0); - CHECK(*lookup == 15); - - delete okey; - delete okey2; - delete ukey; - delete ukey2; - } - -TEST_CASE("dict iteration") - { - PDict dict; - - uint32_t val = 15; - uint32_t key_val = 5; - detail::HashKey* key = new detail::HashKey(key_val); - - uint32_t val2 = 10; - uint32_t key_val2 = 25; - detail::HashKey* key2 = new detail::HashKey(key_val2); - - dict.Insert(key, &val); - dict.Insert(key2, &val2); - - int count = 0; - - for ( const auto& entry : dict ) - { - auto* v = static_cast(entry.value); - uint64_t k = *(uint32_t*)entry.GetKey(); - - switch ( count ) - { - case 0: - CHECK(k == key_val2); - CHECK(*v == val2); - break; - case 1: - CHECK(k == key_val); - CHECK(*v == val); - break; - default: - break; - } - - count++; - } - - PDict::iterator it; - it = dict.begin(); - it = dict.end(); - PDict::iterator it2 = it; - - CHECK(count == 2); - - delete key; - delete key2; - } - -TEST_CASE("dict robust iteration") - { - PDict dict; - - uint32_t val = 15; - uint32_t key_val = 5; - detail::HashKey* key = new detail::HashKey(key_val); - - uint32_t val2 = 10; - uint32_t key_val2 = 25; - detail::HashKey* key2 = new detail::HashKey(key_val2); - - uint32_t val3 = 20; - uint32_t key_val3 = 35; - detail::HashKey* key3 = new detail::HashKey(key_val3); - - dict.Insert(key, &val); - dict.Insert(key2, &val2); - - { - int count = 0; - auto it = dict.begin_robust(); - - for ( ; it != dict.end_robust(); ++it ) - { - auto* v = it->value; - uint64_t k = *(uint32_t*)it->GetKey(); - - switch ( count ) - { - case 0: - CHECK(k == key_val2); - CHECK(*v == val2); - dict.Insert(key3, &val3); - break; - case 1: - CHECK(k == key_val); - CHECK(*v == val); - break; - case 2: - CHECK(k == key_val3); - CHECK(*v == val3); - break; - default: - // We shouldn't get here. - CHECK(false); - break; - } - count++; - } - - CHECK(count == 3); - } - - { - int count = 0; - auto it = dict.begin_robust(); - - for ( ; it != dict.end_robust(); ++it ) - { - auto* v = it->value; - uint64_t k = *(uint32_t*)it->GetKey(); - - switch ( count ) - { - case 0: - CHECK(k == key_val2); - CHECK(*v == val2); - dict.Insert(key3, &val3); - dict.Remove(key3); - break; - case 1: - CHECK(k == key_val); - CHECK(*v == val); - break; - default: - // We shouldn't get here. - CHECK(false); - break; - } - count++; - } - - CHECK(count == 2); - } - - delete key; - delete key2; - delete key3; - } - -TEST_CASE("dict ordered iteration") - { - PDict dict(DictOrder::ORDERED); - - // These key values are specifically contrived to be inserted - // into the dictionary in a different order by default. - uint32_t val = 15; - uint32_t key_val = 5; - auto key = std::make_unique(key_val); - - uint32_t val2 = 10; - uint32_t key_val2 = 25; - auto key2 = std::make_unique(key_val2); - - uint32_t val3 = 30; - uint32_t key_val3 = 45; - auto key3 = std::make_unique(key_val3); - - uint32_t val4 = 20; - uint32_t key_val4 = 35; - auto key4 = std::make_unique(key_val4); - - // Only insert the first three to start with so we can test the order - // being the same after a later insertion. - dict.Insert(key.get(), &val); - dict.Insert(key2.get(), &val2); - dict.Insert(key3.get(), &val3); - - int count = 0; - - for ( const auto& entry : dict ) - { - auto* v = static_cast(entry.value); - uint32_t k = *(uint32_t*)entry.GetKey(); - - // The keys should be returned in the same order we inserted - // them, which is 5, 25, 45. - if ( count == 0 ) - CHECK(k == 5); - else if ( count == 1 ) - CHECK(k == 25); - else if ( count == 2 ) - CHECK(k == 45); - - count++; - } - - dict.Insert(key4.get(), &val4); - count = 0; - - for ( const auto& entry : dict ) - { - auto* v = static_cast(entry.value); - uint32_t k = *(uint32_t*)entry.GetKey(); - - // The keys should be returned in the same order we inserted - // them, which is 5, 25, 45, 35. - if ( count == 0 ) - CHECK(k == 5); - else if ( count == 1 ) - CHECK(k == 25); - else if ( count == 2 ) - CHECK(k == 45); - else if ( count == 3 ) - CHECK(k == 35); - - count++; - } - - dict.Remove(key2.get()); - count = 0; - - for ( const auto& entry : dict ) - { - auto* v = static_cast(entry.value); - uint32_t k = *(uint32_t*)entry.GetKey(); - - // The keys should be returned in the same order we inserted - // them, which is 5, 45, 35. - if ( count == 0 ) - CHECK(k == 5); - else if ( count == 1 ) - CHECK(k == 45); - else if ( count == 2 ) - CHECK(k == 35); - - count++; - } - } - -class DictTestDummy - { +TEST_CASE("dict operation") { + PDict dict; + + uint32_t val = 10; + uint32_t key_val = 5; + + detail::HashKey* key = new detail::HashKey(key_val); + dict.Insert(key, &val); + CHECK(dict.Length() == 1); + + detail::HashKey* key2 = new detail::HashKey(key_val); + uint32_t* lookup = dict.Lookup(key2); + CHECK(*lookup == val); + + dict.Remove(key2); + CHECK(dict.Length() == 0); + uint32_t* lookup2 = dict.Lookup(key2); + CHECK(lookup2 == (uint32_t*)0); + delete key2; + + CHECK(dict.MaxLength() == 1); + CHECK(dict.NumCumulativeInserts() == 1); + + dict.Insert(key, &val); + dict.Remove(key); + + CHECK(dict.MaxLength() == 1); + CHECK(dict.NumCumulativeInserts() == 2); + + uint32_t val2 = 15; + uint32_t key_val2 = 25; + key2 = new detail::HashKey(key_val2); + + dict.Insert(key, &val); + dict.Insert(key2, &val2); + CHECK(dict.Length() == 2); + CHECK(dict.NumCumulativeInserts() == 4); + + dict.Clear(); + CHECK(dict.Length() == 0); + + delete key; + delete key2; +} + +TEST_CASE("dict nthentry") { + PDict unordered(UNORDERED); + PDict ordered(ORDERED); + + uint32_t val = 15; + uint32_t key_val = 5; + detail::HashKey* okey = new detail::HashKey(key_val); + detail::HashKey* ukey = new detail::HashKey(key_val); + + uint32_t val2 = 10; + uint32_t key_val2 = 25; + detail::HashKey* okey2 = new detail::HashKey(key_val2); + detail::HashKey* ukey2 = new detail::HashKey(key_val2); + + unordered.Insert(ukey, &val); + unordered.Insert(ukey2, &val2); + + ordered.Insert(okey, &val); + ordered.Insert(okey2, &val2); + + // NthEntry returns null for unordered dicts + uint32_t* lookup = unordered.NthEntry(0); + CHECK(lookup == (uint32_t*)0); + + // Ordered dicts are based on order of insertion, nothing about the + // data itself + lookup = ordered.NthEntry(0); + CHECK(*lookup == 15); + + delete okey; + delete okey2; + delete ukey; + delete ukey2; +} + +TEST_CASE("dict iteration") { + PDict dict; + + uint32_t val = 15; + uint32_t key_val = 5; + detail::HashKey* key = new detail::HashKey(key_val); + + uint32_t val2 = 10; + uint32_t key_val2 = 25; + detail::HashKey* key2 = new detail::HashKey(key_val2); + + dict.Insert(key, &val); + dict.Insert(key2, &val2); + + int count = 0; + + for ( const auto& entry : dict ) { + auto* v = static_cast(entry.value); + uint64_t k = *(uint32_t*)entry.GetKey(); + + switch ( count ) { + case 0: + CHECK(k == key_val2); + CHECK(*v == val2); + break; + case 1: + CHECK(k == key_val); + CHECK(*v == val); + break; + default: break; + } + + count++; + } + + PDict::iterator it; + it = dict.begin(); + it = dict.end(); + PDict::iterator it2 = it; + + CHECK(count == 2); + + delete key; + delete key2; +} + +TEST_CASE("dict robust iteration") { + PDict dict; + + uint32_t val = 15; + uint32_t key_val = 5; + detail::HashKey* key = new detail::HashKey(key_val); + + uint32_t val2 = 10; + uint32_t key_val2 = 25; + detail::HashKey* key2 = new detail::HashKey(key_val2); + + uint32_t val3 = 20; + uint32_t key_val3 = 35; + detail::HashKey* key3 = new detail::HashKey(key_val3); + + dict.Insert(key, &val); + dict.Insert(key2, &val2); + + { + int count = 0; + auto it = dict.begin_robust(); + + for ( ; it != dict.end_robust(); ++it ) { + auto* v = it->value; + uint64_t k = *(uint32_t*)it->GetKey(); + + switch ( count ) { + case 0: + CHECK(k == key_val2); + CHECK(*v == val2); + dict.Insert(key3, &val3); + break; + case 1: + CHECK(k == key_val); + CHECK(*v == val); + break; + case 2: + CHECK(k == key_val3); + CHECK(*v == val3); + break; + default: + // We shouldn't get here. + CHECK(false); + break; + } + count++; + } + + CHECK(count == 3); + } + + { + int count = 0; + auto it = dict.begin_robust(); + + for ( ; it != dict.end_robust(); ++it ) { + auto* v = it->value; + uint64_t k = *(uint32_t*)it->GetKey(); + + switch ( count ) { + case 0: + CHECK(k == key_val2); + CHECK(*v == val2); + dict.Insert(key3, &val3); + dict.Remove(key3); + break; + case 1: + CHECK(k == key_val); + CHECK(*v == val); + break; + default: + // We shouldn't get here. + CHECK(false); + break; + } + count++; + } + + CHECK(count == 2); + } + + delete key; + delete key2; + delete key3; +} + +TEST_CASE("dict ordered iteration") { + PDict dict(DictOrder::ORDERED); + + // These key values are specifically contrived to be inserted + // into the dictionary in a different order by default. + uint32_t val = 15; + uint32_t key_val = 5; + auto key = std::make_unique(key_val); + + uint32_t val2 = 10; + uint32_t key_val2 = 25; + auto key2 = std::make_unique(key_val2); + + uint32_t val3 = 30; + uint32_t key_val3 = 45; + auto key3 = std::make_unique(key_val3); + + uint32_t val4 = 20; + uint32_t key_val4 = 35; + auto key4 = std::make_unique(key_val4); + + // Only insert the first three to start with so we can test the order + // being the same after a later insertion. + dict.Insert(key.get(), &val); + dict.Insert(key2.get(), &val2); + dict.Insert(key3.get(), &val3); + + int count = 0; + + for ( const auto& entry : dict ) { + auto* v = static_cast(entry.value); + uint32_t k = *(uint32_t*)entry.GetKey(); + + // The keys should be returned in the same order we inserted + // them, which is 5, 25, 45. + if ( count == 0 ) + CHECK(k == 5); + else if ( count == 1 ) + CHECK(k == 25); + else if ( count == 2 ) + CHECK(k == 45); + + count++; + } + + dict.Insert(key4.get(), &val4); + count = 0; + + for ( const auto& entry : dict ) { + auto* v = static_cast(entry.value); + uint32_t k = *(uint32_t*)entry.GetKey(); + + // The keys should be returned in the same order we inserted + // them, which is 5, 25, 45, 35. + if ( count == 0 ) + CHECK(k == 5); + else if ( count == 1 ) + CHECK(k == 25); + else if ( count == 2 ) + CHECK(k == 45); + else if ( count == 3 ) + CHECK(k == 35); + + count++; + } + + dict.Remove(key2.get()); + count = 0; + + for ( const auto& entry : dict ) { + auto* v = static_cast(entry.value); + uint32_t k = *(uint32_t*)entry.GetKey(); + + // The keys should be returned in the same order we inserted + // them, which is 5, 45, 35. + if ( count == 0 ) + CHECK(k == 5); + else if ( count == 1 ) + CHECK(k == 45); + else if ( count == 2 ) + CHECK(k == 35); + + count++; + } +} + +class DictTestDummy { public: - DictTestDummy(int v) : v(v) { } - ~DictTestDummy() = default; - int v = 0; - }; + DictTestDummy(int v) : v(v) {} + ~DictTestDummy() = default; + int v = 0; +}; -TEST_CASE("dict robust iteration replacement") - { - PDict dict; +TEST_CASE("dict robust iteration replacement") { + PDict dict; - DictTestDummy* val1 = new DictTestDummy(15); - uint32_t key_val1 = 5; - detail::HashKey* key1 = new detail::HashKey(key_val1); + DictTestDummy* val1 = new DictTestDummy(15); + uint32_t key_val1 = 5; + detail::HashKey* key1 = new detail::HashKey(key_val1); - DictTestDummy* val2 = new DictTestDummy(10); - uint32_t key_val2 = 25; - detail::HashKey* key2 = new detail::HashKey(key_val2); + DictTestDummy* val2 = new DictTestDummy(10); + uint32_t key_val2 = 25; + detail::HashKey* key2 = new detail::HashKey(key_val2); - DictTestDummy* val3 = new DictTestDummy(20); - uint32_t key_val3 = 35; - detail::HashKey* key3 = new detail::HashKey(key_val3); + DictTestDummy* val3 = new DictTestDummy(20); + uint32_t key_val3 = 35; + detail::HashKey* key3 = new detail::HashKey(key_val3); - dict.Insert(key1, val1); - dict.Insert(key2, val2); - dict.Insert(key3, val3); + dict.Insert(key1, val1); + dict.Insert(key2, val2); + dict.Insert(key3, val3); - int count = 0; - auto it = dict.begin_robust(); + int count = 0; + auto it = dict.begin_robust(); - // Iterate past the first couple of elements so we're not done, but the - // iterator is still pointing at a valid element. - for ( ; count != 2 && it != dict.end_robust(); ++count, ++it ) { } + // Iterate past the first couple of elements so we're not done, but the + // iterator is still pointing at a valid element. + for ( ; count != 2 && it != dict.end_robust(); ++count, ++it ) { + } - // Store off the value at this iterator index - auto* v = it->value; + // Store off the value at this iterator index + auto* v = it->value; - // Replace it with something else - auto k = it->GetHashKey(); - DictTestDummy* val4 = new DictTestDummy(50); - dict.Insert(k.get(), val4); + // Replace it with something else + auto k = it->GetHashKey(); + DictTestDummy* val4 = new DictTestDummy(50); + dict.Insert(k.get(), val4); - // Delete the original element - delete val2; + // Delete the original element + delete val2; - // This shouldn't crash with AddressSanitizer - for ( ; it != dict.end_robust(); ++it ) - { - uint64_t k = *(uint32_t*)it->GetKey(); - auto* v = it->value; - CHECK(v->v == 50); - } + // This shouldn't crash with AddressSanitizer + for ( ; it != dict.end_robust(); ++it ) { + uint64_t k = *(uint32_t*)it->GetKey(); + auto* v = it->value; + CHECK(v->v == 50); + } - delete key1; - delete key2; - delete key3; + delete key1; + delete key2; + delete key3; - delete val1; - delete val3; - delete val4; - } + delete val1; + delete val3; + delete val4; +} -TEST_CASE("dict iterator invalidation") - { - PDict dict; +TEST_CASE("dict iterator invalidation") { + PDict dict; - uint32_t val = 15; - uint32_t key_val = 5; - auto key = new detail::HashKey(key_val); + uint32_t val = 15; + uint32_t key_val = 5; + auto key = new detail::HashKey(key_val); - uint32_t val2 = 10; - uint32_t key_val2 = 25; - auto key2 = new detail::HashKey(key_val2); + uint32_t val2 = 10; + uint32_t key_val2 = 25; + auto key2 = new detail::HashKey(key_val2); - uint32_t val3 = 42; - uint32_t key_val3 = 37; - auto key3 = new detail::HashKey(key_val3); + uint32_t val3 = 42; + uint32_t key_val3 = 37; + auto key3 = new detail::HashKey(key_val3); - dict.Insert(key, &val); - dict.Insert(key2, &val2); + dict.Insert(key, &val); + dict.Insert(key2, &val2); - detail::HashKey* it_key; - bool iterators_invalidated = false; + detail::HashKey* it_key; + bool iterators_invalidated = false; - auto it = dict.begin(); - iterators_invalidated = false; - dict.Remove(key3, &iterators_invalidated); - // Key doesn't exist, nothing to remove, iteration not invalidated. - CHECK(! iterators_invalidated); + auto it = dict.begin(); + iterators_invalidated = false; + dict.Remove(key3, &iterators_invalidated); + // Key doesn't exist, nothing to remove, iteration not invalidated. + CHECK(! iterators_invalidated); - iterators_invalidated = false; - dict.Insert(key, &val2, &iterators_invalidated); - // Key exists, value gets overwritten, iteration not invalidated. - CHECK(! iterators_invalidated); + iterators_invalidated = false; + dict.Insert(key, &val2, &iterators_invalidated); + // Key exists, value gets overwritten, iteration not invalidated. + CHECK(! iterators_invalidated); - iterators_invalidated = false; - dict.Remove(key2, &iterators_invalidated); - // Key exists, gets removed, iteration is invalidated. - CHECK(iterators_invalidated); + iterators_invalidated = false; + dict.Remove(key2, &iterators_invalidated); + // Key exists, gets removed, iteration is invalidated. + CHECK(iterators_invalidated); - it = dict.begin(); - iterators_invalidated = false; - dict.Insert(key3, &val3, &iterators_invalidated); - // Key doesn't exist, gets inserted, iteration is invalidated. - CHECK(iterators_invalidated); + it = dict.begin(); + iterators_invalidated = false; + dict.Insert(key3, &val3, &iterators_invalidated); + // Key doesn't exist, gets inserted, iteration is invalidated. + CHECK(iterators_invalidated); - CHECK(dict.Length() == 2); - CHECK(*static_cast(dict.Lookup(key)) == val2); - CHECK(*static_cast(dict.Lookup(key3)) == val3); - CHECK(static_cast(dict.Lookup(key2)) == nullptr); + CHECK(dict.Length() == 2); + CHECK(*static_cast(dict.Lookup(key)) == val2); + CHECK(*static_cast(dict.Lookup(key3)) == val3); + CHECK(static_cast(dict.Lookup(key2)) == nullptr); - delete key; - delete key2; - delete key3; - } + delete key; + delete key2; + delete key3; +} // private -void generic_delete_func(void* v) - { - free(v); - } +void generic_delete_func(void* v) { free(v); } - } // namespace zeek +} // namespace zeek diff --git a/src/Dict.h b/src/Dict.h index e61faf5c66..a1162de2d2 100644 --- a/src/Dict.h +++ b/src/Dict.h @@ -24,22 +24,17 @@ using dict_delete_func = void (*)(void*); #define ASSERT_EQUAL(a, b) #endif // DEBUG -namespace zeek - { +namespace zeek { -template class Dictionary; +template +class Dictionary; -enum DictOrder - { - ORDERED, - UNORDERED - }; +enum DictOrder { ORDERED, UNORDERED }; // A dict_delete_func that just calls delete. extern void generic_delete_func(void*); -namespace detail - { +namespace detail { // Default number of hash buckets in dictionary. The dictionary will increase the size // of the hash table as needed. @@ -82,469 +77,418 @@ constexpr uint16_t TOO_FAR_TO_REACH = 0xFFFF; /** * An entry stored in the dictionary. */ -template class DictEntry - { +template +class DictEntry { public: #ifdef DEBUG - int bucket = 0; + int bucket = 0; #endif - // Distance from the expected position in the table. 0xFFFF means that the entry is empty. - uint16_t distance = TOO_FAR_TO_REACH; + // Distance from the expected position in the table. 0xFFFF means that the entry is empty. + uint16_t distance = TOO_FAR_TO_REACH; - // The size of the key. Less than 8 bytes we'll store directly in the entry, otherwise we'll - // store it as a pointer. This avoids extra allocations if we can help it. - uint32_t key_size = 0; + // The size of the key. Less than 8 bytes we'll store directly in the entry, otherwise we'll + // store it as a pointer. This avoids extra allocations if we can help it. + uint32_t key_size = 0; - // The maximum value of the key size above. This allows Dictionary to truncate keys before - // they get stored into an entry to avoid weird overflow errors. - static constexpr uint32_t MAX_KEY_SIZE = UINT32_MAX; + // The maximum value of the key size above. This allows Dictionary to truncate keys before + // they get stored into an entry to avoid weird overflow errors. + static constexpr uint32_t MAX_KEY_SIZE = UINT32_MAX; - // Lower 4 bytes of the 8-byte hash, which is used to calculate the position in the table. - uint32_t hash = 0; + // Lower 4 bytes of the 8-byte hash, which is used to calculate the position in the table. + uint32_t hash = 0; - T* value = nullptr; - union { - char key_here[8]; // hold key len<=8. when over 8, it's a pointer to real keys. - char* key; - }; + T* value = nullptr; + union { + char key_here[8]; // hold key len<=8. when over 8, it's a pointer to real keys. + char* key; + }; - DictEntry(void* arg_key, uint32_t key_size = 0, hash_t hash = 0, T* value = nullptr, - int16_t d = TOO_FAR_TO_REACH, bool copy_key = false) - : distance(d), key_size(key_size), hash((uint32_t)hash), value(value) - { - if ( ! arg_key ) - return; + DictEntry(void* arg_key, uint32_t key_size = 0, hash_t hash = 0, T* value = nullptr, int16_t d = TOO_FAR_TO_REACH, + bool copy_key = false) + : distance(d), key_size(key_size), hash((uint32_t)hash), value(value) { + if ( ! arg_key ) + return; - if ( key_size <= 8 ) - { - memcpy(key_here, arg_key, key_size); - if ( ! copy_key ) - delete[](char*) arg_key; // own the arg_key, now don't need it. - } - else - { - if ( copy_key ) - { - key = new char[key_size]; - memcpy(key, arg_key, key_size); - } - else - { - key = (char*)arg_key; - } - } - } + if ( key_size <= 8 ) { + memcpy(key_here, arg_key, key_size); + if ( ! copy_key ) + delete[](char*) arg_key; // own the arg_key, now don't need it. + } + else { + if ( copy_key ) { + key = new char[key_size]; + memcpy(key, arg_key, key_size); + } + else { + key = (char*)arg_key; + } + } + } - bool Empty() const { return distance == TOO_FAR_TO_REACH; } - void SetEmpty() - { - distance = TOO_FAR_TO_REACH; + bool Empty() const { return distance == TOO_FAR_TO_REACH; } + void SetEmpty() { + distance = TOO_FAR_TO_REACH; #ifdef DEBUG - hash = 0; - key = nullptr; - value = nullptr; - key_size = 0; - bucket = 0; + hash = 0; + key = nullptr; + value = nullptr; + key_size = 0; + bucket = 0; #endif // DEBUG - } + } - void Clear() - { - if ( key_size > 8 ) - delete[] key; - SetEmpty(); - } + void Clear() { + if ( key_size > 8 ) + delete[] key; + SetEmpty(); + } - const char* GetKey() const { return key_size <= 8 ? key_here : key; } - std::unique_ptr GetHashKey() const - { - return std::make_unique(GetKey(), key_size, hash); - } + const char* GetKey() const { return key_size <= 8 ? key_here : key; } + std::unique_ptr GetHashKey() const { + return std::make_unique(GetKey(), key_size, hash); + } - bool Equal(const char* arg_key, uint32_t arg_key_size, hash_t arg_hash) const - { // only 40-bit hash comparison. - return (0 == ((hash ^ arg_hash) & HASH_MASK)) && key_size == arg_key_size && - 0 == memcmp(GetKey(), arg_key, key_size); - } + bool Equal(const char* arg_key, uint32_t arg_key_size, hash_t arg_hash) const { // only 40-bit hash comparison. + return (0 == ((hash ^ arg_hash) & HASH_MASK)) && key_size == arg_key_size && + 0 == memcmp(GetKey(), arg_key, key_size); + } - bool operator==(const DictEntry& r) const { return Equal(r.GetKey(), r.key_size, r.hash); } - bool operator!=(const DictEntry& r) const { return ! Equal(r.GetKey(), r.key_size, r.hash); } - }; + bool operator==(const DictEntry& r) const { return Equal(r.GetKey(), r.key_size, r.hash); } + bool operator!=(const DictEntry& r) const { return ! Equal(r.GetKey(), r.key_size, r.hash); } +}; using DictEntryVec = std::vector; - } // namespace detail +} // namespace detail -template class DictIterator - { +template +class DictIterator { public: - using value_type = detail::DictEntry; - using reference = detail::DictEntry&; - using pointer = detail::DictEntry*; - using difference_type = std::ptrdiff_t; - using iterator_category = std::forward_iterator_tag; + using value_type = detail::DictEntry; + using reference = detail::DictEntry&; + using pointer = detail::DictEntry*; + using difference_type = std::ptrdiff_t; + using iterator_category = std::forward_iterator_tag; - DictIterator() = default; - ~DictIterator() - { - if ( dict ) - { - assert(dict->num_iterators > 0); - dict->DecrIters(); - } - } + DictIterator() = default; + ~DictIterator() { + if ( dict ) { + assert(dict->num_iterators > 0); + dict->DecrIters(); + } + } - DictIterator(const DictIterator& that) - { - if ( this == &that ) - return; + DictIterator(const DictIterator& that) { + if ( this == &that ) + return; - if ( dict ) - { - assert(dict->num_iterators > 0); - dict->DecrIters(); - } + if ( dict ) { + assert(dict->num_iterators > 0); + dict->DecrIters(); + } - dict = that.dict; - curr = that.curr; - end = that.end; - ordered_iter = that.ordered_iter; + dict = that.dict; + curr = that.curr; + end = that.end; + ordered_iter = that.ordered_iter; - dict->IncrIters(); - } + dict->IncrIters(); + } - DictIterator& operator=(const DictIterator& that) - { - if ( this == &that ) - return *this; + DictIterator& operator=(const DictIterator& that) { + if ( this == &that ) + return *this; - if ( dict ) - { - assert(dict->num_iterators > 0); - dict->DecrIters(); - } + if ( dict ) { + assert(dict->num_iterators > 0); + dict->DecrIters(); + } - dict = that.dict; - curr = that.curr; - end = that.end; - ordered_iter = that.ordered_iter; + dict = that.dict; + curr = that.curr; + end = that.end; + ordered_iter = that.ordered_iter; - dict->IncrIters(); + dict->IncrIters(); - return *this; - } + return *this; + } - DictIterator(DictIterator&& that) noexcept - { - if ( this == &that ) - return; + DictIterator(DictIterator&& that) noexcept { + if ( this == &that ) + return; - if ( dict ) - { - assert(dict->num_iterators > 0); - dict->DecrIters(); - } + if ( dict ) { + assert(dict->num_iterators > 0); + dict->DecrIters(); + } - dict = that.dict; - curr = that.curr; - end = that.end; - ordered_iter = that.ordered_iter; + dict = that.dict; + curr = that.curr; + end = that.end; + ordered_iter = that.ordered_iter; - that.dict = nullptr; - } + that.dict = nullptr; + } - DictIterator& operator=(DictIterator&& that) noexcept - { - if ( this == &that ) - return *this; + DictIterator& operator=(DictIterator&& that) noexcept { + if ( this == &that ) + return *this; - if ( dict ) - { - assert(dict->num_iterators > 0); - dict->DecrIters(); - } + if ( dict ) { + assert(dict->num_iterators > 0); + dict->DecrIters(); + } - dict = that.dict; - curr = that.curr; - end = that.end; - ordered_iter = that.ordered_iter; + dict = that.dict; + curr = that.curr; + end = that.end; + ordered_iter = that.ordered_iter; - that.dict = nullptr; + that.dict = nullptr; - return *this; - } + return *this; + } - reference operator*() - { - if ( dict->IsOrdered() ) - { - // TODO: how does this work if ordered_iter == end(). LookupEntry will return a nullptr, - // which the dereference will fail on. That's undefined behavior, correct? Is that any - // different than if the unordered version returns a dereference of it's end? - auto e = dict->LookupEntry(*ordered_iter); - return *e; - } + reference operator*() { + if ( dict->IsOrdered() ) { + // TODO: how does this work if ordered_iter == end(). LookupEntry will return a nullptr, + // which the dereference will fail on. That's undefined behavior, correct? Is that any + // different than if the unordered version returns a dereference of it's end? + auto e = dict->LookupEntry(*ordered_iter); + return *e; + } - return *curr; - } - reference operator*() const - { - if ( dict->IsOrdered() ) - { - auto e = dict->LookupEntry(*ordered_iter); - return *e; - } + return *curr; + } + reference operator*() const { + if ( dict->IsOrdered() ) { + auto e = dict->LookupEntry(*ordered_iter); + return *e; + } - return *curr; - } - pointer operator->() - { - if ( dict->IsOrdered() ) - return dict->LookupEntry(*ordered_iter); + return *curr; + } + pointer operator->() { + if ( dict->IsOrdered() ) + return dict->LookupEntry(*ordered_iter); - return curr; - } - pointer operator->() const - { - if ( dict->IsOrdered() ) - return dict->LookupEntry(*ordered_iter); + return curr; + } + pointer operator->() const { + if ( dict->IsOrdered() ) + return dict->LookupEntry(*ordered_iter); - return curr; - } + return curr; + } - DictIterator& operator++() - { - if ( dict->IsOrdered() ) - ++ordered_iter; - else - { - // The non-robust case is easy. Just advance the current position forward until you - // find one isn't empty and isn't the end. - do - { - ++curr; - } while ( curr != end && curr->Empty() ); - } + DictIterator& operator++() { + if ( dict->IsOrdered() ) + ++ordered_iter; + else { + // The non-robust case is easy. Just advance the current position forward until you + // find one isn't empty and isn't the end. + do { + ++curr; + } while ( curr != end && curr->Empty() ); + } - return *this; - } + return *this; + } - DictIterator operator++(int) - { - auto temp(*this); - ++*this; - return temp; - } + DictIterator operator++(int) { + auto temp(*this); + ++*this; + return temp; + } - bool operator==(const DictIterator& that) const - { - if ( dict != that.dict ) - return false; + bool operator==(const DictIterator& that) const { + if ( dict != that.dict ) + return false; - if ( dict->IsOrdered() ) - return ordered_iter == that.ordered_iter; + if ( dict->IsOrdered() ) + return ordered_iter == that.ordered_iter; - return curr == that.curr; - } + return curr == that.curr; + } - bool operator!=(const DictIterator& that) const { return ! (*this == that); } + bool operator!=(const DictIterator& that) const { return ! (*this == that); } private: - friend class Dictionary; + friend class Dictionary; - DictIterator(const Dictionary* d, detail::DictEntry* begin, detail::DictEntry* end) - : curr(begin), end(end) - { - // Cast away the constness so that the number of iterators can be modified in the - // dictionary. This does violate the constness guarantees of const-begin()/end() and - // cbegin()/cend(), but we're not modifying the actual data in the collection, just a - // counter in the wrapper of the collection. - dict = const_cast*>(d); + DictIterator(const Dictionary* d, detail::DictEntry* begin, detail::DictEntry* end) + : curr(begin), end(end) { + // Cast away the constness so that the number of iterators can be modified in the + // dictionary. This does violate the constness guarantees of const-begin()/end() and + // cbegin()/cend(), but we're not modifying the actual data in the collection, just a + // counter in the wrapper of the collection. + dict = const_cast*>(d); - // Make sure that we're starting on a non-empty element. - while ( curr != end && curr->Empty() ) - ++curr; + // Make sure that we're starting on a non-empty element. + while ( curr != end && curr->Empty() ) + ++curr; - dict->IncrIters(); - } + dict->IncrIters(); + } - DictIterator(const Dictionary* d, detail::DictEntryVec::iterator iter) : ordered_iter(iter) - { - // Cast away the constness so that the number of iterators can be modified in the - // dictionary. This does violate the constness guarantees of const-begin()/end() and - // cbegin()/cend(), but we're not modifying the actual data in the collection, just a - // counter in the wrapper of the collection. - dict = const_cast*>(d); - dict->IncrIters(); - } + DictIterator(const Dictionary* d, detail::DictEntryVec::iterator iter) : ordered_iter(iter) { + // Cast away the constness so that the number of iterators can be modified in the + // dictionary. This does violate the constness guarantees of const-begin()/end() and + // cbegin()/cend(), but we're not modifying the actual data in the collection, just a + // counter in the wrapper of the collection. + dict = const_cast*>(d); + dict->IncrIters(); + } - Dictionary* dict = nullptr; - detail::DictEntry* curr = nullptr; - detail::DictEntry* end = nullptr; - detail::DictEntryVec::iterator ordered_iter; - }; + Dictionary* dict = nullptr; + detail::DictEntry* curr = nullptr; + detail::DictEntry* end = nullptr; + detail::DictEntryVec::iterator ordered_iter; +}; -template class RobustDictIterator - { +template +class RobustDictIterator { public: - using value_type = detail::DictEntry; - using reference = detail::DictEntry&; - using pointer = detail::DictEntry*; - using difference_type = std::ptrdiff_t; - using iterator_category = std::forward_iterator_tag; + using value_type = detail::DictEntry; + using reference = detail::DictEntry&; + using pointer = detail::DictEntry*; + using difference_type = std::ptrdiff_t; + using iterator_category = std::forward_iterator_tag; - RobustDictIterator() : curr(nullptr) { } + RobustDictIterator() : curr(nullptr) {} - RobustDictIterator(Dictionary* d) : curr(nullptr), dict(d) - { - next = -1; - inserted = new std::vector>(); - visited = new std::vector>(); + RobustDictIterator(Dictionary* d) : curr(nullptr), dict(d) { + next = -1; + inserted = new std::vector>(); + visited = new std::vector>(); - dict->IncrIters(); - dict->iterators->push_back(this); + dict->IncrIters(); + dict->iterators->push_back(this); - // Advance the iterator one step so that we're at the first element. - curr = dict->GetNextRobustIteration(this); - } + // Advance the iterator one step so that we're at the first element. + curr = dict->GetNextRobustIteration(this); + } - RobustDictIterator(const RobustDictIterator& other) : curr(nullptr), dict(nullptr) - { - *this = other; - } + RobustDictIterator(const RobustDictIterator& other) : curr(nullptr), dict(nullptr) { *this = other; } - RobustDictIterator(RobustDictIterator&& other) noexcept : curr(nullptr), dict(nullptr) - { - *this = other; - } + RobustDictIterator(RobustDictIterator&& other) noexcept : curr(nullptr), dict(nullptr) { *this = other; } - ~RobustDictIterator() { Complete(); } + ~RobustDictIterator() { Complete(); } - reference operator*() { return curr; } - pointer operator->() { return &curr; } + reference operator*() { return curr; } + pointer operator->() { return &curr; } - RobustDictIterator& operator++() - { - curr = dict->GetNextRobustIteration(this); - return *this; - } + RobustDictIterator& operator++() { + curr = dict->GetNextRobustIteration(this); + return *this; + } - RobustDictIterator operator++(int) - { - auto temp(*this); - ++*this; - return temp; - } + RobustDictIterator operator++(int) { + auto temp(*this); + ++*this; + return temp; + } - RobustDictIterator& operator=(const RobustDictIterator& other) - { - if ( this == &other ) - return *this; + RobustDictIterator& operator=(const RobustDictIterator& other) { + if ( this == &other ) + return *this; - delete inserted; - inserted = nullptr; + delete inserted; + inserted = nullptr; - delete visited; - visited = nullptr; + delete visited; + visited = nullptr; - dict = nullptr; - curr.Clear(); - next = -1; + dict = nullptr; + curr.Clear(); + next = -1; - if ( other.dict ) - { - next = other.next; - inserted = new std::vector>(); - visited = new std::vector>(); + if ( other.dict ) { + next = other.next; + inserted = new std::vector>(); + visited = new std::vector>(); - if ( other.inserted ) - std::copy(other.inserted->begin(), other.inserted->end(), - std::back_inserter(*inserted)); + if ( other.inserted ) + std::copy(other.inserted->begin(), other.inserted->end(), std::back_inserter(*inserted)); - if ( other.visited ) - std::copy(other.visited->begin(), other.visited->end(), - std::back_inserter(*visited)); + if ( other.visited ) + std::copy(other.visited->begin(), other.visited->end(), std::back_inserter(*visited)); - dict = other.dict; - dict->IncrIters(); - dict->iterators->push_back(this); + dict = other.dict; + dict->IncrIters(); + dict->iterators->push_back(this); - curr = other.curr; - } + curr = other.curr; + } - return *this; - } + return *this; + } - RobustDictIterator& operator=(RobustDictIterator&& other) noexcept - { - delete inserted; - inserted = nullptr; + RobustDictIterator& operator=(RobustDictIterator&& other) noexcept { + delete inserted; + inserted = nullptr; - delete visited; - visited = nullptr; + delete visited; + visited = nullptr; - dict = nullptr; - curr.Clear(); - next = -1; + dict = nullptr; + curr.Clear(); + next = -1; - if ( other.dict ) - { - next = other.next; - inserted = other.inserted; - visited = other.visited; + if ( other.dict ) { + next = other.next; + inserted = other.inserted; + visited = other.visited; - dict = other.dict; - dict->iterators->push_back(this); - dict->iterators->erase( - std::remove(dict->iterators->begin(), dict->iterators->end(), &other), - dict->iterators->end()); - other.dict = nullptr; + dict = other.dict; + dict->iterators->push_back(this); + dict->iterators->erase(std::remove(dict->iterators->begin(), dict->iterators->end(), &other), + dict->iterators->end()); + other.dict = nullptr; - curr = std::move(other.curr); - } + curr = std::move(other.curr); + } - return *this; - } + return *this; + } - bool operator==(const RobustDictIterator& that) const { return curr == that.curr; } - bool operator!=(const RobustDictIterator& that) const { return ! (*this == that); } + bool operator==(const RobustDictIterator& that) const { return curr == that.curr; } + bool operator!=(const RobustDictIterator& that) const { return ! (*this == that); } private: - friend class Dictionary; + friend class Dictionary; - void Complete() - { - if ( dict ) - { - assert(dict->num_iterators > 0); - dict->DecrIters(); + void Complete() { + if ( dict ) { + assert(dict->num_iterators > 0); + dict->DecrIters(); - dict->iterators->erase( - std::remove(dict->iterators->begin(), dict->iterators->end(), this), - dict->iterators->end()); + dict->iterators->erase(std::remove(dict->iterators->begin(), dict->iterators->end(), this), + dict->iterators->end()); - delete inserted; - delete visited; + delete inserted; + delete visited; - inserted = nullptr; - visited = nullptr; - dict = nullptr; - } - } + inserted = nullptr; + visited = nullptr; + dict = nullptr; + } + } - // Tracks the new entries inserted while iterating. - std::vector>* inserted = nullptr; + // Tracks the new entries inserted while iterating. + std::vector>* inserted = nullptr; - // Tracks the entries already visited but were moved across the next iteration - // point due to an insertion. - std::vector>* visited = nullptr; + // Tracks the entries already visited but were moved across the next iteration + // point due to an insertion. + std::vector>* visited = nullptr; - detail::DictEntry curr; - Dictionary* dict = nullptr; - int next = -1; - }; + detail::DictEntry curr; + Dictionary* dict = nullptr; + int next = -1; +}; /** * A dictionary type that uses clustered hashing, a variation of Robinhood/Open Addressing @@ -558,1183 +502,1068 @@ private: * the keys but not the values. The dictionary size will be bounded at around 100K. 1M * entries is the absolute limit. Only Connections use that many entries, and that is rare. */ -template class Dictionary - { +template +class Dictionary { public: - explicit Dictionary(DictOrder ordering = UNORDERED, - int initial_size = detail::DEFAULT_DICT_SIZE) - { - if ( initial_size > 0 ) - { - // If an initial size is specified, init the table right away. Otherwise wait until the - // first insertion to init. - SetLog2Buckets(static_cast(std::log2(initial_size))); - Init(); - } + explicit Dictionary(DictOrder ordering = UNORDERED, int initial_size = detail::DEFAULT_DICT_SIZE) { + if ( initial_size > 0 ) { + // If an initial size is specified, init the table right away. Otherwise wait until the + // first insertion to init. + SetLog2Buckets(static_cast(std::log2(initial_size))); + Init(); + } - if ( ordering == ORDERED ) - order = std::make_unique>(); - } + if ( ordering == ORDERED ) + order = std::make_unique>(); + } - ~Dictionary() { Clear(); } + ~Dictionary() { Clear(); } - // Member functions for looking up a key, inserting/changing its - // contents, and deleting it. These come in two flavors: one - // which takes a zeek::detail::HashKey, and the other which takes a raw key, - // its size, and its (unmodulated) hash. - // lookup may move the key to right place if in the old zone to speed up the next lookup. - T* Lookup(const detail::HashKey* key) const - { - return Lookup(key->Key(), key->Size(), key->Hash()); - } + // Member functions for looking up a key, inserting/changing its + // contents, and deleting it. These come in two flavors: one + // which takes a zeek::detail::HashKey, and the other which takes a raw key, + // its size, and its (unmodulated) hash. + // lookup may move the key to right place if in the old zone to speed up the next lookup. + T* Lookup(const detail::HashKey* key) const { return Lookup(key->Key(), key->Size(), key->Hash()); } - T* Lookup(const void* key, int key_size, detail::hash_t h) const - { - if ( auto e = LookupEntry(key, key_size, h) ) - return e->value; + T* Lookup(const void* key, int key_size, detail::hash_t h) const { + if ( auto e = LookupEntry(key, key_size, h) ) + return e->value; - return nullptr; - } + return nullptr; + } - T* Lookup(const char* key) const - { - detail::HashKey h(key); - return Dictionary::Lookup(&h); - } + T* Lookup(const char* key) const { + detail::HashKey h(key); + return Dictionary::Lookup(&h); + } - // Returns previous value, or 0 if none. - // If iterators_invalidated is supplied, its value is set to true - // if the removal may have invalidated any existing iterators. - T* Insert(detail::HashKey* key, T* val, bool* iterators_invalidated = nullptr) - { - return Insert(key->TakeKey(), key->Size(), key->Hash(), val, false, iterators_invalidated); - } + // Returns previous value, or 0 if none. + // If iterators_invalidated is supplied, its value is set to true + // if the removal may have invalidated any existing iterators. + T* Insert(detail::HashKey* key, T* val, bool* iterators_invalidated = nullptr) { + return Insert(key->TakeKey(), key->Size(), key->Hash(), val, false, iterators_invalidated); + } - // If copy_key is true, then the key is copied, otherwise it's assumed - // that it's a heap pointer that now belongs to the Dictionary to - // manage as needed. - // If iterators_invalidated is supplied, its value is set to true - // if the removal may have invalidated any existing iterators. - T* Insert(void* key, uint64_t key_size, detail::hash_t hash, T* val, bool copy_key, - bool* iterators_invalidated = nullptr) - { - ASSERT_VALID(this); + // If copy_key is true, then the key is copied, otherwise it's assumed + // that it's a heap pointer that now belongs to the Dictionary to + // manage as needed. + // If iterators_invalidated is supplied, its value is set to true + // if the removal may have invalidated any existing iterators. + T* Insert(void* key, uint64_t key_size, detail::hash_t hash, T* val, bool copy_key, + bool* iterators_invalidated = nullptr) { + ASSERT_VALID(this); - // Initialize the table if it hasn't been done yet. This saves memory storing a bunch - // of empty dicts. - if ( ! table ) - Init(); + // Initialize the table if it hasn't been done yet. This saves memory storing a bunch + // of empty dicts. + if ( ! table ) + Init(); - T* v = nullptr; + T* v = nullptr; - if ( key_size > detail::DictEntry::MAX_KEY_SIZE ) - { - // If the key is bigger than something that will fit in a DictEntry, report a - // RuntimeError. This will throw an exception. If this call came from a script - // context, it'll cause the script interpreter to unwind and stop the script - // execution. If called elsewhere, Zeek will likely abort due to an unhandled - // exception. This is all entirely intentional. since if you got to this point - // something went really wrong with your input data. - auto loc = detail::GetCurrentLocation(); - reporter->RuntimeError(&loc, - "Attempted to create DictEntry with excessively large key, " - "truncating key (%" PRIu64 " > %u)", - key_size, detail::DictEntry::MAX_KEY_SIZE); - } + if ( key_size > detail::DictEntry::MAX_KEY_SIZE ) { + // If the key is bigger than something that will fit in a DictEntry, report a + // RuntimeError. This will throw an exception. If this call came from a script + // context, it'll cause the script interpreter to unwind and stop the script + // execution. If called elsewhere, Zeek will likely abort due to an unhandled + // exception. This is all entirely intentional. since if you got to this point + // something went really wrong with your input data. + auto loc = detail::GetCurrentLocation(); + reporter->RuntimeError(&loc, + "Attempted to create DictEntry with excessively large key, " + "truncating key (%" PRIu64 " > %u)", + key_size, detail::DictEntry::MAX_KEY_SIZE); + } - // Look to see if this key is already in the table. If found, insert_position is the - // position of the existing element. If not, insert_position is where it'll be inserted - // and insert_distance is the distance of the key for the position. - int insert_position = -1, insert_distance = -1; - int position = LookupIndex(key, key_size, hash, &insert_position, &insert_distance); - if ( position >= 0 ) - { - v = table[position].value; - table[position].value = val; - if ( ! copy_key ) - delete[](char*) key; + // Look to see if this key is already in the table. If found, insert_position is the + // position of the existing element. If not, insert_position is where it'll be inserted + // and insert_distance is the distance of the key for the position. + int insert_position = -1, insert_distance = -1; + int position = LookupIndex(key, key_size, hash, &insert_position, &insert_distance); + if ( position >= 0 ) { + v = table[position].value; + table[position].value = val; + if ( ! copy_key ) + delete[](char*) key; - if ( iterators && ! iterators->empty() ) - // need to set new v for iterators too. - for ( auto c : *iterators ) - { - // Check to see if this iterator points at the entry we're replacing. The - // iterator keeps a copy of the element, so we need to update it too. - if ( **c == table[position] ) - (*c)->value = val; + if ( iterators && ! iterators->empty() ) + // need to set new v for iterators too. + for ( auto c : *iterators ) { + // Check to see if this iterator points at the entry we're replacing. The + // iterator keeps a copy of the element, so we need to update it too. + if ( **c == table[position] ) + (*c)->value = val; - // Check if any of the inserted elements in this iterator point at the entry - // being replaced. Update those too. - auto it = std::find(c->inserted->begin(), c->inserted->end(), table[position]); - if ( it != c->inserted->end() ) - it->value = val; - } - } - else - { - if ( ! HaveOnlyRobustIterators() ) - { - if ( iterators_invalidated ) - *iterators_invalidated = true; - else - reporter->InternalWarning( - "Dictionary::Insert() possibly caused iterator invalidation"); - } + // Check if any of the inserted elements in this iterator point at the entry + // being replaced. Update those too. + auto it = std::find(c->inserted->begin(), c->inserted->end(), table[position]); + if ( it != c->inserted->end() ) + it->value = val; + } + } + else { + if ( ! HaveOnlyRobustIterators() ) { + if ( iterators_invalidated ) + *iterators_invalidated = true; + else + reporter->InternalWarning("Dictionary::Insert() possibly caused iterator invalidation"); + } - // Do this before the actual insertion since creating the DictEntry is going to delete - // the key data. We need a copy of it first. - if ( order ) - order->emplace_back(detail::HashKey{key, static_cast(key_size), hash}); + // Do this before the actual insertion since creating the DictEntry is going to delete + // the key data. We need a copy of it first. + if ( order ) + order->emplace_back(detail::HashKey{key, static_cast(key_size), hash}); - // Allocate memory for key if necessary. Key is updated to reflect internal key if - // necessary. - detail::DictEntry entry(key, key_size, hash, val, insert_distance, copy_key); - InsertRelocateAndAdjust(entry, insert_position); + // Allocate memory for key if necessary. Key is updated to reflect internal key if + // necessary. + detail::DictEntry entry(key, key_size, hash, val, insert_distance, copy_key); + InsertRelocateAndAdjust(entry, insert_position); - num_entries++; - cum_entries++; - if ( max_entries < num_entries ) - max_entries = num_entries; - if ( num_entries > ThresholdEntries() ) - SizeUp(); + num_entries++; + cum_entries++; + if ( max_entries < num_entries ) + max_entries = num_entries; + if ( num_entries > ThresholdEntries() ) + SizeUp(); - // if space_distance is too great, performance decreases. we need to sizeup for - // performance. - else if ( space_distance_samples > detail::MIN_SPACE_DISTANCE_SAMPLES && - static_cast(space_distance_sum) > - static_cast(space_distance_samples) * - detail::SPACE_DISTANCE_THRESHOLD && - static_cast(num_entries) > - detail::MIN_DICT_LOAD_FACTOR_100 * Capacity() / 100 ) - SizeUp(); - } + // if space_distance is too great, performance decreases. we need to sizeup for + // performance. + else if ( space_distance_samples > detail::MIN_SPACE_DISTANCE_SAMPLES && + static_cast(space_distance_sum) > + static_cast(space_distance_samples) * detail::SPACE_DISTANCE_THRESHOLD && + static_cast(num_entries) > detail::MIN_DICT_LOAD_FACTOR_100 * Capacity() / 100 ) + SizeUp(); + } - // Remap after insert can adjust asap to shorten period of mixed table. - // TODO: however, if remap happens right after size up, then it consumes more cpu for this - // cycle, a possible hiccup point. - if ( Remapping() ) - Remap(); - ASSERT_VALID(this); - return v; - } + // Remap after insert can adjust asap to shorten period of mixed table. + // TODO: however, if remap happens right after size up, then it consumes more cpu for this + // cycle, a possible hiccup point. + if ( Remapping() ) + Remap(); + ASSERT_VALID(this); + return v; + } - T* Insert(const char* key, T* val, bool* iterators_invalidated = nullptr) - { - detail::HashKey h(key); - return Insert(&h, val, iterators_invalidated); - } + T* Insert(const char* key, T* val, bool* iterators_invalidated = nullptr) { + detail::HashKey h(key); + return Insert(&h, val, iterators_invalidated); + } - // Removes the given element. Returns a pointer to the element in - // case it needs to be deleted. Returns 0 if no such element exists. - // If dontdelete is true, the key's bytes will not be deleted. - // If iterators_invalidated is supplied, its value is set to true - // if the removal may have invalidated any existing iterators. - T* Remove(const detail::HashKey* key, bool* iterators_invalidated = nullptr) - { - return Remove(key->Key(), key->Size(), key->Hash(), false, iterators_invalidated); - } - T* Remove(const void* key, int key_size, detail::hash_t hash, bool dont_delete = false, - bool* iterators_invalidated = nullptr) - { // cookie adjustment: maintain inserts here. maintain next in lower level version. - ASSERT_VALID(this); + // Removes the given element. Returns a pointer to the element in + // case it needs to be deleted. Returns 0 if no such element exists. + // If dontdelete is true, the key's bytes will not be deleted. + // If iterators_invalidated is supplied, its value is set to true + // if the removal may have invalidated any existing iterators. + T* Remove(const detail::HashKey* key, bool* iterators_invalidated = nullptr) { + return Remove(key->Key(), key->Size(), key->Hash(), false, iterators_invalidated); + } + T* Remove(const void* key, int key_size, detail::hash_t hash, bool dont_delete = false, + bool* iterators_invalidated = + nullptr) { // cookie adjustment: maintain inserts here. maintain next in lower level version. + ASSERT_VALID(this); - ASSERT(! dont_delete); // this is a poorly designed flag. if on, the internal has nowhere to - // return and memory is lost. + ASSERT(! dont_delete); // this is a poorly designed flag. if on, the internal has nowhere to + // return and memory is lost. - int position = LookupIndex(key, key_size, hash); - if ( position < 0 ) - return nullptr; + int position = LookupIndex(key, key_size, hash); + if ( position < 0 ) + return nullptr; - if ( ! HaveOnlyRobustIterators() ) - { - if ( iterators_invalidated ) - *iterators_invalidated = true; - else - reporter->InternalWarning( - "Dictionary::Remove() possibly caused iterator invalidation"); - } + if ( ! HaveOnlyRobustIterators() ) { + if ( iterators_invalidated ) + *iterators_invalidated = true; + else + reporter->InternalWarning("Dictionary::Remove() possibly caused iterator invalidation"); + } - detail::DictEntry entry = RemoveRelocateAndAdjust(position); - num_entries--; - ASSERT(num_entries >= 0); - // e is about to be invalid. remove it from all references. - if ( order ) - { - for ( auto it = order->begin(); it != order->end(); ++it ) - { - if ( it->Equal(key, key_size, hash) ) - { - it = order->erase(it); - break; - } - } - } + detail::DictEntry entry = RemoveRelocateAndAdjust(position); + num_entries--; + ASSERT(num_entries >= 0); + // e is about to be invalid. remove it from all references. + if ( order ) { + for ( auto it = order->begin(); it != order->end(); ++it ) { + if ( it->Equal(key, key_size, hash) ) { + it = order->erase(it); + break; + } + } + } - T* v = entry.value; - entry.Clear(); - ASSERT_VALID(this); - return v; - } + T* v = entry.value; + entry.Clear(); + ASSERT_VALID(this); + return v; + } - // TODO: these came from PDict. They could probably be deprecated and removed in favor of - // just using Remove(). - T* RemoveEntry(const detail::HashKey* key, bool* iterators_invalidated = nullptr) - { - return Remove(key->Key(), key->Size(), key->Hash(), false, iterators_invalidated); - } - T* RemoveEntry(const detail::HashKey& key, bool* iterators_invalidated = nullptr) - { - return Remove(key.Key(), key.Size(), key.Hash(), false, iterators_invalidated); - } + // TODO: these came from PDict. They could probably be deprecated and removed in favor of + // just using Remove(). + T* RemoveEntry(const detail::HashKey* key, bool* iterators_invalidated = nullptr) { + return Remove(key->Key(), key->Size(), key->Hash(), false, iterators_invalidated); + } + T* RemoveEntry(const detail::HashKey& key, bool* iterators_invalidated = nullptr) { + return Remove(key.Key(), key.Size(), key.Hash(), false, iterators_invalidated); + } - // Number of entries. - int Length() const { return num_entries; } + // Number of entries. + int Length() const { return num_entries; } - // Largest it's ever been. - int MaxLength() const { return max_entries; } + // Largest it's ever been. + int MaxLength() const { return max_entries; } - // Total number of entries ever. - uint64_t NumCumulativeInserts() const { return cum_entries; } + // Total number of entries ever. + uint64_t NumCumulativeInserts() const { return cum_entries; } - // True if the dictionary is ordered, false otherwise. - int IsOrdered() const { return order != nullptr; } + // True if the dictionary is ordered, false otherwise. + int IsOrdered() const { return order != nullptr; } - // If the dictionary is ordered then returns the n'th entry's value; - // the second method also returns the key. The first entry inserted - // corresponds to n=0. - // - // Returns nil if the dictionary is not ordered or if "n" is out - // of range. - T* NthEntry(int n) const - { - const void* key; - int key_len; - return NthEntry(n, key, key_len); - } + // If the dictionary is ordered then returns the n'th entry's value; + // the second method also returns the key. The first entry inserted + // corresponds to n=0. + // + // Returns nil if the dictionary is not ordered or if "n" is out + // of range. + T* NthEntry(int n) const { + const void* key; + int key_len; + return NthEntry(n, key, key_len); + } - T* NthEntry(int n, const void*& key, int& key_size) const - { - if ( ! order || n < 0 || n >= Length() ) - return nullptr; + T* NthEntry(int n, const void*& key, int& key_size) const { + if ( ! order || n < 0 || n >= Length() ) + return nullptr; - auto& hk = order->at(n); - auto entry = Lookup(&hk); + auto& hk = order->at(n); + auto entry = Lookup(&hk); - key = hk.Key(); - key_size = hk.Size(); - return entry; - } + key = hk.Key(); + key_size = hk.Size(); + return entry; + } - T* NthEntry(int n, const char*& key) const - { - int key_len; - return NthEntry(n, (const void*&)key, key_len); - } + T* NthEntry(int n, const char*& key) const { + int key_len; + return NthEntry(n, (const void*&)key, key_len); + } - void SetDeleteFunc(dict_delete_func f) { delete_func = f; } + void SetDeleteFunc(dict_delete_func f) { delete_func = f; } - // Remove all entries. - void Clear() - { - if ( table ) - { - for ( int i = Capacity() - 1; i >= 0; i-- ) - { - if ( table[i].Empty() ) - continue; - if ( delete_func ) - delete_func(table[i].value); - table[i].Clear(); - } - free(table); - table = nullptr; - } + // Remove all entries. + void Clear() { + if ( table ) { + for ( int i = Capacity() - 1; i >= 0; i-- ) { + if ( table[i].Empty() ) + continue; + if ( delete_func ) + delete_func(table[i].value); + table[i].Clear(); + } + free(table); + table = nullptr; + } - if ( order ) - order.reset(); + if ( order ) + order.reset(); - if ( iterators ) - { - delete iterators; - iterators = nullptr; - } - log2_buckets = 0; - num_iterators = 0; - remaps = 0; - remap_end = -1; - num_entries = 0; - max_entries = 0; - } + if ( iterators ) { + delete iterators; + iterators = nullptr; + } + log2_buckets = 0; + num_iterators = 0; + remaps = 0; + remap_end = -1; + num_entries = 0; + max_entries = 0; + } - /// The capacity of the table, Buckets + Overflow Size. - int Capacity() const { return table ? bucket_capacity : 0; } - int ExpectedCapacity() const { return bucket_capacity; } + /// The capacity of the table, Buckets + Overflow Size. + int Capacity() const { return table ? bucket_capacity : 0; } + int ExpectedCapacity() const { return bucket_capacity; } - // Debugging -#define DUMPIF(f) \ - if ( f ) \ - Dump(1) + // Debugging +#define DUMPIF(f) \ + if ( f ) \ + Dump(1) #ifdef ZEEK_DICT_DEBUG - void AssertValid() const - { - bool valid = true; - int n = num_entries; + void AssertValid() const { + bool valid = true; + int n = num_entries; - if ( table ) - for ( int i = Capacity() - 1; i >= 0; i-- ) - if ( ! table[i].Empty() ) - n--; + if ( table ) + for ( int i = Capacity() - 1; i >= 0; i-- ) + if ( ! table[i].Empty() ) + n--; - valid = (n == 0); - ASSERT(valid); - DUMPIF(! valid); + valid = (n == 0); + ASSERT(valid); + DUMPIF(! valid); - // entries must clustered together - for ( int i = 1; i < Capacity(); i++ ) - { - if ( ! table || table[i].Empty() ) - continue; + // entries must clustered together + for ( int i = 1; i < Capacity(); i++ ) { + if ( ! table || table[i].Empty() ) + continue; - if ( table[i - 1].Empty() ) - { - valid = (table[i].distance == 0); - ASSERT(valid); - DUMPIF(! valid); - } - else - { - valid = (table[i].bucket >= table[i - 1].bucket); - ASSERT(valid); - DUMPIF(! valid); + if ( table[i - 1].Empty() ) { + valid = (table[i].distance == 0); + ASSERT(valid); + DUMPIF(! valid); + } + else { + valid = (table[i].bucket >= table[i - 1].bucket); + ASSERT(valid); + DUMPIF(! valid); - if ( table[i].bucket == table[i - 1].bucket ) - { - valid = (table[i].distance == table[i - 1].distance + 1); - ASSERT(valid); - DUMPIF(! valid); - } - else - { - valid = (table[i].distance <= table[i - 1].distance); - ASSERT(valid); - DUMPIF(! valid); - } - } - } - } + if ( table[i].bucket == table[i - 1].bucket ) { + valid = (table[i].distance == table[i - 1].distance + 1); + ASSERT(valid); + DUMPIF(! valid); + } + else { + valid = (table[i].distance <= table[i - 1].distance); + ASSERT(valid); + DUMPIF(! valid); + } + } + } + } #endif // ZEEK_DICT_DEBUG - void Dump(int level = 0) const - { - int key_size = 0; - for ( int i = 0; i < Capacity(); i++ ) - { - if ( table[i].Empty() ) - continue; - key_size += zeek::util::pad_size(table[i].key_size); - if ( ! table[i].value ) - continue; - } + void Dump(int level = 0) const { + int key_size = 0; + for ( int i = 0; i < Capacity(); i++ ) { + if ( table[i].Empty() ) + continue; + key_size += zeek::util::pad_size(table[i].key_size); + if ( ! table[i].value ) + continue; + } #define DICT_NUM_DISTANCES 5 - int distances[DICT_NUM_DISTANCES]; - int max_distance = 0; - DistanceStats(max_distance, distances, DICT_NUM_DISTANCES); - printf("cap %'7d ent %'7d %'-7d load %.2f max_dist %2d key/ent %3d lg " - "%2d remaps %1d remap_end %4d ", - Capacity(), Length(), MaxLength(), (double)Length() / (table ? Capacity() : 1), - max_distance, key_size / (Length() ? Length() : 1), log2_buckets, remaps, remap_end); - if ( Length() > 0 ) - { - for ( int i = 0; i < DICT_NUM_DISTANCES - 1; i++ ) - printf("[%d]%2d%% ", i, 100 * distances[i] / Length()); - printf("[%d+]%2d%% ", DICT_NUM_DISTANCES - 1, - 100 * distances[DICT_NUM_DISTANCES - 1] / Length()); - } - else - printf("\n"); + int distances[DICT_NUM_DISTANCES]; + int max_distance = 0; + DistanceStats(max_distance, distances, DICT_NUM_DISTANCES); + printf( + "cap %'7d ent %'7d %'-7d load %.2f max_dist %2d key/ent %3d lg " + "%2d remaps %1d remap_end %4d ", + Capacity(), Length(), MaxLength(), (double)Length() / (table ? Capacity() : 1), max_distance, + key_size / (Length() ? Length() : 1), log2_buckets, remaps, remap_end); + if ( Length() > 0 ) { + for ( int i = 0; i < DICT_NUM_DISTANCES - 1; i++ ) + printf("[%d]%2d%% ", i, 100 * distances[i] / Length()); + printf("[%d+]%2d%% ", DICT_NUM_DISTANCES - 1, 100 * distances[DICT_NUM_DISTANCES - 1] / Length()); + } + else + printf("\n"); - printf("\n"); - if ( level >= 1 ) - { - printf("%-10s %1s %-10s %-4s %-4s %-10s %-18s %-2s\n", "Index", "*", "Bucket", "Dist", - "Off", "Hash", "FibHash", "KeySize"); - for ( int i = 0; i < Capacity(); i++ ) - if ( table[i].Empty() ) - printf("%'10d \n", i); - else - printf("%'10d %1s %'10d %4d %4d 0x%08x 0x%016" PRIx64 "(%3d) %2d\n", i, - (i <= remap_end ? "*" : ""), BucketByPosition(i), (int)table[i].distance, - OffsetInClusterByPosition(i), uint(table[i].hash), - FibHash(table[i].hash), (int)FibHash(table[i].hash) & 0xFF, - (int)table[i].key_size); - } - } + printf("\n"); + if ( level >= 1 ) { + printf("%-10s %1s %-10s %-4s %-4s %-10s %-18s %-2s\n", "Index", "*", "Bucket", "Dist", "Off", "Hash", + "FibHash", "KeySize"); + for ( int i = 0; i < Capacity(); i++ ) + if ( table[i].Empty() ) + printf("%'10d \n", i); + else + printf("%'10d %1s %'10d %4d %4d 0x%08x 0x%016" PRIx64 "(%3d) %2d\n", i, (i <= remap_end ? "*" : ""), + BucketByPosition(i), (int)table[i].distance, OffsetInClusterByPosition(i), + uint(table[i].hash), FibHash(table[i].hash), (int)FibHash(table[i].hash) & 0xFF, + (int)table[i].key_size); + } + } - void DistanceStats(int& max_distance, int* distances = 0, int num_distances = 0) const - { - max_distance = 0; - for ( int i = 0; i < num_distances; i++ ) - distances[i] = 0; + void DistanceStats(int& max_distance, int* distances = 0, int num_distances = 0) const { + max_distance = 0; + for ( int i = 0; i < num_distances; i++ ) + distances[i] = 0; - for ( int i = 0; i < Capacity(); i++ ) - { - if ( table[i].Empty() ) - continue; - if ( table[i].distance > max_distance ) - max_distance = table[i].distance; - if ( num_distances <= 0 || ! distances ) - continue; - if ( table[i].distance >= num_distances - 1 ) - distances[num_distances - 1]++; - else - distances[table[i].distance]++; - } - } + for ( int i = 0; i < Capacity(); i++ ) { + if ( table[i].Empty() ) + continue; + if ( table[i].distance > max_distance ) + max_distance = table[i].distance; + if ( num_distances <= 0 || ! distances ) + continue; + if ( table[i].distance >= num_distances - 1 ) + distances[num_distances - 1]++; + else + distances[table[i].distance]++; + } + } - void DumpKeys() const - { - if ( ! table ) - return; + void DumpKeys() const { + if ( ! table ) + return; - char key_file[100]; - // Detect string or binary from first key. - int i = 0; - while ( table[i].Empty() && i < Capacity() ) - i++; + char key_file[100]; + // Detect string or binary from first key. + int i = 0; + while ( table[i].Empty() && i < Capacity() ) + i++; - bool binary = false; - const char* key = table[i].GetKey(); - for ( int j = 0; j < table[i].key_size; j++ ) - if ( ! isprint(key[j]) ) - { - binary = true; - break; - } - int max_distance = 0; + bool binary = false; + const char* key = table[i].GetKey(); + for ( int j = 0; j < table[i].key_size; j++ ) + if ( ! isprint(key[j]) ) { + binary = true; + break; + } + int max_distance = 0; - DistanceStats(max_distance); - if ( binary ) - { - char key = char(random() % 26) + 'A'; - snprintf(key_file, 100, "%d.%d-%c.key", Length(), max_distance, key); - std::ofstream f(key_file, std::ios::binary | std::ios::out | std::ios::trunc); - for ( int idx = 0; idx < Capacity(); idx++ ) - if ( ! table[idx].Empty() ) - { - int key_size = table[idx].key_size; - f.write((const char*)&key_size, sizeof(int)); - f.write(table[idx].GetKey(), table[idx].key_size); - } - } - else - { - char key = char(random() % 26) + 'A'; - snprintf(key_file, 100, "%d.%d-%d.ckey", Length(), max_distance, key); - std::ofstream f(key_file, std::ios::out | std::ios::trunc); - for ( int idx = 0; idx < Capacity(); idx++ ) - if ( ! table[idx].Empty() ) - { - std::string s((char*)table[idx].GetKey(), table[idx].key_size); - f << s << std::endl; - } - } - } + DistanceStats(max_distance); + if ( binary ) { + char key = char(random() % 26) + 'A'; + snprintf(key_file, 100, "%d.%d-%c.key", Length(), max_distance, key); + std::ofstream f(key_file, std::ios::binary | std::ios::out | std::ios::trunc); + for ( int idx = 0; idx < Capacity(); idx++ ) + if ( ! table[idx].Empty() ) { + int key_size = table[idx].key_size; + f.write((const char*)&key_size, sizeof(int)); + f.write(table[idx].GetKey(), table[idx].key_size); + } + } + else { + char key = char(random() % 26) + 'A'; + snprintf(key_file, 100, "%d.%d-%d.ckey", Length(), max_distance, key); + std::ofstream f(key_file, std::ios::out | std::ios::trunc); + for ( int idx = 0; idx < Capacity(); idx++ ) + if ( ! table[idx].Empty() ) { + std::string s((char*)table[idx].GetKey(), table[idx].key_size); + f << s << std::endl; + } + } + } - // Type traits needed for some of the std algorithms to work - using value_type = detail::DictEntry; - using pointer = detail::DictEntry*; - using const_pointer = const detail::DictEntry*; + // Type traits needed for some of the std algorithms to work + using value_type = detail::DictEntry; + using pointer = detail::DictEntry*; + using const_pointer = const detail::DictEntry*; - // Iterator support - using iterator = DictIterator; - using const_iterator = const iterator; - using reverse_iterator = std::reverse_iterator; - using const_reverse_iterator = std::reverse_iterator; + // Iterator support + using iterator = DictIterator; + using const_iterator = const iterator; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; - iterator begin() - { - if ( IsOrdered() ) - return {this, order->begin()}; + iterator begin() { + if ( IsOrdered() ) + return {this, order->begin()}; - return {this, table, table + Capacity()}; - } - iterator end() - { - if ( IsOrdered() ) - return {this, order->end()}; + return {this, table, table + Capacity()}; + } + iterator end() { + if ( IsOrdered() ) + return {this, order->end()}; - return {this, table + Capacity(), table + Capacity()}; - } - const_iterator begin() const - { - if ( IsOrdered() ) - return {this, order->begin()}; + return {this, table + Capacity(), table + Capacity()}; + } + const_iterator begin() const { + if ( IsOrdered() ) + return {this, order->begin()}; - return {this, table, table + Capacity()}; - } - const_iterator end() const - { - if ( IsOrdered() ) - return {this, order->end()}; + return {this, table, table + Capacity()}; + } + const_iterator end() const { + if ( IsOrdered() ) + return {this, order->end()}; - return {this, table + Capacity(), table + Capacity()}; - } - const_iterator cbegin() - { - if ( IsOrdered() ) - return {this, order->begin()}; + return {this, table + Capacity(), table + Capacity()}; + } + const_iterator cbegin() { + if ( IsOrdered() ) + return {this, order->begin()}; - return {this, table, table + Capacity()}; - } - const_iterator cend() - { - if ( IsOrdered() ) - return {this, order->end()}; + return {this, table, table + Capacity()}; + } + const_iterator cend() { + if ( IsOrdered() ) + return {this, order->end()}; - return {this, table + Capacity(), table + Capacity()}; - } + return {this, table + Capacity(), table + Capacity()}; + } - RobustDictIterator begin_robust() { return MakeRobustIterator(); } - RobustDictIterator end_robust() { return RobustDictIterator(); } + RobustDictIterator begin_robust() { return MakeRobustIterator(); } + RobustDictIterator end_robust() { return RobustDictIterator(); } private: - friend zeek::DictIterator; - friend zeek::RobustDictIterator; + friend zeek::DictIterator; + friend zeek::RobustDictIterator; - void SetLog2Buckets(int value) - { - log2_buckets = value; - bucket_count = 1 << log2_buckets; - bucket_capacity = (1 << log2_buckets) + log2_buckets; - } + void SetLog2Buckets(int value) { + log2_buckets = value; + bucket_count = 1 << log2_buckets; + bucket_capacity = (1 << log2_buckets) + log2_buckets; + } - /// Buckets of the table, not including overflow size. - int Buckets() const { return table ? bucket_count : 0; } + /// Buckets of the table, not including overflow size. + int Buckets() const { return table ? bucket_count : 0; } - // bucket math - uint32_t ThresholdEntries() const - { - // Increase the size of the dictionary when it is 75% full. However, when the dictionary - // is small ( bucket_capacity <= 2^3+3=11 elements ), only resize it when it's 100% full. - // The dictionary will always resize when the current insertion causes it to be full. This - // ensures that the current insertion should always be successful. - int capacity = Capacity(); - if ( log2_buckets <= detail::DICT_THRESHOLD_BITS ) - return capacity; - return capacity * detail::DICT_LOAD_FACTOR_100 / 100; - } + // bucket math + uint32_t ThresholdEntries() const { + // Increase the size of the dictionary when it is 75% full. However, when the dictionary + // is small ( bucket_capacity <= 2^3+3=11 elements ), only resize it when it's 100% full. + // The dictionary will always resize when the current insertion causes it to be full. This + // ensures that the current insertion should always be successful. + int capacity = Capacity(); + if ( log2_buckets <= detail::DICT_THRESHOLD_BITS ) + return capacity; + return capacity * detail::DICT_LOAD_FACTOR_100 / 100; + } - // Used to improve the distribution of the original hash. - detail::hash_t FibHash(detail::hash_t h) const - { - // GoldenRatio phi = (sqrt(5)+1)/2 = 1.6180339887... - // 1/phi = phi - 1 - h &= detail::HASH_MASK; - h *= 11400714819323198485llu; // 2^64/phi - return h; - } + // Used to improve the distribution of the original hash. + detail::hash_t FibHash(detail::hash_t h) const { + // GoldenRatio phi = (sqrt(5)+1)/2 = 1.6180339887... + // 1/phi = phi - 1 + h &= detail::HASH_MASK; + h *= 11400714819323198485llu; // 2^64/phi + return h; + } - // Maps a hash to the appropriate n-bit table bucket. - int BucketByHash(detail::hash_t h, int bit) const - { - ASSERT(bit >= 0); - if ( ! bit ) - return 0; //<< >> breaks on 64. + // Maps a hash to the appropriate n-bit table bucket. + int BucketByHash(detail::hash_t h, int bit) const { + ASSERT(bit >= 0); + if ( ! bit ) + return 0; //<< >> breaks on 64. #ifdef DICT_NO_FIB_HASH - detail::hash_t hash = h; + detail::hash_t hash = h; #else - detail::hash_t hash = FibHash(h); + detail::hash_t hash = FibHash(h); #endif - int m = 64 - bit; - hash <<= m; - hash >>= m; + int m = 64 - bit; + hash <<= m; + hash >>= m; - return hash; - } + return hash; + } - // Given a position of a non-empty item in the table, find the related bucket. - int BucketByPosition(int position) const - { - ASSERT(table && position >= 0 && position < Capacity() && ! table[position].Empty()); - return position - table[position].distance; - } + // Given a position of a non-empty item in the table, find the related bucket. + int BucketByPosition(int position) const { + ASSERT(table && position >= 0 && position < Capacity() && ! table[position].Empty()); + return position - table[position].distance; + } - // Given a bucket of a non-empty item in the table, find the end of its cluster. - // The end should be equal to tail+1 if tail exists. Otherwise it's the tail of - // the just-smaller cluster + 1. - int EndOfClusterByBucket(int bucket) const - { - ASSERT(bucket >= 0 && bucket < Buckets()); - int i = bucket; - int current_cap = Capacity(); - while ( i < current_cap && ! table[i].Empty() && BucketByPosition(i) <= bucket ) - i++; - return i; - } + // Given a bucket of a non-empty item in the table, find the end of its cluster. + // The end should be equal to tail+1 if tail exists. Otherwise it's the tail of + // the just-smaller cluster + 1. + int EndOfClusterByBucket(int bucket) const { + ASSERT(bucket >= 0 && bucket < Buckets()); + int i = bucket; + int current_cap = Capacity(); + while ( i < current_cap && ! table[i].Empty() && BucketByPosition(i) <= bucket ) + i++; + return i; + } - // Given a position of a non-empty item in the table, find the head of its cluster. - int HeadOfClusterByPosition(int position) const - { - // Finding the first entry in the bucket chain. - ASSERT(0 <= position && position < Capacity() && ! table[position].Empty()); + // Given a position of a non-empty item in the table, find the head of its cluster. + int HeadOfClusterByPosition(int position) const { + // Finding the first entry in the bucket chain. + ASSERT(0 <= position && position < Capacity() && ! table[position].Empty()); - // Look backward for the first item with the same bucket as myself. - int bucket = BucketByPosition(position); - int i = position; - while ( i >= bucket && BucketByPosition(i) == bucket ) - i--; + // Look backward for the first item with the same bucket as myself. + int bucket = BucketByPosition(position); + int i = position; + while ( i >= bucket && BucketByPosition(i) == bucket ) + i--; - return i == bucket ? i : i + 1; - } + return i == bucket ? i : i + 1; + } - // Given a position of a non-empty item in the table, find the tail of its cluster. - int TailOfClusterByPosition(int position) const - { - ASSERT(0 <= position && position < Capacity() && ! table[position].Empty()); + // Given a position of a non-empty item in the table, find the tail of its cluster. + int TailOfClusterByPosition(int position) const { + ASSERT(0 <= position && position < Capacity() && ! table[position].Empty()); - int bucket = BucketByPosition(position); - int i = position; - int current_cap = Capacity(); - while ( i < current_cap && ! table[i].Empty() && BucketByPosition(i) == bucket ) - i++; // stop just over the tail. + int bucket = BucketByPosition(position); + int i = position; + int current_cap = Capacity(); + while ( i < current_cap && ! table[i].Empty() && BucketByPosition(i) == bucket ) + i++; // stop just over the tail. - return i - 1; - } + return i - 1; + } - // Given a position of a non-empty item in the table, find the end of its cluster. - // The end should be equal to tail+1 if tail exists. Otherwise it's the tail of - // the just-smaller cluster + 1. - int EndOfClusterByPosition(int position) const { return TailOfClusterByPosition(position) + 1; } + // Given a position of a non-empty item in the table, find the end of its cluster. + // The end should be equal to tail+1 if tail exists. Otherwise it's the tail of + // the just-smaller cluster + 1. + int EndOfClusterByPosition(int position) const { return TailOfClusterByPosition(position) + 1; } - // Given a position of a non-empty item in the table, find the offset of it within - // its cluster. - int OffsetInClusterByPosition(int position) const - { - ASSERT(0 <= position && position < Capacity() && ! table[position].Empty()); - int head = HeadOfClusterByPosition(position); - return position - head; - } + // Given a position of a non-empty item in the table, find the offset of it within + // its cluster. + int OffsetInClusterByPosition(int position) const { + ASSERT(0 <= position && position < Capacity() && ! table[position].Empty()); + int head = HeadOfClusterByPosition(position); + return position - head; + } - // Next non-empty item position in the table, starting at the specified position. - int Next(int position) const - { - ASSERT(table && -1 <= position && position < Capacity()); + // Next non-empty item position in the table, starting at the specified position. + int Next(int position) const { + ASSERT(table && -1 <= position && position < Capacity()); - int current_cap = Capacity(); - do - { - position++; - } while ( position < current_cap && table[position].Empty() ); + int current_cap = Capacity(); + do { + position++; + } while ( position < current_cap && table[position].Empty() ); - return position; - } + return position; + } - void Init() - { - ASSERT(! table); - table = (detail::DictEntry*)malloc(sizeof(detail::DictEntry) * ExpectedCapacity()); - for ( int i = Capacity() - 1; i >= 0; i-- ) - table[i].SetEmpty(); - } + void Init() { + ASSERT(! table); + table = (detail::DictEntry*)malloc(sizeof(detail::DictEntry) * ExpectedCapacity()); + for ( int i = Capacity() - 1; i >= 0; i-- ) + table[i].SetEmpty(); + } - // Lookup - int LinearLookupIndex(const void* key, int key_size, detail::hash_t hash) const - { - auto current_cap = Capacity(); - for ( int i = 0; i < current_cap; i++ ) - if ( ! table[i].Empty() && table[i].Equal((const char*)key, key_size, hash) ) - return i; - return -1; - } + // Lookup + int LinearLookupIndex(const void* key, int key_size, detail::hash_t hash) const { + auto current_cap = Capacity(); + for ( int i = 0; i < current_cap; i++ ) + if ( ! table[i].Empty() && table[i].Equal((const char*)key, key_size, hash) ) + return i; + return -1; + } - // Lookup position for all possible table_sizes caused by remapping. Remap it immediately - // if not in the middle of iteration. - int LookupIndex(const void* key, int key_size, detail::hash_t hash, - int* insert_position = nullptr, int* insert_distance = nullptr) - { - ASSERT_VALID(this); - if ( ! table ) - return -1; + // Lookup position for all possible table_sizes caused by remapping. Remap it immediately + // if not in the middle of iteration. + int LookupIndex(const void* key, int key_size, detail::hash_t hash, int* insert_position = nullptr, + int* insert_distance = nullptr) { + ASSERT_VALID(this); + if ( ! table ) + return -1; - int bucket = BucketByHash(hash, log2_buckets); + int bucket = BucketByHash(hash, log2_buckets); #ifdef ZEEK_DICT_DEBUG - int linear_position = LinearLookupIndex(key, key_size, hash); + int linear_position = LinearLookupIndex(key, key_size, hash); #endif // ZEEK_DICT_DEBUG - int position = LookupIndex(key, key_size, hash, bucket, Capacity(), insert_position, - insert_distance); - if ( position >= 0 ) - { - ASSERT_EQUAL(position, linear_position); // same as linearLookup - return position; - } + int position = LookupIndex(key, key_size, hash, bucket, Capacity(), insert_position, insert_distance); + if ( position >= 0 ) { + ASSERT_EQUAL(position, linear_position); // same as linearLookup + return position; + } - for ( int i = 1; i <= remaps; i++ ) - { - int prev_bucket = BucketByHash(hash, log2_buckets - i); - if ( prev_bucket <= remap_end ) - { - // possibly here. insert_position & insert_distance returned on failed lookup is - // not valid in previous table_sizes. - position = LookupIndex(key, key_size, hash, prev_bucket, remap_end + 1); - if ( position >= 0 ) - { - ASSERT_EQUAL(position, linear_position); // same as linearLookup - // remap immediately if no iteration is on. - if ( ! num_iterators ) - { - Remap(position, &position); - ASSERT_EQUAL(position, LookupIndex(key, key_size, hash)); - } - return position; - } - } - } - // not found + for ( int i = 1; i <= remaps; i++ ) { + int prev_bucket = BucketByHash(hash, log2_buckets - i); + if ( prev_bucket <= remap_end ) { + // possibly here. insert_position & insert_distance returned on failed lookup is + // not valid in previous table_sizes. + position = LookupIndex(key, key_size, hash, prev_bucket, remap_end + 1); + if ( position >= 0 ) { + ASSERT_EQUAL(position, linear_position); // same as linearLookup + // remap immediately if no iteration is on. + if ( ! num_iterators ) { + Remap(position, &position); + ASSERT_EQUAL(position, LookupIndex(key, key_size, hash)); + } + return position; + } + } + } + // not found #ifdef ZEEK_DICT_DEBUG - if ( linear_position >= 0 ) - { // different. stop and try to see whats happening. - ASSERT(false); - // rerun the function in debugger to track down the bug. - LookupIndex(key, key_size, hash); - } + if ( linear_position >= 0 ) { // different. stop and try to see whats happening. + ASSERT(false); + // rerun the function in debugger to track down the bug. + LookupIndex(key, key_size, hash); + } #endif // ZEEK_DICT_DEBUG - return -1; - } + return -1; + } - // Returns the position of the item if it exists. Otherwise returns -1, but set the insert - // position/distance if required. The starting point for the search may not be the bucket - // for the current table size since this method is also used to search for an item in the - // previous table size. - int LookupIndex(const void* key, int key_size, detail::hash_t hash, int begin, int end, - int* insert_position = nullptr, int* insert_distance = nullptr) - { - ASSERT(begin >= 0 && begin < Buckets()); - int i = begin; - for ( ; i < end && ! table[i].Empty() && BucketByPosition(i) <= begin; i++ ) - if ( BucketByPosition(i) == begin && table[i].Equal((char*)key, key_size, hash) ) - return i; + // Returns the position of the item if it exists. Otherwise returns -1, but set the insert + // position/distance if required. The starting point for the search may not be the bucket + // for the current table size since this method is also used to search for an item in the + // previous table size. + int LookupIndex(const void* key, int key_size, detail::hash_t hash, int begin, int end, + int* insert_position = nullptr, int* insert_distance = nullptr) { + ASSERT(begin >= 0 && begin < Buckets()); + int i = begin; + for ( ; i < end && ! table[i].Empty() && BucketByPosition(i) <= begin; i++ ) + if ( BucketByPosition(i) == begin && table[i].Equal((char*)key, key_size, hash) ) + return i; - // no such cluster, or not found in the cluster. - if ( insert_position ) - *insert_position = i; + // no such cluster, or not found in the cluster. + if ( insert_position ) + *insert_position = i; - if ( insert_distance ) - { - *insert_distance = i - begin; + if ( insert_distance ) { + *insert_distance = i - begin; - if ( *insert_distance >= detail::TOO_FAR_TO_REACH ) - reporter->FatalErrorWithCore("Dictionary (size %d) insertion distance too far: %d", - Length(), *insert_distance); - } + if ( *insert_distance >= detail::TOO_FAR_TO_REACH ) + reporter->FatalErrorWithCore("Dictionary (size %d) insertion distance too far: %d", Length(), + *insert_distance); + } - return -1; - } + return -1; + } - /// Insert entry, Adjust iterators when necessary. - void InsertRelocateAndAdjust(detail::DictEntry& entry, int insert_position) - { + /// Insert entry, Adjust iterators when necessary. + void InsertRelocateAndAdjust(detail::DictEntry& entry, int insert_position) { /// e.distance is adjusted to be the one at insert_position. #ifdef DEBUG - entry.bucket = BucketByHash(entry.hash, log2_buckets); + entry.bucket = BucketByHash(entry.hash, log2_buckets); #endif // DEBUG - int last_affected_position = insert_position; - InsertAndRelocate(entry, insert_position, &last_affected_position); - space_distance_sum += last_affected_position - insert_position; - space_distance_samples++; + int last_affected_position = insert_position; + InsertAndRelocate(entry, insert_position, &last_affected_position); + space_distance_sum += last_affected_position - insert_position; + space_distance_samples++; - // If remapping in progress, adjust the remap_end to step back a little to cover the new - // range if the changed range straddles over remap_end. - if ( Remapping() && insert_position <= remap_end && remap_end < last_affected_position ) - { //[i,j] range changed. if map_end in between. then possibly old entry pushed down - // across - // map_end. - remap_end = last_affected_position; // adjust to j on the conservative side. - } + // If remapping in progress, adjust the remap_end to step back a little to cover the new + // range if the changed range straddles over remap_end. + if ( Remapping() && insert_position <= remap_end && + remap_end < last_affected_position ) { //[i,j] range changed. if map_end in between. then possibly old + // entry pushed down + // across + // map_end. + remap_end = last_affected_position; // adjust to j on the conservative side. + } - if ( iterators && ! iterators->empty() ) - for ( auto c : *iterators ) - AdjustOnInsert(c, entry, insert_position, last_affected_position); - } + if ( iterators && ! iterators->empty() ) + for ( auto c : *iterators ) + AdjustOnInsert(c, entry, insert_position, last_affected_position); + } - /// insert entry into position, relocate other entries when necessary. - void InsertAndRelocate(detail::DictEntry& entry, int insert_position, - int* last_affected_position = nullptr) - { /// take out the head of cluster and append to the end of the cluster. - while ( true ) - { - if ( insert_position >= Capacity() ) - { - ASSERT(insert_position == Capacity()); - SizeUp(); // copied all the items to new table. as it's just copying without - // remapping, insert_position is now empty. - table[insert_position] = entry; - if ( last_affected_position ) - *last_affected_position = insert_position; - return; - } - if ( table[insert_position].Empty() ) - { // the condition to end the loop. - table[insert_position] = entry; - if ( last_affected_position ) - *last_affected_position = insert_position; - return; - } + /// insert entry into position, relocate other entries when necessary. + void InsertAndRelocate( + detail::DictEntry& entry, int insert_position, + int* last_affected_position = nullptr) { /// take out the head of cluster and append to the end of the cluster. + while ( true ) { + if ( insert_position >= Capacity() ) { + ASSERT(insert_position == Capacity()); + SizeUp(); // copied all the items to new table. as it's just copying without + // remapping, insert_position is now empty. + table[insert_position] = entry; + if ( last_affected_position ) + *last_affected_position = insert_position; + return; + } + if ( table[insert_position].Empty() ) { // the condition to end the loop. + table[insert_position] = entry; + if ( last_affected_position ) + *last_affected_position = insert_position; + return; + } - // the to-be-swapped-out item appends to the end of its original cluster. - auto t = table[insert_position]; - int next = EndOfClusterByPosition(insert_position); - t.distance += next - insert_position; + // the to-be-swapped-out item appends to the end of its original cluster. + auto t = table[insert_position]; + int next = EndOfClusterByPosition(insert_position); + t.distance += next - insert_position; - // swap - table[insert_position] = entry; - entry = t; - insert_position = next; // append to the end of the current cluster. - } - } + // swap + table[insert_position] = entry; + entry = t; + insert_position = next; // append to the end of the current cluster. + } + } - /// Adjust Iterators on Insert. - void AdjustOnInsert(RobustDictIterator* c, const detail::DictEntry& entry, - int insert_position, int last_affected_position) - { - // See note in Dictionary::AdjustOnInsert() above. - c->inserted->erase(std::remove(c->inserted->begin(), c->inserted->end(), entry), - c->inserted->end()); - c->visited->erase(std::remove(c->visited->begin(), c->visited->end(), entry), - c->visited->end()); + /// Adjust Iterators on Insert. + void AdjustOnInsert(RobustDictIterator* c, const detail::DictEntry& entry, int insert_position, + int last_affected_position) { + // See note in Dictionary::AdjustOnInsert() above. + c->inserted->erase(std::remove(c->inserted->begin(), c->inserted->end(), entry), c->inserted->end()); + c->visited->erase(std::remove(c->visited->begin(), c->visited->end(), entry), c->visited->end()); - if ( insert_position < c->next ) - c->inserted->push_back(entry); - if ( insert_position < c->next && c->next <= last_affected_position ) - { - int k = TailOfClusterByPosition(c->next); - ASSERT(k >= 0 && k < Capacity()); - c->visited->push_back(table[k]); - } - } + if ( insert_position < c->next ) + c->inserted->push_back(entry); + if ( insert_position < c->next && c->next <= last_affected_position ) { + int k = TailOfClusterByPosition(c->next); + ASSERT(k >= 0 && k < Capacity()); + c->visited->push_back(table[k]); + } + } - /// Remove, Relocate & Adjust iterators. - detail::DictEntry RemoveRelocateAndAdjust(int position) - { - int last_affected_position = position; - detail::DictEntry entry = RemoveAndRelocate(position, &last_affected_position); + /// Remove, Relocate & Adjust iterators. + detail::DictEntry RemoveRelocateAndAdjust(int position) { + int last_affected_position = position; + detail::DictEntry entry = RemoveAndRelocate(position, &last_affected_position); #ifdef ZEEK_DICT_DEBUG - // validation: index to i-1 should be continuous without empty spaces. - for ( int k = position; k < last_affected_position; k++ ) - ASSERT(! table[k].Empty()); + // validation: index to i-1 should be continuous without empty spaces. + for ( int k = position; k < last_affected_position; k++ ) + ASSERT(! table[k].Empty()); #endif // ZEEK_DICT_DEBUG - if ( iterators && ! iterators->empty() ) - for ( auto c : *iterators ) - AdjustOnRemove(c, entry, position, last_affected_position); + if ( iterators && ! iterators->empty() ) + for ( auto c : *iterators ) + AdjustOnRemove(c, entry, position, last_affected_position); - return entry; - } + return entry; + } - /// Remove & Relocate - detail::DictEntry RemoveAndRelocate(int position, int* last_affected_position = nullptr) - { - // fill the empty position with the tail of the cluster of position+1. - ASSERT(position >= 0 && position < Capacity() && ! table[position].Empty()); + /// Remove & Relocate + detail::DictEntry RemoveAndRelocate(int position, int* last_affected_position = nullptr) { + // fill the empty position with the tail of the cluster of position+1. + ASSERT(position >= 0 && position < Capacity() && ! table[position].Empty()); - detail::DictEntry entry = table[position]; - while ( true ) - { - if ( position == Capacity() - 1 || table[position + 1].Empty() || - table[position + 1].distance == 0 ) - { - // no next cluster to fill, or next position is empty or next position is already in - // perfect bucket. - table[position].SetEmpty(); - if ( last_affected_position ) - *last_affected_position = position; - return entry; - } - int next = TailOfClusterByPosition(position + 1); - table[position] = table[next]; - table[position].distance -= next - position; // distance improved for the item. - position = next; - } + detail::DictEntry entry = table[position]; + while ( true ) { + if ( position == Capacity() - 1 || table[position + 1].Empty() || table[position + 1].distance == 0 ) { + // no next cluster to fill, or next position is empty or next position is already in + // perfect bucket. + table[position].SetEmpty(); + if ( last_affected_position ) + *last_affected_position = position; + return entry; + } + int next = TailOfClusterByPosition(position + 1); + table[position] = table[next]; + table[position].distance -= next - position; // distance improved for the item. + position = next; + } - return entry; - } + return entry; + } - /// Adjust safe iterators after Removal of entry at position. - void AdjustOnRemove(RobustDictIterator* c, const detail::DictEntry& entry, int position, - int last_affected_position) - { - // See note in Dictionary::AdjustOnInsert() above. - c->inserted->erase(std::remove(c->inserted->begin(), c->inserted->end(), entry), - c->inserted->end()); - c->visited->erase(std::remove(c->visited->begin(), c->visited->end(), entry), - c->visited->end()); + /// Adjust safe iterators after Removal of entry at position. + void AdjustOnRemove(RobustDictIterator* c, const detail::DictEntry& entry, int position, + int last_affected_position) { + // See note in Dictionary::AdjustOnInsert() above. + c->inserted->erase(std::remove(c->inserted->begin(), c->inserted->end(), entry), c->inserted->end()); + c->visited->erase(std::remove(c->visited->begin(), c->visited->end(), entry), c->visited->end()); - if ( position < c->next && c->next <= last_affected_position ) - { - int moved = HeadOfClusterByPosition(c->next - 1); - if ( moved < position ) - moved = position; - c->inserted->push_back(table[moved]); - } + if ( position < c->next && c->next <= last_affected_position ) { + int moved = HeadOfClusterByPosition(c->next - 1); + if ( moved < position ) + moved = position; + c->inserted->push_back(table[moved]); + } - // if not already the end of the dictionary, adjust next to a valid one. - if ( c->next < Capacity() && table[c->next].Empty() ) - c->next = Next(c->next); + // if not already the end of the dictionary, adjust next to a valid one. + if ( c->next < Capacity() && table[c->next].Empty() ) + c->next = Next(c->next); - if ( c->curr == entry ) - { - if ( c->next >= 0 && c->next < Capacity() && ! table[c->next].Empty() ) - c->curr = table[c->next]; - else - c->curr = detail::DictEntry(nullptr); // -> c == end_robust() - } - } + if ( c->curr == entry ) { + if ( c->next >= 0 && c->next < Capacity() && ! table[c->next].Empty() ) + c->curr = table[c->next]; + else + c->curr = detail::DictEntry(nullptr); // -> c == end_robust() + } + } - bool Remapping() const { return remap_end >= 0; } // remap in reverse order. + bool Remapping() const { return remap_end >= 0; } // remap in reverse order. - /// One round of remap. - void Remap() - { - /// since remap should be very fast. take more at a time. - /// delay Remap when cookie is there. hard to handle cookie iteration while size changes. - /// remap from bottom up. - /// remap creates two parts of the dict: [0,remap_end] (remap_end, ...]. the former is mixed - /// with old/new entries; the latter contains all new entries. - /// - if ( num_iterators > 0 ) - return; + /// One round of remap. + void Remap() { + /// since remap should be very fast. take more at a time. + /// delay Remap when cookie is there. hard to handle cookie iteration while size changes. + /// remap from bottom up. + /// remap creates two parts of the dict: [0,remap_end] (remap_end, ...]. the former is mixed + /// with old/new entries; the latter contains all new entries. + /// + if ( num_iterators > 0 ) + return; - int left = detail::DICT_REMAP_ENTRIES; - while ( remap_end >= 0 && left > 0 ) - { - if ( ! table[remap_end].Empty() && Remap(remap_end) ) - left--; - else //< successful Remap may increase remap_end in the case of SizeUp due to insert. if - // so, - // remap_end need to be worked on again. - remap_end--; - } - if ( remap_end < 0 ) - remaps = 0; // done remapping. - } + int left = detail::DICT_REMAP_ENTRIES; + while ( remap_end >= 0 && left > 0 ) { + if ( ! table[remap_end].Empty() && Remap(remap_end) ) + left--; + else //< successful Remap may increase remap_end in the case of SizeUp due to insert. if + // so, + // remap_end need to be worked on again. + remap_end--; + } + if ( remap_end < 0 ) + remaps = 0; // done remapping. + } - // Remap an item in position to a new position. Returns true if the relocation was - // successful, false otherwise. new_position will be set to the new position if a - // pointer is provided to store the new value. - bool Remap(int position, int* new_position = nullptr) - { - ASSERT_VALID(this); - /// Remap changes item positions by remove() and insert(). to avoid excessive operation. - /// avoid it when safe iteration is in progress. - ASSERT(! iterators || iterators->empty()); - int current = BucketByPosition(position); // current bucket - int expected = BucketByHash(table[position].hash, log2_buckets); // expected bucket in new - // table. - // equal because 1: it's a new item, 2: it's an old item, but new bucket is the same as old. - // 50% of old items act this way due to fibhash. - if ( current == expected ) - return false; - detail::DictEntry entry = RemoveAndRelocate( - position); // no iteration cookies to adjust, no need for last_affected_position. + // Remap an item in position to a new position. Returns true if the relocation was + // successful, false otherwise. new_position will be set to the new position if a + // pointer is provided to store the new value. + bool Remap(int position, int* new_position = nullptr) { + ASSERT_VALID(this); + /// Remap changes item positions by remove() and insert(). to avoid excessive operation. + /// avoid it when safe iteration is in progress. + ASSERT(! iterators || iterators->empty()); + int current = BucketByPosition(position); // current bucket + int expected = BucketByHash(table[position].hash, log2_buckets); // expected bucket in new + // table. + // equal because 1: it's a new item, 2: it's an old item, but new bucket is the same as old. + // 50% of old items act this way due to fibhash. + if ( current == expected ) + return false; + detail::DictEntry entry = + RemoveAndRelocate(position); // no iteration cookies to adjust, no need for last_affected_position. #ifdef DEBUG - entry.bucket = expected; + entry.bucket = expected; #endif // DEBUG - // find insert position. - int insert_position = EndOfClusterByBucket(expected); - if ( new_position ) - *new_position = insert_position; - entry.distance = insert_position - expected; - InsertAndRelocate( - entry, - insert_position); // no iteration cookies to adjust, no need for last_affected_position. - ASSERT_VALID(this); - return true; - } + // find insert position. + int insert_position = EndOfClusterByBucket(expected); + if ( new_position ) + *new_position = insert_position; + entry.distance = insert_position - expected; + InsertAndRelocate(entry, + insert_position); // no iteration cookies to adjust, no need for last_affected_position. + ASSERT_VALID(this); + return true; + } - void SizeUp() - { - int prev_capacity = Capacity(); - SetLog2Buckets(log2_buckets + 1); + void SizeUp() { + int prev_capacity = Capacity(); + SetLog2Buckets(log2_buckets + 1); - int capacity = Capacity(); - table = (detail::DictEntry*)realloc(table, capacity * sizeof(detail::DictEntry)); - for ( int i = prev_capacity; i < capacity; i++ ) - table[i].SetEmpty(); + int capacity = Capacity(); + table = (detail::DictEntry*)realloc(table, capacity * sizeof(detail::DictEntry)); + for ( int i = prev_capacity; i < capacity; i++ ) + table[i].SetEmpty(); - // REmap from last to first in reverse order. SizeUp can be triggered by 2 conditions, one - // of which is that the last space in the table is occupied and there's nowhere to put new - // items. In this case, the table doubles in capacity and the item is put at the - // prev_capacity position with the old hash. We need to cover this item (?). - remap_end = prev_capacity; // prev_capacity instead of prev_capacity-1. + // REmap from last to first in reverse order. SizeUp can be triggered by 2 conditions, one + // of which is that the last space in the table is occupied and there's nowhere to put new + // items. In this case, the table doubles in capacity and the item is put at the + // prev_capacity position with the old hash. We need to cover this item (?). + remap_end = prev_capacity; // prev_capacity instead of prev_capacity-1. - // another remap starts. - remaps++; // used in Lookup() to cover SizeUp with incomplete remaps. - ASSERT(remaps <= log2_buckets); // because we only sizeUp, one direction. we know the - // previous log2_buckets. - // reset performance metrics. - space_distance_sum = 0; - space_distance_samples = 0; - } + // another remap starts. + remaps++; // used in Lookup() to cover SizeUp with incomplete remaps. + ASSERT(remaps <= log2_buckets); // because we only sizeUp, one direction. we know the + // previous log2_buckets. + // reset performance metrics. + space_distance_sum = 0; + space_distance_samples = 0; + } - /** - * Retrieves a pointer to a full DictEntry in the table based on a hash key. - * - * @param key the key to lookup. - * @return A pointer to the entry or a nullptr if no entry has a matching key. - */ - detail::DictEntry* LookupEntry(const detail::HashKey& key) - { - return LookupEntry(key.Key(), key.Size(), key.Hash()); - } + /** + * Retrieves a pointer to a full DictEntry in the table based on a hash key. + * + * @param key the key to lookup. + * @return A pointer to the entry or a nullptr if no entry has a matching key. + */ + detail::DictEntry* LookupEntry(const detail::HashKey& key) { + return LookupEntry(key.Key(), key.Size(), key.Hash()); + } - /** - * Retrieves a pointer to a full DictEntry in the table based on key data. - * - * @param key the key to lookup - * @param key_size the size of the key data - * @param h a hash of the key data. - * @return A pointer to the entry or a nullptr if no entry has a matching key. - */ - detail::DictEntry* LookupEntry(const void* key, int key_size, detail::hash_t h) const - { - // Look up possibly modifies the entry. Why? if the entry is found but not positioned - // according to the current dict (so it's before SizeUp), it will be moved to the right - // position so next lookup is fast. - Dictionary* d = const_cast(this); - int position = d->LookupIndex(key, key_size, h); - return position >= 0 ? &(table[position]) : nullptr; - } + /** + * Retrieves a pointer to a full DictEntry in the table based on key data. + * + * @param key the key to lookup + * @param key_size the size of the key data + * @param h a hash of the key data. + * @return A pointer to the entry or a nullptr if no entry has a matching key. + */ + detail::DictEntry* LookupEntry(const void* key, int key_size, detail::hash_t h) const { + // Look up possibly modifies the entry. Why? if the entry is found but not positioned + // according to the current dict (so it's before SizeUp), it will be moved to the right + // position so next lookup is fast. + Dictionary* d = const_cast(this); + int position = d->LookupIndex(key, key_size, h); + return position >= 0 ? &(table[position]) : nullptr; + } - bool HaveOnlyRobustIterators() const - { - return (num_iterators == 0) || ((iterators ? iterators->size() : 0) == num_iterators); - } + bool HaveOnlyRobustIterators() const { + return (num_iterators == 0) || ((iterators ? iterators->size() : 0) == num_iterators); + } - RobustDictIterator MakeRobustIterator() - { - if ( IsOrdered() ) - reporter->InternalError( - "RobustIterators are not currently supported for ordered dictionaries"); + RobustDictIterator MakeRobustIterator() { + if ( IsOrdered() ) + reporter->InternalError("RobustIterators are not currently supported for ordered dictionaries"); - if ( ! iterators ) - iterators = new std::vector*>; + if ( ! iterators ) + iterators = new std::vector*>; - return {this}; - } + return {this}; + } - detail::DictEntry GetNextRobustIteration(RobustDictIterator* iter) - { - // If there's no table in the dictionary, then the iterator needs to be - // cleaned up because it's not pointing at anything. - if ( ! table ) - { - iter->Complete(); - return detail::DictEntry(nullptr); // end of iteration - } + detail::DictEntry GetNextRobustIteration(RobustDictIterator* iter) { + // If there's no table in the dictionary, then the iterator needs to be + // cleaned up because it's not pointing at anything. + if ( ! table ) { + iter->Complete(); + return detail::DictEntry(nullptr); // end of iteration + } - // If there are any inserted entries, return them first. - // That keeps the list small and helps avoiding searching - // a large list when deleting an entry. - if ( iter->inserted && ! iter->inserted->empty() ) - { - // Return the last one. Order doesn't matter, - // and removing from the tail is cheaper. - detail::DictEntry e = iter->inserted->back(); - iter->inserted->pop_back(); - return e; - } + // If there are any inserted entries, return them first. + // That keeps the list small and helps avoiding searching + // a large list when deleting an entry. + if ( iter->inserted && ! iter->inserted->empty() ) { + // Return the last one. Order doesn't matter, + // and removing from the tail is cheaper. + detail::DictEntry e = iter->inserted->back(); + iter->inserted->pop_back(); + return e; + } - // First iteration. - if ( iter->next < 0 ) - iter->next = Next(-1); + // First iteration. + if ( iter->next < 0 ) + iter->next = Next(-1); - if ( iter->next < Capacity() && table[iter->next].Empty() ) - { - // [Robin] I believe this means that the table has resized in a way - // that we're now inside the overflow area where elements are empty, - // because elsewhere empty slots aren't allowed. Assuming that's right, - // then it means we'll always be at the end of the table now and could - // also just set `next` to capacity. However, just to be sure, we - // instead reuse logic from below to move forward "to a valid position" - // and then double check, through an assertion in debug mode, that it's - // actually the end. If this ever triggered, the above assumption would - // be wrong (but the Next() call would probably still be right). - iter->next = Next(iter->next); - ASSERT(iter->next == Capacity()); - } + if ( iter->next < Capacity() && table[iter->next].Empty() ) { + // [Robin] I believe this means that the table has resized in a way + // that we're now inside the overflow area where elements are empty, + // because elsewhere empty slots aren't allowed. Assuming that's right, + // then it means we'll always be at the end of the table now and could + // also just set `next` to capacity. However, just to be sure, we + // instead reuse logic from below to move forward "to a valid position" + // and then double check, through an assertion in debug mode, that it's + // actually the end. If this ever triggered, the above assumption would + // be wrong (but the Next() call would probably still be right). + iter->next = Next(iter->next); + ASSERT(iter->next == Capacity()); + } - // Filter out visited keys. - int capacity = Capacity(); - if ( iter->visited && ! iter->visited->empty() ) - // Filter out visited entries. - while ( iter->next < capacity ) - { - ASSERT(! table[iter->next].Empty()); - auto it = std::find(iter->visited->begin(), iter->visited->end(), - table[iter->next]); - if ( it == iter->visited->end() ) - break; - iter->visited->erase(it); - iter->next = Next(iter->next); - } + // Filter out visited keys. + int capacity = Capacity(); + if ( iter->visited && ! iter->visited->empty() ) + // Filter out visited entries. + while ( iter->next < capacity ) { + ASSERT(! table[iter->next].Empty()); + auto it = std::find(iter->visited->begin(), iter->visited->end(), table[iter->next]); + if ( it == iter->visited->end() ) + break; + iter->visited->erase(it); + iter->next = Next(iter->next); + } - if ( iter->next >= capacity ) - { - iter->Complete(); - return detail::DictEntry(nullptr); // end of iteration - } + if ( iter->next >= capacity ) { + iter->Complete(); + return detail::DictEntry(nullptr); // end of iteration + } - ASSERT(! table[iter->next].Empty()); - detail::DictEntry e = table[iter->next]; + ASSERT(! table[iter->next].Empty()); + detail::DictEntry e = table[iter->next]; - // prepare for next time. - iter->next = Next(iter->next); - return e; - } + // prepare for next time. + iter->next = Next(iter->next); + return e; + } - void IncrIters() { ++num_iterators; } - void DecrIters() { --num_iterators; } + void IncrIters() { ++num_iterators; } + void DecrIters() { --num_iterators; } - // aligned on 8-bytes with 4-leading bytes. 7*8=56 bytes a dictionary. + // aligned on 8-bytes with 4-leading bytes. 7*8=56 bytes a dictionary. - // when sizeup but the current mapping is in progress. the current mapping will be ignored - // as it will be remapped to new dict size anyway. however, the missed count is recorded - // for lookup. if position not found for a key in the position of dict of current size, it - // still could be in the position of dict of previous N sizes. - uint16_t remaps = 0; - uint16_t log2_buckets = 0; - uint32_t bucket_capacity = 1; - uint32_t bucket_count = 1; + // when sizeup but the current mapping is in progress. the current mapping will be ignored + // as it will be remapped to new dict size anyway. however, the missed count is recorded + // for lookup. if position not found for a key in the position of dict of current size, it + // still could be in the position of dict of previous N sizes. + uint16_t remaps = 0; + uint16_t log2_buckets = 0; + uint32_t bucket_capacity = 1; + uint32_t bucket_count = 1; - // Pending number of iterators on the Dict, including both robust and non-robust. - // This is used to avoid remapping if there are any active iterators. - uint16_t num_iterators = 0; + // Pending number of iterators on the Dict, including both robust and non-robust. + // This is used to avoid remapping if there are any active iterators. + uint16_t num_iterators = 0; - // The last index to be remapped. - int32_t remap_end = -1; + // The last index to be remapped. + int32_t remap_end = -1; - uint32_t num_entries = 0; - uint32_t max_entries = 0; - uint64_t cum_entries = 0; - uint32_t space_distance_samples = 0; - // how far the space is - int64_t space_distance_sum = 0; + uint32_t num_entries = 0; + uint32_t max_entries = 0; + uint64_t cum_entries = 0; + uint32_t space_distance_samples = 0; + // how far the space is + int64_t space_distance_sum = 0; - dict_delete_func delete_func = nullptr; - detail::DictEntry* table = nullptr; - std::vector*>* iterators = nullptr; + dict_delete_func delete_func = nullptr; + detail::DictEntry* table = nullptr; + std::vector*>* iterators = nullptr; - // Ordered dictionaries keep the order based on some criteria, by default the order of - // insertion. We only store a copy of the keys here for memory savings and for safety - // around reallocs and such. - std::unique_ptr order; - }; + // Ordered dictionaries keep the order based on some criteria, by default the order of + // insertion. We only store a copy of the keys here for memory savings and for safety + // around reallocs and such. + std::unique_ptr order; +}; -template using PDict = Dictionary; +template +using PDict = Dictionary; - } // namespace zeek +} // namespace zeek diff --git a/src/Discard.cc b/src/Discard.cc index d55c718b52..a1cfbb9ece 100644 --- a/src/Discard.cc +++ b/src/Discard.cc @@ -14,153 +14,130 @@ #include "zeek/Var.h" #include "zeek/ZeekString.h" -namespace zeek::detail - { +namespace zeek::detail { -Discarder::Discarder() - { - check_ip = id::find_func("discarder_check_ip"); - check_tcp = id::find_func("discarder_check_tcp"); - check_udp = id::find_func("discarder_check_udp"); - check_icmp = id::find_func("discarder_check_icmp"); +Discarder::Discarder() { + check_ip = id::find_func("discarder_check_ip"); + check_tcp = id::find_func("discarder_check_tcp"); + check_udp = id::find_func("discarder_check_udp"); + check_icmp = id::find_func("discarder_check_icmp"); - discarder_maxlen = static_cast(id::find_val("discarder_maxlen")->AsCount()); - } + discarder_maxlen = static_cast(id::find_val("discarder_maxlen")->AsCount()); +} -bool Discarder::IsActive() - { - return check_ip || check_tcp || check_udp || check_icmp; - } +bool Discarder::IsActive() { return check_ip || check_tcp || check_udp || check_icmp; } -bool Discarder::NextPacket(const std::shared_ptr& ip, int len, int caplen) - { - bool discard_packet = false; +bool Discarder::NextPacket(const std::shared_ptr& ip, int len, int caplen) { + bool discard_packet = false; - if ( check_ip ) - { - zeek::Args args{ip->ToPktHdrVal()}; + if ( check_ip ) { + zeek::Args args{ip->ToPktHdrVal()}; - try - { - discard_packet = check_ip->Invoke(&args)->AsBool(); - } + try { + discard_packet = check_ip->Invoke(&args)->AsBool(); + } - catch ( InterpreterException& e ) - { - discard_packet = false; - } + catch ( InterpreterException& e ) { + discard_packet = false; + } - if ( discard_packet ) - return discard_packet; - } + if ( discard_packet ) + return discard_packet; + } - int proto = ip->NextProto(); - if ( proto != IPPROTO_TCP && proto != IPPROTO_UDP && proto != IPPROTO_ICMP ) - // This is not a protocol we understand. - return false; + int proto = ip->NextProto(); + if ( proto != IPPROTO_TCP && proto != IPPROTO_UDP && proto != IPPROTO_ICMP ) + // This is not a protocol we understand. + return false; - // XXX shall we only check the first packet??? - if ( ip->IsFragment() ) - // Never check any fragment. - return false; + // XXX shall we only check the first packet??? + if ( ip->IsFragment() ) + // Never check any fragment. + return false; - int ip_hdr_len = ip->HdrLen(); - len -= ip_hdr_len; // remove IP header - caplen -= ip_hdr_len; + int ip_hdr_len = ip->HdrLen(); + len -= ip_hdr_len; // remove IP header + caplen -= ip_hdr_len; - bool is_tcp = (proto == IPPROTO_TCP); - bool is_udp = (proto == IPPROTO_UDP); - int min_hdr_len = is_tcp ? sizeof(struct tcphdr) - : (is_udp ? sizeof(struct udphdr) : sizeof(struct icmp)); + bool is_tcp = (proto == IPPROTO_TCP); + bool is_udp = (proto == IPPROTO_UDP); + int min_hdr_len = is_tcp ? sizeof(struct tcphdr) : (is_udp ? sizeof(struct udphdr) : sizeof(struct icmp)); - if ( len < min_hdr_len || caplen < min_hdr_len ) - // we don't have a complete protocol header - return false; + if ( len < min_hdr_len || caplen < min_hdr_len ) + // we don't have a complete protocol header + return false; - // Where the data starts - if this is a protocol we know about, - // this gets advanced past the transport header. - const u_char* data = ip->Payload(); + // Where the data starts - if this is a protocol we know about, + // this gets advanced past the transport header. + const u_char* data = ip->Payload(); - if ( is_tcp ) - { - if ( check_tcp ) - { - const struct tcphdr* tp = (const struct tcphdr*)data; - int th_len = tp->th_off * 4; + if ( is_tcp ) { + if ( check_tcp ) { + const struct tcphdr* tp = (const struct tcphdr*)data; + int th_len = tp->th_off * 4; - zeek::Args args{ - ip->ToPktHdrVal(), - {AdoptRef{}, BuildData(data, th_len, len, caplen)}, - }; + zeek::Args args{ + ip->ToPktHdrVal(), + {AdoptRef{}, BuildData(data, th_len, len, caplen)}, + }; - try - { - discard_packet = check_tcp->Invoke(&args)->AsBool(); - } + try { + discard_packet = check_tcp->Invoke(&args)->AsBool(); + } - catch ( InterpreterException& e ) - { - discard_packet = false; - } - } - } + catch ( InterpreterException& e ) { + discard_packet = false; + } + } + } - else if ( is_udp ) - { - if ( check_udp ) - { - const struct udphdr* up = (const struct udphdr*)data; - int uh_len = sizeof(struct udphdr); + else if ( is_udp ) { + if ( check_udp ) { + const struct udphdr* up = (const struct udphdr*)data; + int uh_len = sizeof(struct udphdr); - zeek::Args args{ - ip->ToPktHdrVal(), - {AdoptRef{}, BuildData(data, uh_len, len, caplen)}, - }; + zeek::Args args{ + ip->ToPktHdrVal(), + {AdoptRef{}, BuildData(data, uh_len, len, caplen)}, + }; - try - { - discard_packet = check_udp->Invoke(&args)->AsBool(); - } + try { + discard_packet = check_udp->Invoke(&args)->AsBool(); + } - catch ( InterpreterException& e ) - { - discard_packet = false; - } - } - } + catch ( InterpreterException& e ) { + discard_packet = false; + } + } + } - else - { - if ( check_icmp ) - { - const struct icmp* ih = (const struct icmp*)data; + else { + if ( check_icmp ) { + const struct icmp* ih = (const struct icmp*)data; - zeek::Args args{ip->ToPktHdrVal()}; + zeek::Args args{ip->ToPktHdrVal()}; - try - { - discard_packet = check_icmp->Invoke(&args)->AsBool(); - } + try { + discard_packet = check_icmp->Invoke(&args)->AsBool(); + } - catch ( InterpreterException& e ) - { - discard_packet = false; - } - } - } + catch ( InterpreterException& e ) { + discard_packet = false; + } + } + } - return discard_packet; - } + return discard_packet; +} -Val* Discarder::BuildData(const u_char* data, int hdrlen, int len, int caplen) - { - len -= hdrlen; - caplen -= hdrlen; - data += hdrlen; +Val* Discarder::BuildData(const u_char* data, int hdrlen, int len, int caplen) { + len -= hdrlen; + caplen -= hdrlen; + data += hdrlen; - len = std::max(std::min(std::min(len, caplen), discarder_maxlen), 0); + len = std::max(std::min(std::min(len, caplen), discarder_maxlen), 0); - return new StringVal(new String(data, len, true)); - } + return new StringVal(new String(data, len, true)); +} - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/Discard.h b/src/Discard.h index 2aa3dc4bf8..9049516329 100644 --- a/src/Discard.h +++ b/src/Discard.h @@ -7,38 +7,35 @@ #include "zeek/IntrusivePtr.h" -namespace zeek - { +namespace zeek { class IP_Hdr; class Val; class Func; using FuncPtr = IntrusivePtr; -namespace detail - { +namespace detail { -class Discarder final - { +class Discarder final { public: - Discarder(); - ~Discarder() = default; + Discarder(); + ~Discarder() = default; - bool IsActive(); + bool IsActive(); - bool NextPacket(const std::shared_ptr& ip, int len, int caplen); + bool NextPacket(const std::shared_ptr& ip, int len, int caplen); protected: - Val* BuildData(const u_char* data, int hdrlen, int len, int caplen); + Val* BuildData(const u_char* data, int hdrlen, int len, int caplen); - FuncPtr check_ip; - FuncPtr check_tcp; - FuncPtr check_udp; - FuncPtr check_icmp; + FuncPtr check_ip; + FuncPtr check_tcp; + FuncPtr check_udp; + FuncPtr check_icmp; - // Maximum amount of application data passed to filtering functions. - int discarder_maxlen; - }; + // Maximum amount of application data passed to filtering functions. + int discarder_maxlen; +}; - } // namespace detail - } // namespace zeek +} // namespace detail +} // namespace zeek diff --git a/src/EquivClass.cc b/src/EquivClass.cc index 7fba212541..2e9482b621 100644 --- a/src/EquivClass.cc +++ b/src/EquivClass.cc @@ -7,190 +7,168 @@ #include "zeek/CCL.h" #include "zeek/util.h" -namespace zeek::detail - { +namespace zeek::detail { -EquivClass::EquivClass(int arg_size) - { - size = arg_size; - fwd = new int[size]; - bck = new int[size]; - equiv_class = new int[size]; - rep = new int[size]; - ccl_flags = nullptr; - num_ecs = 0; +EquivClass::EquivClass(int arg_size) { + size = arg_size; + fwd = new int[size]; + bck = new int[size]; + equiv_class = new int[size]; + rep = new int[size]; + ccl_flags = nullptr; + num_ecs = 0; - ec_nil = no_class = no_rep = size + 1; + ec_nil = no_class = no_rep = size + 1; - bck[0] = ec_nil; - fwd[size - 1] = ec_nil; + bck[0] = ec_nil; + fwd[size - 1] = ec_nil; - for ( int i = 0; i < size; ++i ) - { - if ( i > 0 ) - { - fwd[i - 1] = i; - bck[i] = i - 1; - } + for ( int i = 0; i < size; ++i ) { + if ( i > 0 ) { + fwd[i - 1] = i; + bck[i] = i - 1; + } - equiv_class[i] = no_class; - rep[i] = no_rep; - } - } + equiv_class[i] = no_class; + rep[i] = no_rep; + } +} -EquivClass::~EquivClass() - { - delete[] fwd; - delete[] bck; - delete[] equiv_class; - delete[] rep; - delete[] ccl_flags; - } +EquivClass::~EquivClass() { + delete[] fwd; + delete[] bck; + delete[] equiv_class; + delete[] rep; + delete[] ccl_flags; +} -void EquivClass::ConvertCCL(CCL* ccl) - { - // For each character in the class, add the character's - // equivalence class to the new "character" class we are - // creating. Thus when we are all done, the character class - // will really consist of collections of equivalence classes - // instead of collections of characters. +void EquivClass::ConvertCCL(CCL* ccl) { + // For each character in the class, add the character's + // equivalence class to the new "character" class we are + // creating. Thus when we are all done, the character class + // will really consist of collections of equivalence classes + // instead of collections of characters. - int_list* c_syms = ccl->Syms(); - int_list* new_syms = new int_list; + int_list* c_syms = ccl->Syms(); + int_list* new_syms = new int_list; - for ( auto sym : *c_syms ) - { - if ( IsRep(sym) ) - new_syms->push_back(SymEquivClass(sym)); - } + for ( auto sym : *c_syms ) { + if ( IsRep(sym) ) + new_syms->push_back(SymEquivClass(sym)); + } - ccl->ReplaceSyms(new_syms); - } + ccl->ReplaceSyms(new_syms); +} -int EquivClass::BuildECs() - { - // Create equivalence class numbers. If bck[x] is nil, - // then x is the representative of its equivalence class. +int EquivClass::BuildECs() { + // Create equivalence class numbers. If bck[x] is nil, + // then x is the representative of its equivalence class. - for ( int i = 0; i < size; ++i ) - if ( bck[i] == ec_nil ) - { - equiv_class[i] = num_ecs++; - rep[i] = i; - for ( int j = fwd[i]; j != ec_nil; j = fwd[j] ) - { - equiv_class[j] = equiv_class[i]; - rep[j] = i; - } - } + for ( int i = 0; i < size; ++i ) + if ( bck[i] == ec_nil ) { + equiv_class[i] = num_ecs++; + rep[i] = i; + for ( int j = fwd[i]; j != ec_nil; j = fwd[j] ) { + equiv_class[j] = equiv_class[i]; + rep[j] = i; + } + } - return num_ecs; - } + return num_ecs; +} -void EquivClass::CCL_Use(CCL* ccl) - { - // Note that it doesn't matter whether or not the character class is - // negated. The same results will be obtained in either case. +void EquivClass::CCL_Use(CCL* ccl) { + // Note that it doesn't matter whether or not the character class is + // negated. The same results will be obtained in either case. - if ( ! ccl_flags ) - { - ccl_flags = new int[size]; - for ( int i = 0; i < size; ++i ) - ccl_flags[i] = 0; - } + if ( ! ccl_flags ) { + ccl_flags = new int[size]; + for ( int i = 0; i < size; ++i ) + ccl_flags[i] = 0; + } - int_list* csyms = ccl->Syms(); - for ( size_t i = 0; i < csyms->size(); /* no increment */ ) - { - int sym = (*csyms)[i]; + int_list* csyms = ccl->Syms(); + for ( size_t i = 0; i < csyms->size(); /* no increment */ ) { + int sym = (*csyms)[i]; - int old_ec = bck[sym]; - int new_ec = sym; + int old_ec = bck[sym]; + int new_ec = sym; - size_t j = i + 1; + size_t j = i + 1; - for ( int k = fwd[sym]; k && k < size; k = fwd[k] ) - { // look for the symbol in the character class - for ( ; j < csyms->size(); ++j ) - { - if ( (*csyms)[j] > k ) - // Since the character class is sorted, - // we can stop. - break; + for ( int k = fwd[sym]; k && k < size; k = fwd[k] ) { // look for the symbol in the character class + for ( ; j < csyms->size(); ++j ) { + if ( (*csyms)[j] > k ) + // Since the character class is sorted, + // we can stop. + break; - if ( (*csyms)[j] == k && ! ccl_flags[j] ) - { - // We found an old companion of sym - // in the ccl. Link it into the new - // equivalence class and flag it as - // having been processed. - bck[k] = new_ec; - fwd[new_ec] = k; - new_ec = k; + if ( (*csyms)[j] == k && ! ccl_flags[j] ) { + // We found an old companion of sym + // in the ccl. Link it into the new + // equivalence class and flag it as + // having been processed. + bck[k] = new_ec; + fwd[new_ec] = k; + new_ec = k; - // Set flag so we don't reprocess. - ccl_flags[j] = 1; + // Set flag so we don't reprocess. + ccl_flags[j] = 1; - // Get next equivalence class member. - break; - } - } + // Get next equivalence class member. + break; + } + } - if ( j < csyms->size() && (*csyms)[j] == k ) - // We broke out of the above loop by finding - // an old companion - go to the next symbol. - continue; + if ( j < csyms->size() && (*csyms)[j] == k ) + // We broke out of the above loop by finding + // an old companion - go to the next symbol. + continue; - // Symbol isn't in character class. Put it in the old - // equivalence class. - bck[k] = old_ec; - if ( old_ec != ec_nil ) - fwd[old_ec] = k; + // Symbol isn't in character class. Put it in the old + // equivalence class. + bck[k] = old_ec; + if ( old_ec != ec_nil ) + fwd[old_ec] = k; - old_ec = k; - } + old_ec = k; + } - if ( bck[sym] != ec_nil || old_ec != bck[sym] ) - { - bck[sym] = ec_nil; - fwd[old_ec] = ec_nil; - } + if ( bck[sym] != ec_nil || old_ec != bck[sym] ) { + bck[sym] = ec_nil; + fwd[old_ec] = ec_nil; + } - fwd[new_ec] = ec_nil; + fwd[new_ec] = ec_nil; - // Find next ccl member to process. - for ( ++i; i < csyms->size() && ccl_flags[i]; ++i ) - // Reset "doesn't need processing" flag. - ccl_flags[i] = 0; - } - } + // Find next ccl member to process. + for ( ++i; i < csyms->size() && ccl_flags[i]; ++i ) + // Reset "doesn't need processing" flag. + ccl_flags[i] = 0; + } +} -void EquivClass::UniqueChar(int sym) - { - // If until now the character has been a proper subset of - // an equivalence class, break it away to create a new ec. +void EquivClass::UniqueChar(int sym) { + // If until now the character has been a proper subset of + // an equivalence class, break it away to create a new ec. - if ( fwd[sym] != ec_nil ) - bck[fwd[sym]] = bck[sym]; + if ( fwd[sym] != ec_nil ) + bck[fwd[sym]] = bck[sym]; - if ( bck[sym] != ec_nil ) - fwd[bck[sym]] = fwd[sym]; + if ( bck[sym] != ec_nil ) + fwd[bck[sym]] = fwd[sym]; - fwd[sym] = ec_nil; - bck[sym] = ec_nil; - } + fwd[sym] = ec_nil; + bck[sym] = ec_nil; +} -void EquivClass::Dump(FILE* f) - { - fprintf(f, "%d symbols in EC yielded %d ecs\n", size, num_ecs); - for ( int i = 0; i < size; ++i ) - if ( SymEquivClass(i) != 0 ) // skip usually huge default ec - fprintf(f, "map %d ('%c') -> %d\n", i, i, SymEquivClass(i)); - } +void EquivClass::Dump(FILE* f) { + fprintf(f, "%d symbols in EC yielded %d ecs\n", size, num_ecs); + for ( int i = 0; i < size; ++i ) + if ( SymEquivClass(i) != 0 ) // skip usually huge default ec + fprintf(f, "map %d ('%c') -> %d\n", i, i, SymEquivClass(i)); +} -int EquivClass::Size() const - { - return padded_sizeof(*this) + util::pad_size(sizeof(int) * size * (ccl_flags ? 5 : 4)); - } +int EquivClass::Size() const { return padded_sizeof(*this) + util::pad_size(sizeof(int) * size * (ccl_flags ? 5 : 4)); } - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/EquivClass.h b/src/EquivClass.h index 272f6b83c7..c2a3e5fad6 100644 --- a/src/EquivClass.h +++ b/src/EquivClass.h @@ -4,46 +4,44 @@ #include -namespace zeek::detail - { +namespace zeek::detail { class CCL; -class EquivClass - { +class EquivClass { public: - explicit EquivClass(int size); - ~EquivClass(); + explicit EquivClass(int size); + ~EquivClass(); - void UniqueChar(int sym); - void CCL_Use(CCL* ccl); + void UniqueChar(int sym); + void CCL_Use(CCL* ccl); - // All done adding character usage info - generate equivalence - // classes. Returns number of classes. - int BuildECs(); + // All done adding character usage info - generate equivalence + // classes. Returns number of classes. + int BuildECs(); - void ConvertCCL(CCL* ccl); + void ConvertCCL(CCL* ccl); - bool IsRep(int sym) const { return rep[sym] == sym; } - int EquivRep(int sym) const { return rep[sym]; } - int SymEquivClass(int sym) const { return equiv_class[sym]; } - int* EquivClasses() const { return equiv_class; } + bool IsRep(int sym) const { return rep[sym] == sym; } + int EquivRep(int sym) const { return rep[sym]; } + int SymEquivClass(int sym) const { return equiv_class[sym]; } + int* EquivClasses() const { return equiv_class; } - int NumSyms() const { return size; } - int NumClasses() const { return num_ecs; } + int NumSyms() const { return size; } + int NumClasses() const { return num_ecs; } - void Dump(FILE* f); - int Size() const; + void Dump(FILE* f); + int Size() const; protected: - int size; // size of character set - int num_ecs; // size of equivalence classes - int* fwd; // forward list of different classes - int* bck; // backward list - int* equiv_class; // symbol's equivalence class - int* rep; // representative for symbol's equivalence class - int* ccl_flags; - int ec_nil, no_class, no_rep; - }; + int size; // size of character set + int num_ecs; // size of equivalence classes + int* fwd; // forward list of different classes + int* bck; // backward list + int* equiv_class; // symbol's equivalence class + int* rep; // representative for symbol's equivalence class + int* ccl_flags; + int ec_nil, no_class, no_rep; +}; - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/Event.cc b/src/Event.cc index f382163108..ead50a1ad8 100644 --- a/src/Event.cc +++ b/src/Event.cc @@ -15,204 +15,187 @@ zeek::EventMgr zeek::event_mgr; -namespace zeek - { +namespace zeek { -Event::Event(const EventHandlerPtr& arg_handler, zeek::Args arg_args, - util::detail::SourceID arg_src, analyzer::ID arg_aid, Obj* arg_obj, double arg_ts) - : handler(arg_handler), args(std::move(arg_args)), src(arg_src), aid(arg_aid), ts(arg_ts), - obj(arg_obj), next_event(nullptr) - { - if ( obj ) - Ref(obj); - } +Event::Event(const EventHandlerPtr& arg_handler, zeek::Args arg_args, util::detail::SourceID arg_src, + analyzer::ID arg_aid, Obj* arg_obj, double arg_ts) + : handler(arg_handler), + args(std::move(arg_args)), + src(arg_src), + aid(arg_aid), + ts(arg_ts), + obj(arg_obj), + next_event(nullptr) { + if ( obj ) + Ref(obj); +} -void Event::Describe(ODesc* d) const - { - if ( d->IsReadable() ) - d->AddSP("event"); +void Event::Describe(ODesc* d) const { + if ( d->IsReadable() ) + d->AddSP("event"); - bool s = d->IsShort(); - d->SetShort(s); + bool s = d->IsShort(); + d->SetShort(s); - if ( ! d->IsBinary() ) - d->Add("("); - describe_vals(args, d); - if ( ! d->IsBinary() ) - d->Add("("); - } + if ( ! d->IsBinary() ) + d->Add("("); + describe_vals(args, d); + if ( ! d->IsBinary() ) + d->Add("("); +} -void Event::Dispatch(bool no_remote) - { - if ( src == util::detail::SOURCE_BROKER ) - no_remote = true; +void Event::Dispatch(bool no_remote) { + if ( src == util::detail::SOURCE_BROKER ) + no_remote = true; - if ( handler->ErrorHandler() ) - reporter->BeginErrorHandler(); + if ( handler->ErrorHandler() ) + reporter->BeginErrorHandler(); - try - { - handler->Call(&args, no_remote, ts); - } + try { + handler->Call(&args, no_remote, ts); + } - catch ( InterpreterException& e ) - { - // Already reported. - } + catch ( InterpreterException& e ) { + // Already reported. + } - if ( obj ) - // obj->EventDone(); - Unref(obj); + if ( obj ) + // obj->EventDone(); + Unref(obj); - if ( handler->ErrorHandler() ) - reporter->EndErrorHandler(); - } + if ( handler->ErrorHandler() ) + reporter->EndErrorHandler(); +} -EventMgr::EventMgr() - { - head = tail = nullptr; - current_src = util::detail::SOURCE_LOCAL; - current_aid = 0; - current_ts = 0; - src_val = nullptr; - draining = false; - } +EventMgr::EventMgr() { + head = tail = nullptr; + current_src = util::detail::SOURCE_LOCAL; + current_aid = 0; + current_ts = 0; + src_val = nullptr; + draining = false; +} -EventMgr::~EventMgr() - { - while ( head ) - { - Event* n = head->NextEvent(); - Unref(head); - head = n; - } +EventMgr::~EventMgr() { + while ( head ) { + Event* n = head->NextEvent(); + Unref(head); + head = n; + } - Unref(src_val); - } + Unref(src_val); +} -void EventMgr::Enqueue(const EventHandlerPtr& h, Args vl, util::detail::SourceID src, - analyzer::ID aid, Obj* obj, double ts) - { - QueueEvent(new Event(h, std::move(vl), src, aid, obj, ts)); - } +void EventMgr::Enqueue(const EventHandlerPtr& h, Args vl, util::detail::SourceID src, analyzer::ID aid, Obj* obj, + double ts) { + QueueEvent(new Event(h, std::move(vl), src, aid, obj, ts)); +} -void EventMgr::QueueEvent(Event* event) - { - bool done = PLUGIN_HOOK_WITH_RESULT(HOOK_QUEUE_EVENT, HookQueueEvent(event), false); +void EventMgr::QueueEvent(Event* event) { + bool done = PLUGIN_HOOK_WITH_RESULT(HOOK_QUEUE_EVENT, HookQueueEvent(event), false); - if ( done ) - return; + if ( done ) + return; - if ( ! head ) - { - head = tail = event; - queue_flare.Fire(); - } - else - { - tail->SetNext(event); - tail = event; - } + if ( ! head ) { + head = tail = event; + queue_flare.Fire(); + } + else { + tail->SetNext(event); + tail = event; + } - ++event_mgr.num_events_queued; - } + ++event_mgr.num_events_queued; +} -void EventMgr::Dispatch(Event* event, bool no_remote) - { - current_src = event->Source(); - current_aid = event->Analyzer(); - current_ts = event->Time(); - event->Dispatch(no_remote); - Unref(event); - } +void EventMgr::Dispatch(Event* event, bool no_remote) { + current_src = event->Source(); + current_aid = event->Analyzer(); + current_ts = event->Time(); + event->Dispatch(no_remote); + Unref(event); +} -void EventMgr::Drain() - { - if ( event_queue_flush_point ) - Enqueue(event_queue_flush_point, Args{}); +void EventMgr::Drain() { + if ( event_queue_flush_point ) + Enqueue(event_queue_flush_point, Args{}); - detail::SegmentProfiler prof(detail::segment_logger, "draining-events"); + detail::SegmentProfiler prof(detail::segment_logger, "draining-events"); - PLUGIN_HOOK_VOID(HOOK_DRAIN_EVENTS, HookDrainEvents()); + PLUGIN_HOOK_VOID(HOOK_DRAIN_EVENTS, HookDrainEvents()); - draining = true; + draining = true; - // Past Zeek versions drained as long as there events, including when - // a handler queued new events during its execution. This could lead - // to endless loops in case a handler kept triggering its own event. - // We now limit this to just a couple of rounds. We do more than - // just one round to make it less likely to break existing scripts - // that expect the old behavior to trigger something quickly. + // Past Zeek versions drained as long as there events, including when + // a handler queued new events during its execution. This could lead + // to endless loops in case a handler kept triggering its own event. + // We now limit this to just a couple of rounds. We do more than + // just one round to make it less likely to break existing scripts + // that expect the old behavior to trigger something quickly. - for ( int round = 0; head && round < 2; round++ ) - { - Event* current = head; - head = nullptr; - tail = nullptr; + for ( int round = 0; head && round < 2; round++ ) { + Event* current = head; + head = nullptr; + tail = nullptr; - while ( current ) - { - Event* next = current->NextEvent(); + while ( current ) { + Event* next = current->NextEvent(); - current_src = current->Source(); - current_aid = current->Analyzer(); - current_ts = current->Time(); - current->Dispatch(); - Unref(current); + current_src = current->Source(); + current_aid = current->Analyzer(); + current_ts = current->Time(); + current->Dispatch(); + Unref(current); - ++event_mgr.num_events_dispatched; - current = next; - } - } + ++event_mgr.num_events_dispatched; + current = next; + } + } - // Note: we might eventually need a general way to specify things to - // do after draining events. - draining = false; + // Note: we might eventually need a general way to specify things to + // do after draining events. + draining = false; - // Make sure all of the triggers get processed every time the events - // drain. - detail::trigger_mgr->Process(); - } + // Make sure all of the triggers get processed every time the events + // drain. + detail::trigger_mgr->Process(); +} -void EventMgr::Describe(ODesc* d) const - { - int n = 0; - Event* e; - for ( e = head; e; e = e->NextEvent() ) - ++n; +void EventMgr::Describe(ODesc* d) const { + int n = 0; + Event* e; + for ( e = head; e; e = e->NextEvent() ) + ++n; - d->AddCount(n); + d->AddCount(n); - for ( e = head; e; e = e->NextEvent() ) - { - e->Describe(d); - d->NL(); - } - } + for ( e = head; e; e = e->NextEvent() ) { + e->Describe(d); + d->NL(); + } +} -void EventMgr::Process() - { - queue_flare.Extinguish(); +void EventMgr::Process() { + queue_flare.Extinguish(); - // While it semes like the most logical thing to do, we dont want - // to call Drain() as part of this method. It will get called at - // the end of net_run after all of the sources have been processed - // and had the opportunity to spawn new events. - } + // While it semes like the most logical thing to do, we dont want + // to call Drain() as part of this method. It will get called at + // the end of net_run after all of the sources have been processed + // and had the opportunity to spawn new events. +} -void EventMgr::InitPostScript() - { - iosource_mgr->Register(this, true, false); - if ( ! iosource_mgr->RegisterFd(queue_flare.FD(), this) ) - reporter->FatalError("Failed to register event manager FD with iosource_mgr"); - } +void EventMgr::InitPostScript() { + iosource_mgr->Register(this, true, false); + if ( ! iosource_mgr->RegisterFd(queue_flare.FD(), this) ) + reporter->FatalError("Failed to register event manager FD with iosource_mgr"); +} -void EventMgr::InitPostFork() - { - // Re-initialize the flare, closing and re-opening the underlying - // pipe FDs. This is needed so that each Zeek process in a supervisor - // setup has its own pipe instead of them all sharing a single pipe. - queue_flare = zeek::detail::Flare{}; - } +void EventMgr::InitPostFork() { + // Re-initialize the flare, closing and re-opening the underlying + // pipe FDs. This is needed so that each Zeek process in a supervisor + // setup has its own pipe instead of them all sharing a single pipe. + queue_flare = zeek::detail::Flare{}; +} - } // namespace zeek +} // namespace zeek diff --git a/src/Event.h b/src/Event.h index 75929817c1..2fc9ee0e53 100644 --- a/src/Event.h +++ b/src/Event.h @@ -12,131 +12,124 @@ #include "zeek/analyzer/Analyzer.h" #include "zeek/iosource/IOSource.h" -namespace zeek - { +namespace zeek { -namespace run_state - { +namespace run_state { extern double network_time; - } // namespace run_state +} // namespace run_state class EventMgr; -class Event final : public Obj - { +class Event final : public Obj { public: - Event(const EventHandlerPtr& handler, zeek::Args args, - util::detail::SourceID src = util::detail::SOURCE_LOCAL, analyzer::ID aid = 0, - Obj* obj = nullptr, double ts = run_state::network_time); + Event(const EventHandlerPtr& handler, zeek::Args args, util::detail::SourceID src = util::detail::SOURCE_LOCAL, + analyzer::ID aid = 0, Obj* obj = nullptr, double ts = run_state::network_time); - void SetNext(Event* n) { next_event = n; } - Event* NextEvent() const { return next_event; } + void SetNext(Event* n) { next_event = n; } + Event* NextEvent() const { return next_event; } - util::detail::SourceID Source() const { return src; } - analyzer::ID Analyzer() const { return aid; } - EventHandlerPtr Handler() const { return handler; } - const zeek::Args& Args() const { return args; } - double Time() const { return ts; } + util::detail::SourceID Source() const { return src; } + analyzer::ID Analyzer() const { return aid; } + EventHandlerPtr Handler() const { return handler; } + const zeek::Args& Args() const { return args; } + double Time() const { return ts; } - void Describe(ODesc* d) const override; + void Describe(ODesc* d) const override; protected: - friend class EventMgr; + friend class EventMgr; - // This method is protected to make sure that everybody goes through - // EventMgr::Dispatch(). - void Dispatch(bool no_remote = false); + // This method is protected to make sure that everybody goes through + // EventMgr::Dispatch(). + void Dispatch(bool no_remote = false); - EventHandlerPtr handler; - zeek::Args args; - util::detail::SourceID src; - analyzer::ID aid; - double ts; - Obj* obj; - Event* next_event; - }; + EventHandlerPtr handler; + zeek::Args args; + util::detail::SourceID src; + analyzer::ID aid; + double ts; + Obj* obj; + Event* next_event; +}; -class EventMgr final : public Obj, public iosource::IOSource - { +class EventMgr final : public Obj, public iosource::IOSource { public: - EventMgr(); - ~EventMgr() override; + EventMgr(); + ~EventMgr() override; - /** - * Adds an event to the queue. If no handler is found for the event - * when later going to call it, nothing happens except for having - * wasted a bit of time/resources, so callers may want to first check - * if any handler/consumer exists before enqueuing an event. - * @param h reference to the event handler to later call. - * @param vl the argument list to the event handler call. - * @param src indicates the origin of the event (local versus remote). - * @param aid identifies the protocol analyzer generating the event. - * @param obj an arbitrary object to use as a "cookie" or just hold a - * reference to until dispatching the event. - * @param ts timestamp at which the event is intended to be executed - * (defaults to current network time). - */ - void Enqueue(const EventHandlerPtr& h, zeek::Args vl, - util::detail::SourceID src = util::detail::SOURCE_LOCAL, analyzer::ID aid = 0, - Obj* obj = nullptr, double ts = run_state::network_time); + /** + * Adds an event to the queue. If no handler is found for the event + * when later going to call it, nothing happens except for having + * wasted a bit of time/resources, so callers may want to first check + * if any handler/consumer exists before enqueuing an event. + * @param h reference to the event handler to later call. + * @param vl the argument list to the event handler call. + * @param src indicates the origin of the event (local versus remote). + * @param aid identifies the protocol analyzer generating the event. + * @param obj an arbitrary object to use as a "cookie" or just hold a + * reference to until dispatching the event. + * @param ts timestamp at which the event is intended to be executed + * (defaults to current network time). + */ + void Enqueue(const EventHandlerPtr& h, zeek::Args vl, util::detail::SourceID src = util::detail::SOURCE_LOCAL, + analyzer::ID aid = 0, Obj* obj = nullptr, double ts = run_state::network_time); - /** - * A version of Enqueue() taking a variable number of arguments. - */ - template - std::enable_if_t>, ValPtr>> - Enqueue(const EventHandlerPtr& h, Args&&... args) - { - return Enqueue(h, zeek::Args{std::forward(args)...}); - } + /** + * A version of Enqueue() taking a variable number of arguments. + */ + template + std::enable_if_t>, ValPtr>> Enqueue( + const EventHandlerPtr& h, Args&&... args) { + return Enqueue(h, zeek::Args{std::forward(args)...}); + } - void Dispatch(Event* event, bool no_remote = false); + void Dispatch(Event* event, bool no_remote = false); - void Drain(); - bool IsDraining() const { return draining; } + void Drain(); + bool IsDraining() const { return draining; } - bool HasEvents() const { return head != nullptr; } + bool HasEvents() const { return head != nullptr; } - // Returns the source ID of last raised event. - util::detail::SourceID CurrentSource() const { return current_src; } + // Returns the source ID of last raised event. + util::detail::SourceID CurrentSource() const { return current_src; } - // Returns the ID of the analyzer which raised the last event, or 0 if - // non-analyzer event. - analyzer::ID CurrentAnalyzer() const { return current_aid; } + // Returns the ID of the analyzer which raised the last event, or 0 if + // non-analyzer event. + analyzer::ID CurrentAnalyzer() const { return current_aid; } - // Returns the timestamp of the last raised event. The timestamp reflects the network time - // the event was intended to be executed. For scheduled events, this is the time the event - // was scheduled to. For any other event, this is the time when the event was created. - double CurrentEventTime() const { return current_ts; } + // Returns the timestamp of the last raised event. The timestamp reflects the network time + // the event was intended to be executed. For scheduled events, this is the time the event + // was scheduled to. For any other event, this is the time when the event was created. + double CurrentEventTime() const { return current_ts; } - int Size() const { return num_events_queued - num_events_dispatched; } + int Size() const { return num_events_queued - num_events_dispatched; } - void Describe(ODesc* d) const override; + void Describe(ODesc* d) const override; - double GetNextTimeout() override { return -1; } - void Process() override; - const char* Tag() override { return "EventManager"; } - void InitPostScript(); + double GetNextTimeout() override { return -1; } + void Process() override; + const char* Tag() override { return "EventManager"; } + void InitPostScript(); - // Initialization to be done after a fork() happened. - void InitPostFork(); + // Initialization to be done after a fork() happened. + void InitPostFork(); - uint64_t num_events_queued = 0; - uint64_t num_events_dispatched = 0; + uint64_t num_events_queued = 0; + uint64_t num_events_dispatched = 0; protected: - void QueueEvent(Event* event); + void QueueEvent(Event* event); - Event* head; - Event* tail; - util::detail::SourceID current_src; - analyzer::ID current_aid; - double current_ts; - RecordVal* src_val; - bool draining; - detail::Flare queue_flare; - }; + Event* head; + Event* tail; + util::detail::SourceID current_src; + analyzer::ID current_aid; + double current_ts; + RecordVal* src_val; + bool draining; + detail::Flare queue_flare; +}; extern EventMgr event_mgr; - } // namespace zeek +} // namespace zeek diff --git a/src/EventHandler.cc b/src/EventHandler.cc index 79f12d32cb..c31a601713 100644 --- a/src/EventHandler.cc +++ b/src/EventHandler.cc @@ -11,127 +11,108 @@ #include "zeek/broker/Manager.h" #include "zeek/telemetry/Manager.h" -namespace zeek - { +namespace zeek { -EventHandler::EventHandler(std::string arg_name) - { - name = std::move(arg_name); - used = false; - error_handler = false; - enabled = true; - generate_always = false; - } +EventHandler::EventHandler(std::string arg_name) { + name = std::move(arg_name); + used = false; + error_handler = false; + enabled = true; + generate_always = false; +} -EventHandler::operator bool() const - { - return enabled && - ((local && local->HasEnabledBodies()) || generate_always || ! auto_publish.empty()); - } +EventHandler::operator bool() const { + return enabled && ((local && local->HasEnabledBodies()) || generate_always || ! auto_publish.empty()); +} -const FuncTypePtr& EventHandler::GetType(bool check_export) - { - if ( type ) - return type; +const FuncTypePtr& EventHandler::GetType(bool check_export) { + if ( type ) + return type; - const auto& id = detail::lookup_ID(name.data(), detail::current_module.c_str(), false, false, - check_export); + const auto& id = detail::lookup_ID(name.data(), detail::current_module.c_str(), false, false, check_export); - if ( ! id ) - return FuncType::nil; + if ( ! id ) + return FuncType::nil; - if ( id->GetType()->Tag() != TYPE_FUNC ) - return FuncType::nil; + if ( id->GetType()->Tag() != TYPE_FUNC ) + return FuncType::nil; - type = id->GetType(); - return type; - } + type = id->GetType(); + return type; +} -void EventHandler::SetFunc(FuncPtr f) - { - local = std::move(f); - } +void EventHandler::SetFunc(FuncPtr f) { local = std::move(f); } -void EventHandler::Call(Args* vl, bool no_remote, double ts) - { - if ( ! call_count ) - { - static auto eh_invocations_family = telemetry_mgr->CounterFamily( - "zeek", "event-handler-invocations", {"name"}, - "Number of times the given event handler was called", "1", true); +void EventHandler::Call(Args* vl, bool no_remote, double ts) { + if ( ! call_count ) { + static auto eh_invocations_family = + telemetry_mgr->CounterFamily("zeek", "event-handler-invocations", {"name"}, + "Number of times the given event handler was called", "1", true); - call_count = eh_invocations_family.GetOrAdd({{"name", name}}); - } + call_count = eh_invocations_family.GetOrAdd({{"name", name}}); + } - call_count->Inc(); + call_count->Inc(); - if ( new_event ) - NewEvent(vl); + if ( new_event ) + NewEvent(vl); - if ( ! no_remote ) - { - if ( ! auto_publish.empty() ) - { - // Send event in form [name, xs...] where xs represent the arguments. - broker::vector xs; - xs.reserve(vl->size()); - bool valid_args = true; + if ( ! no_remote ) { + if ( ! auto_publish.empty() ) { + // Send event in form [name, xs...] where xs represent the arguments. + broker::vector xs; + xs.reserve(vl->size()); + bool valid_args = true; - for ( auto i = 0u; i < vl->size(); ++i ) - { - auto opt_data = Broker::detail::val_to_data((*vl)[i].get()); + for ( auto i = 0u; i < vl->size(); ++i ) { + auto opt_data = Broker::detail::val_to_data((*vl)[i].get()); - if ( opt_data ) - xs.emplace_back(std::move(*opt_data)); - else - { - valid_args = false; - auto_publish.clear(); - reporter->Error("failed auto-remote event '%s', disabled", Name()); - break; - } - } + if ( opt_data ) + xs.emplace_back(std::move(*opt_data)); + else { + valid_args = false; + auto_publish.clear(); + reporter->Error("failed auto-remote event '%s', disabled", Name()); + break; + } + } - if ( valid_args ) - { - for ( auto it = auto_publish.begin();; ) - { - const auto& topic = *it; - ++it; + if ( valid_args ) { + for ( auto it = auto_publish.begin();; ) { + const auto& topic = *it; + ++it; - if ( it != auto_publish.end() ) - broker_mgr->PublishEvent(topic, Name(), xs, ts); - else - { - broker_mgr->PublishEvent(topic, Name(), std::move(xs), ts); - break; - } - } - } - } - } + if ( it != auto_publish.end() ) + broker_mgr->PublishEvent(topic, Name(), xs, ts); + else { + broker_mgr->PublishEvent(topic, Name(), std::move(xs), ts); + break; + } + } + } + } + } - if ( local ) - // No try/catch here; we pass exceptions upstream. - local->Invoke(vl); - } + if ( local ) + // No try/catch here; we pass exceptions upstream. + local->Invoke(vl); +} -void EventHandler::NewEvent(Args* vl) - { - if ( ! new_event ) - return; +void EventHandler::NewEvent(Args* vl) { + if ( ! new_event ) + return; - if ( this == new_event.Ptr() ) - // new_event() is the one event we don't want to report. - return; + if ( this == new_event.Ptr() ) + // new_event() is the one event we don't want to report. + return; - auto vargs = MakeCallArgumentVector(*vl, GetType()->Params()); + auto vargs = MakeCallArgumentVector(*vl, GetType()->Params()); - auto ev = new Event(new_event, { - make_intrusive(name), - std::move(vargs), - }); - event_mgr.Dispatch(ev); - } + auto ev = new Event(new_event, { + make_intrusive(name), + std::move(vargs), + }); + event_mgr.Dispatch(ev); +} - } // namespace zeek +} // namespace zeek diff --git a/src/EventHandler.h b/src/EventHandler.h index 89d5d705fc..d99e3acaad 100644 --- a/src/EventHandler.h +++ b/src/EventHandler.h @@ -11,104 +11,95 @@ #include "zeek/ZeekList.h" #include "zeek/telemetry/Counter.h" -namespace zeek - { +namespace zeek { -namespace run_state - { +namespace run_state { extern double network_time; - } // namespace run_state +} // namespace run_state class Func; using FuncPtr = IntrusivePtr; -class EventHandler - { +class EventHandler { public: - explicit EventHandler(std::string name); + explicit EventHandler(std::string name); - const char* Name() const { return name.data(); } + const char* Name() const { return name.data(); } - const FuncPtr& GetFunc() const { return local; } + const FuncPtr& GetFunc() const { return local; } - const FuncTypePtr& GetType(bool check_export = true); + const FuncTypePtr& GetType(bool check_export = true); - void SetFunc(FuncPtr f); + void SetFunc(FuncPtr f); - void AutoPublish(std::string topic) { auto_publish.insert(std::move(topic)); } + void AutoPublish(std::string topic) { auto_publish.insert(std::move(topic)); } - void AutoUnpublish(const std::string& topic) { auto_publish.erase(topic); } + void AutoUnpublish(const std::string& topic) { auto_publish.erase(topic); } - void Call(zeek::Args* vl, bool no_remote = false, double ts = run_state::network_time); + void Call(zeek::Args* vl, bool no_remote = false, double ts = run_state::network_time); - // Returns true if there is at least one local or remote handler. - explicit operator bool() const; + // Returns true if there is at least one local or remote handler. + explicit operator bool() const; - void SetUsed() { used = true; } - bool Used() const { return used; } + void SetUsed() { used = true; } + bool Used() const { return used; } - // Handlers marked as error handlers will not be called recursively to - // avoid infinite loops if they trigger a similar error themselves. - void SetErrorHandler() { error_handler = true; } - bool ErrorHandler() const { return error_handler; } + // Handlers marked as error handlers will not be called recursively to + // avoid infinite loops if they trigger a similar error themselves. + void SetErrorHandler() { error_handler = true; } + bool ErrorHandler() const { return error_handler; } - void SetEnable(bool arg_enable) { enabled = arg_enable; } + void SetEnable(bool arg_enable) { enabled = arg_enable; } - // Flags the event as interesting even if there is no body defined. In - // particular, this will then still pass the event on to plugins. - void SetGenerateAlways(bool arg_generate_always = true) - { - generate_always = arg_generate_always; - } - bool GenerateAlways() const { return generate_always; } + // Flags the event as interesting even if there is no body defined. In + // particular, this will then still pass the event on to plugins. + void SetGenerateAlways(bool arg_generate_always = true) { generate_always = arg_generate_always; } + bool GenerateAlways() const { return generate_always; } - uint64_t CallCount() const { return call_count ? call_count->Value() : 0; } + uint64_t CallCount() const { return call_count ? call_count->Value() : 0; } private: - void NewEvent(zeek::Args* vl); // Raise new_event() meta event. + void NewEvent(zeek::Args* vl); // Raise new_event() meta event. - std::string name; - FuncPtr local; - FuncTypePtr type; - bool used; // this handler is indeed used somewhere - bool enabled; - bool error_handler; // this handler reports error messages. - bool generate_always; + std::string name; + FuncPtr local; + FuncTypePtr type; + bool used; // this handler is indeed used somewhere + bool enabled; + bool error_handler; // this handler reports error messages. + bool generate_always; - // Initialize this lazy, so we don't expose metrics for 0 values. - std::optional call_count; + // Initialize this lazy, so we don't expose metrics for 0 values. + std::optional call_count; - std::unordered_set auto_publish; - }; + std::unordered_set auto_publish; +}; // Encapsulates a ptr to an event handler to overload the boolean operator. -class EventHandlerPtr - { +class EventHandlerPtr { public: - EventHandlerPtr(EventHandler* p = nullptr) { handler = p; } - EventHandlerPtr(const EventHandlerPtr& h) { handler = h.handler; } + EventHandlerPtr(EventHandler* p = nullptr) { handler = p; } + EventHandlerPtr(const EventHandlerPtr& h) { handler = h.handler; } - const EventHandlerPtr& operator=(EventHandler* p) - { - handler = p; - return *this; - } - const EventHandlerPtr& operator=(const EventHandlerPtr& h) - { - handler = h.handler; - return *this; - } + const EventHandlerPtr& operator=(EventHandler* p) { + handler = p; + return *this; + } + const EventHandlerPtr& operator=(const EventHandlerPtr& h) { + handler = h.handler; + return *this; + } - bool operator==(const EventHandlerPtr& h) const { return handler == h.handler; } + bool operator==(const EventHandlerPtr& h) const { return handler == h.handler; } - EventHandler* Ptr() { return handler; } + EventHandler* Ptr() { return handler; } - explicit operator bool() const { return handler && *handler; } - EventHandler* operator->() { return handler; } - const EventHandler* operator->() const { return handler; } + explicit operator bool() const { return handler && *handler; } + EventHandler* operator->() { return handler; } + const EventHandler* operator->() const { return handler; } private: - EventHandler* handler; - }; + EventHandler* handler; +}; - } // namespace zeek +} // namespace zeek diff --git a/src/EventRegistry.cc b/src/EventRegistry.cc index 1b4cb963b9..0cddef6dff 100644 --- a/src/EventRegistry.cc +++ b/src/EventRegistry.cc @@ -7,167 +7,144 @@ #include "zeek/RE.h" #include "zeek/Reporter.h" -namespace zeek - { +namespace zeek { EventRegistry::EventRegistry() = default; EventRegistry::~EventRegistry() noexcept = default; -EventHandlerPtr EventRegistry::Register(std::string_view name, bool is_from_script) - { - // If there already is an entry in the registry, we have a - // local handler on the script layer. - EventHandler* h = event_registry->Lookup(name); +EventHandlerPtr EventRegistry::Register(std::string_view name, bool is_from_script) { + // If there already is an entry in the registry, we have a + // local handler on the script layer. + EventHandler* h = event_registry->Lookup(name); - if ( h ) - { - if ( ! is_from_script ) - not_only_from_script.insert(std::string(name)); + if ( h ) { + if ( ! is_from_script ) + not_only_from_script.insert(std::string(name)); - h->SetUsed(); - return h; - } + h->SetUsed(); + return h; + } - h = new EventHandler(std::string(name)); - event_registry->Register(h, is_from_script); + h = new EventHandler(std::string(name)); + event_registry->Register(h, is_from_script); - h->SetUsed(); + h->SetUsed(); - return h; - } + return h; +} -void EventRegistry::Register(EventHandlerPtr handler, bool is_from_script) - { - std::string name = handler->Name(); +void EventRegistry::Register(EventHandlerPtr handler, bool is_from_script) { + std::string name = handler->Name(); - handlers[name] = std::unique_ptr(handler.Ptr()); + handlers[name] = std::unique_ptr(handler.Ptr()); - if ( ! is_from_script ) - not_only_from_script.insert(name); - } + if ( ! is_from_script ) + not_only_from_script.insert(name); +} -EventHandler* EventRegistry::Lookup(std::string_view name) - { - auto it = handlers.find(name); - if ( it != handlers.end() ) - return it->second.get(); +EventHandler* EventRegistry::Lookup(std::string_view name) { + auto it = handlers.find(name); + if ( it != handlers.end() ) + return it->second.get(); - return nullptr; - } + return nullptr; +} -bool EventRegistry::NotOnlyRegisteredFromScript(std::string_view name) - { - return not_only_from_script.count(std::string(name)) > 0; - } +bool EventRegistry::NotOnlyRegisteredFromScript(std::string_view name) { + return not_only_from_script.count(std::string(name)) > 0; +} -EventRegistry::string_list EventRegistry::Match(RE_Matcher* pattern) - { - string_list names; +EventRegistry::string_list EventRegistry::Match(RE_Matcher* pattern) { + string_list names; - for ( const auto& entry : handlers ) - { - EventHandler* v = entry.second.get(); - if ( v->GetFunc() && pattern->MatchExactly(v->Name()) ) - names.push_back(entry.first); - } + for ( const auto& entry : handlers ) { + EventHandler* v = entry.second.get(); + if ( v->GetFunc() && pattern->MatchExactly(v->Name()) ) + names.push_back(entry.first); + } - return names; - } + return names; +} -EventRegistry::string_list EventRegistry::UnusedHandlers() - { - string_list names; +EventRegistry::string_list EventRegistry::UnusedHandlers() { + string_list names; - for ( const auto& entry : handlers ) - { - EventHandler* v = entry.second.get(); - if ( v->GetFunc() && ! v->Used() ) - names.push_back(entry.first); - } + for ( const auto& entry : handlers ) { + EventHandler* v = entry.second.get(); + if ( v->GetFunc() && ! v->Used() ) + names.push_back(entry.first); + } - return names; - } + return names; +} -EventRegistry::string_list EventRegistry::UsedHandlers() - { - string_list names; +EventRegistry::string_list EventRegistry::UsedHandlers() { + string_list names; - for ( const auto& entry : handlers ) - { - EventHandler* v = entry.second.get(); - if ( v->GetFunc() && v->Used() ) - names.push_back(entry.first); - } + for ( const auto& entry : handlers ) { + EventHandler* v = entry.second.get(); + if ( v->GetFunc() && v->Used() ) + names.push_back(entry.first); + } - return names; - } + return names; +} -EventRegistry::string_list EventRegistry::AllHandlers() - { - string_list names; +EventRegistry::string_list EventRegistry::AllHandlers() { + string_list names; - for ( const auto& entry : handlers ) - { - names.push_back(entry.first); - } + for ( const auto& entry : handlers ) { + names.push_back(entry.first); + } - return names; - } + return names; +} -void EventRegistry::PrintDebug() - { - for ( const auto& entry : handlers ) - { - EventHandler* v = entry.second.get(); - fprintf(stderr, "Registered event %s (%s handler / %s)\n", v->Name(), - v->GetFunc() ? "local" : "no", *v ? "active" : "not active"); - } - } +void EventRegistry::PrintDebug() { + for ( const auto& entry : handlers ) { + EventHandler* v = entry.second.get(); + fprintf(stderr, "Registered event %s (%s handler / %s)\n", v->Name(), v->GetFunc() ? "local" : "no", + *v ? "active" : "not active"); + } +} -void EventRegistry::SetErrorHandler(std::string_view name) - { - EventHandler* eh = Lookup(name); +void EventRegistry::SetErrorHandler(std::string_view name) { + EventHandler* eh = Lookup(name); - if ( eh ) - { - eh->SetErrorHandler(); - return; - } + if ( eh ) { + eh->SetErrorHandler(); + return; + } - reporter->InternalWarning("unknown event handler '%s' in SetErrorHandler()", - std::string(name).c_str()); - } + reporter->InternalWarning("unknown event handler '%s' in SetErrorHandler()", std::string(name).c_str()); +} -void EventRegistry::ActivateAllHandlers() - { - auto event_names = AllHandlers(); - for ( const auto& name : event_names ) - { - if ( auto event = Lookup(name) ) - event->SetGenerateAlways(); - } - } +void EventRegistry::ActivateAllHandlers() { + auto event_names = AllHandlers(); + for ( const auto& name : event_names ) { + if ( auto event = Lookup(name) ) + event->SetGenerateAlways(); + } +} -EventGroupPtr EventRegistry::RegisterGroup(EventGroupKind kind, std::string_view name) - { - auto key = std::pair{kind, std::string{name}}; - if ( const auto& it = event_groups.find(key); it != event_groups.end() ) - return it->second; +EventGroupPtr EventRegistry::RegisterGroup(EventGroupKind kind, std::string_view name) { + auto key = std::pair{kind, std::string{name}}; + if ( const auto& it = event_groups.find(key); it != event_groups.end() ) + return it->second; - auto group = std::make_shared(kind, name); - return event_groups.emplace(key, group).first->second; - } + auto group = std::make_shared(kind, name); + return event_groups.emplace(key, group).first->second; +} -EventGroupPtr EventRegistry::LookupGroup(EventGroupKind kind, std::string_view name) - { - auto key = std::pair{kind, std::string{name}}; - if ( const auto& it = event_groups.find(key); it != event_groups.end() ) - return it->second; +EventGroupPtr EventRegistry::LookupGroup(EventGroupKind kind, std::string_view name) { + auto key = std::pair{kind, std::string{name}}; + if ( const auto& it = event_groups.find(key); it != event_groups.end() ) + return it->second; - return nullptr; - } + return nullptr; +} -EventGroup::EventGroup(EventGroupKind kind, std::string_view name) : kind(kind), name(name) { } +EventGroup::EventGroup(EventGroupKind kind, std::string_view name) : kind(kind), name(name) {} // Run through all ScriptFunc instances associated with this group and // update their bodies after a group's enable/disable state has changed. @@ -177,50 +154,36 @@ EventGroup::EventGroup(EventGroupKind kind, std::string_view name) : kind(kind), // EventGroup is private friend with Func, so fiddling with the bodies // and private members works and keeps the logic out of Func and away // from the public zeek:: namespace. -void EventGroup::UpdateFuncBodies() - { - static auto is_group_disabled = [](const auto& g) - { - return g->IsDisabled(); - }; +void EventGroup::UpdateFuncBodies() { + static auto is_group_disabled = [](const auto& g) { return g->IsDisabled(); }; - for ( auto& func : funcs ) - { - for ( auto& b : func->bodies ) - b.disabled = std::any_of(b.groups.cbegin(), b.groups.cend(), is_group_disabled); + for ( auto& func : funcs ) { + for ( auto& b : func->bodies ) + b.disabled = std::any_of(b.groups.cbegin(), b.groups.cend(), is_group_disabled); - static auto is_body_enabled = [](const auto& b) - { - return ! b.disabled; - }; - func->has_enabled_bodies = std::any_of(func->bodies.cbegin(), func->bodies.cend(), - is_body_enabled); - } - } + static auto is_body_enabled = [](const auto& b) { return ! b.disabled; }; + func->has_enabled_bodies = std::any_of(func->bodies.cbegin(), func->bodies.cend(), is_body_enabled); + } +} -void EventGroup::Enable() - { - if ( enabled ) - return; +void EventGroup::Enable() { + if ( enabled ) + return; - enabled = true; + enabled = true; - UpdateFuncBodies(); - } + UpdateFuncBodies(); +} -void EventGroup::Disable() - { - if ( ! enabled ) - return; +void EventGroup::Disable() { + if ( ! enabled ) + return; - enabled = false; + enabled = false; - UpdateFuncBodies(); - } + UpdateFuncBodies(); +} -void EventGroup::AddFunc(detail::ScriptFuncPtr f) - { - funcs.insert(f); - } +void EventGroup::AddFunc(detail::ScriptFuncPtr f) { funcs.insert(f); } - } // namespace zeek +} // namespace zeek diff --git a/src/EventRegistry.h b/src/EventRegistry.h index b1189f5a5e..f0740e9f17 100644 --- a/src/EventRegistry.h +++ b/src/EventRegistry.h @@ -13,15 +13,13 @@ #include "zeek/IntrusivePtr.h" -namespace zeek - { +namespace zeek { // The different kinds of event groups that exist. -enum class EventGroupKind - { - Attribute, - Module, - }; +enum class EventGroupKind { + Attribute, + Module, +}; class EventGroup; class EventHandler; @@ -30,89 +28,86 @@ class RE_Matcher; using EventGroupPtr = std::shared_ptr; -namespace detail - { +namespace detail { class ScriptFunc; using ScriptFuncPtr = zeek::IntrusivePtr; - } +} // namespace detail // The registry keeps track of all events that we provide or handle. -class EventRegistry final - { +class EventRegistry final { public: - EventRegistry(); - ~EventRegistry() noexcept; + EventRegistry(); + ~EventRegistry() noexcept; - /** - * Performs a lookup for an existing event handler and returns it - * if one exists, or else creates one, registers it, and returns it. - * @param name The name of the event handler to lookup/register. - * @param name Whether the registration is coming from a script element. - * @return The event handler. - */ - EventHandlerPtr Register(std::string_view name, bool is_from_script = false); + /** + * Performs a lookup for an existing event handler and returns it + * if one exists, or else creates one, registers it, and returns it. + * @param name The name of the event handler to lookup/register. + * @param name Whether the registration is coming from a script element. + * @return The event handler. + */ + EventHandlerPtr Register(std::string_view name, bool is_from_script = false); - void Register(EventHandlerPtr handler, bool is_from_script = false); + void Register(EventHandlerPtr handler, bool is_from_script = false); - // Return nil if unknown. - EventHandler* Lookup(std::string_view name); + // Return nil if unknown. + EventHandler* Lookup(std::string_view name); - // True if the given event handler (1) exists, and (2) was registered - // in a non-script context (even if perhaps also registered in a script - // context). - bool NotOnlyRegisteredFromScript(std::string_view name); + // True if the given event handler (1) exists, and (2) was registered + // in a non-script context (even if perhaps also registered in a script + // context). + bool NotOnlyRegisteredFromScript(std::string_view name); - // Returns a list of all local handlers that match the given pattern. - // Passes ownership of list. - using string_list = std::vector; - string_list Match(RE_Matcher* pattern); + // Returns a list of all local handlers that match the given pattern. + // Passes ownership of list. + using string_list = std::vector; + string_list Match(RE_Matcher* pattern); - // Marks a handler as handling errors. Error handler will not be called - // recursively to avoid infinite loops in case they trigger an error - // themselves. - void SetErrorHandler(std::string_view name); + // Marks a handler as handling errors. Error handler will not be called + // recursively to avoid infinite loops in case they trigger an error + // themselves. + void SetErrorHandler(std::string_view name); - string_list UnusedHandlers(); - string_list UsedHandlers(); - string_list AllHandlers(); + string_list UnusedHandlers(); + string_list UsedHandlers(); + string_list AllHandlers(); - void PrintDebug(); + void PrintDebug(); - /** - * Marks all event handlers as active. - * - * By default, zeek does not generate (raise) events that have not handled by - * any scripts. This means that these events will be invisible to a lot of other - * event handlers - and will not raise :zeek:id:`new_event`. Calling this - * function will cause all event handlers to be raised. This is likely only - * useful for debugging and fuzzing, and likely causes reduced performance. - */ - void ActivateAllHandlers(); + /** + * Marks all event handlers as active. + * + * By default, zeek does not generate (raise) events that have not handled by + * any scripts. This means that these events will be invisible to a lot of other + * event handlers - and will not raise :zeek:id:`new_event`. Calling this + * function will cause all event handlers to be raised. This is likely only + * useful for debugging and fuzzing, and likely causes reduced performance. + */ + void ActivateAllHandlers(); - /** - * Lookup or register a new event group. - * - * @return Pointer to the group. - */ - EventGroupPtr RegisterGroup(EventGroupKind kind, std::string_view name); + /** + * Lookup or register a new event group. + * + * @return Pointer to the group. + */ + EventGroupPtr RegisterGroup(EventGroupKind kind, std::string_view name); - /** - * Lookup an event group. - * - * @return Pointer to the group or nil if the group does not exist. - */ - EventGroupPtr LookupGroup(EventGroupKind kind, std::string_view name); + /** + * Lookup an event group. + * + * @return Pointer to the group or nil if the group does not exist. + */ + EventGroupPtr LookupGroup(EventGroupKind kind, std::string_view name); private: - std::map, std::less<>> handlers; - // Tracks whether a given event handler was registered in a - // non-script context. - std::unordered_set not_only_from_script; + std::map, std::less<>> handlers; + // Tracks whether a given event handler was registered in a + // non-script context. + std::unordered_set not_only_from_script; - // Map event groups identified by kind and name to their instances. - std::map, std::shared_ptr, std::less<>> - event_groups; - }; + // Map event groups identified by kind and name to their instances. + std::map, std::shared_ptr, std::less<>> event_groups; +}; /** * Event group. @@ -135,45 +130,44 @@ private: * bodies of the tracked ScriptFuncs and updates them to reflect the current * group state. */ -class EventGroup final - { +class EventGroup final { public: - EventGroup(EventGroupKind kind, std::string_view name); - ~EventGroup() noexcept = default; - EventGroup(const EventGroup& g) = delete; - EventGroup& operator=(const EventGroup&) = delete; + EventGroup(EventGroupKind kind, std::string_view name); + ~EventGroup() noexcept = default; + EventGroup(const EventGroup& g) = delete; + EventGroup& operator=(const EventGroup&) = delete; - /** - * Enable this event group and update all event handlers associated with it. - */ - void Enable(); + /** + * Enable this event group and update all event handlers associated with it. + */ + void Enable(); - /** - * Disable this event group and update all event handlers associated with it. - */ - void Disable(); + /** + * Disable this event group and update all event handlers associated with it. + */ + void Disable(); - /** - * @return True if this group is disabled else false. - */ - bool IsDisabled() { return ! enabled; } + /** + * @return True if this group is disabled else false. + */ + bool IsDisabled() { return ! enabled; } - /** - * Add a function to this group that may contain matching bodies. - * - * @param f Pointer to the function to track. - */ - void AddFunc(detail::ScriptFuncPtr f); + /** + * Add a function to this group that may contain matching bodies. + * + * @param f Pointer to the function to track. + */ + void AddFunc(detail::ScriptFuncPtr f); private: - void UpdateFuncBodies(); + void UpdateFuncBodies(); - EventGroupKind kind; - std::string name; - bool enabled = true; - std::unordered_set funcs; - }; + EventGroupKind kind; + std::string name; + bool enabled = true; + std::unordered_set funcs; +}; extern EventRegistry* event_registry; - } // namespace zeek +} // namespace zeek diff --git a/src/EventTrace.cc b/src/EventTrace.cc index d313f7c3ea..e962636945 100644 --- a/src/EventTrace.cc +++ b/src/EventTrace.cc @@ -11,1208 +11,1024 @@ #include "zeek/Reporter.h" #include "zeek/ZeekString.h" -namespace zeek::detail - { +namespace zeek::detail { std::unique_ptr etm; // Helper function for generating a correct script-level representation // of a string constant. -static std::string escape_string(const u_char* b, int len) - { - std::string res = "\""; - - for ( int i = 0; i < len; ++i ) - { - unsigned char c = b[i]; - - switch ( c ) - { - case '\a': - res += "\\a"; - break; - case '\b': - res += "\\b"; - break; - case '\f': - res += "\\f"; - break; - case '\n': - res += "\\n"; - break; - case '\r': - res += "\\r"; - break; - case '\t': - res += "\\t"; - break; - case '\v': - res += "\\v"; - break; - - case '\\': - res += "\\\\"; - break; - case '"': - res += "\\\""; - break; - - default: - if ( isprint(c) ) - res += c; - else - { - char buf[8192]; - snprintf(buf, sizeof buf, "%03o", c); - res += "\\"; - res += buf; - } - break; - } - } - - return res + "\""; - } - -ValTrace::ValTrace(const ValPtr& _v) : v(_v) - { - t = v->GetType(); - - switch ( t->Tag() ) - { - case TYPE_LIST: - TraceList(cast_intrusive(v)); - break; - - case TYPE_RECORD: - TraceRecord(cast_intrusive(v)); - break; - - case TYPE_TABLE: - TraceTable(cast_intrusive(v)); - break; - - case TYPE_VECTOR: - TraceVector(cast_intrusive(v)); - break; - - default: - break; - } - } - -bool ValTrace::operator==(const ValTrace& vt) const - { - auto& vt_v = vt.GetVal(); - if ( vt_v == v ) - return true; - - auto tag = t->Tag(); - - if ( vt.GetType()->Tag() != tag ) - return false; - - switch ( tag ) - { - case TYPE_BOOL: - case TYPE_INT: - case TYPE_ENUM: - return v->AsInt() == vt_v->AsInt(); - - case TYPE_COUNT: - case TYPE_PORT: - return v->AsCount() == vt_v->AsCount(); - - case TYPE_DOUBLE: - case TYPE_INTERVAL: - case TYPE_TIME: - return v->AsDouble() == vt_v->AsDouble(); - - case TYPE_STRING: - return (*v->AsString()) == (*vt_v->AsString()); - - case TYPE_ADDR: - return v->AsAddr() == vt_v->AsAddr(); - - case TYPE_SUBNET: - return v->AsSubNet() == vt_v->AsSubNet(); - - case TYPE_FUNC: - return v->AsFile() == vt_v->AsFile(); - - case TYPE_FILE: - return v->AsFile() == vt_v->AsFile(); - - case TYPE_PATTERN: - return v->AsPattern() == vt_v->AsPattern(); - - case TYPE_ANY: - return v->AsSubNet() == vt_v->AsSubNet(); - - case TYPE_TYPE: - return v->AsType() == vt_v->AsType(); - - case TYPE_OPAQUE: - return false; // needs pointer equivalence - - case TYPE_LIST: - return SameList(vt); - - case TYPE_RECORD: - return SameRecord(vt); - - case TYPE_TABLE: - return SameTable(vt); - - case TYPE_VECTOR: - return SameVector(vt); - - default: - reporter->InternalError("bad type in ValTrace::operator=="); - } - } - -void ValTrace::ComputeDelta(const ValTrace* prev, DeltaVector& deltas) const - { - auto tag = t->Tag(); - - if ( prev ) - { - ASSERT(prev->GetType()->Tag() == tag); - - auto& prev_v = prev->GetVal(); - - if ( prev_v != v ) - { - if ( *this != *prev ) - deltas.emplace_back(std::make_unique(this, v)); - return; - } - } - - switch ( tag ) - { - case TYPE_BOOL: - case TYPE_INT: - case TYPE_ENUM: - case TYPE_COUNT: - case TYPE_PORT: - case TYPE_DOUBLE: - case TYPE_INTERVAL: - case TYPE_TIME: - case TYPE_STRING: - case TYPE_ADDR: - case TYPE_SUBNET: - case TYPE_FUNC: - case TYPE_PATTERN: - case TYPE_TYPE: - // These don't change in place. No need to create - // them as stand-alone variables, since we can just - // use the constant representation instead. - break; - - case TYPE_ANY: - case TYPE_FILE: - case TYPE_OPAQUE: - // If we have a previous instance, we can ignore this - // one, because we know it's equivalent (due to the - // test at the beginning of this method), and it's - // not meaningful to recurse inside it looking for - // interior changes. - if ( ! prev ) - deltas.emplace_back(std::make_unique(this)); - break; - - case TYPE_LIST: - // We shouldn't see these exposed directly, as they're - // not manipulable at script-level. An exception - // might be for "any" types that are then decomposed - // via compound assignment; for now, we don't support - // those. - reporter->InternalError("list type seen in ValTrace::ComputeDelta"); - break; - - case TYPE_RECORD: - if ( prev ) - ComputeRecordDelta(prev, deltas); - else - deltas.emplace_back(std::make_unique(this)); - break; - - case TYPE_TABLE: - if ( prev ) - ComputeTableDelta(prev, deltas); - - else if ( GetType()->AsTableType()->IsUnspecifiedTable() ) - // For unspecified values, we generate them - // as empty constructors, because we don't - // know their yield type and thus can't - // create variables corresponding to them. - break; - - else if ( t->Yield() ) - deltas.emplace_back(std::make_unique(this)); - else - deltas.emplace_back(std::make_unique(this)); - break; - - case TYPE_VECTOR: - if ( prev ) - ComputeVectorDelta(prev, deltas); - - else if ( GetType()->AsVectorType()->IsUnspecifiedVector() ) - // See above for empty tables/sets. - break; - - else - deltas.emplace_back(std::make_unique(this)); - break; - - default: - reporter->InternalError("bad type in ValTrace::ComputeDelta"); - } - } - -void ValTrace::TraceList(const ListValPtr& lv) - { - auto vals = lv->Vals(); - for ( auto& v : vals ) - elems.emplace_back(std::make_shared(v)); - } - -void ValTrace::TraceRecord(const RecordValPtr& rv) - { - auto n = rv->NumFields(); - auto rt = rv->GetType(); - - for ( auto i = 0U; i < n; ++i ) - { - auto f = rv->RawOptField(i); - if ( f ) - { - auto val = f->ToVal(rt->GetFieldType(i)); - elems.emplace_back(std::make_shared(val)); - } - else - elems.emplace_back(nullptr); - } - } - -void ValTrace::TraceTable(const TableValPtr& tv) - { - for ( auto& elem : tv->ToMap() ) - { - auto& key = elem.first; - elems.emplace_back(std::make_shared(key)); - - auto& val = elem.second; - if ( val ) - elems2.emplace_back(std::make_shared(val)); - } - } - -void ValTrace::TraceVector(const VectorValPtr& vv) - { - auto& vec = vv->RawVec(); - auto n = vec.size(); - auto& yt = vv->RawYieldType(); - auto& yts = vv->RawYieldTypes(); - - for ( auto i = 0U; i < n; ++i ) - { - auto& elem = vec[i]; - if ( elem ) - { - auto& t = yts ? (*yts)[i] : yt; - auto val = elem->ToVal(t); - elems.emplace_back(std::make_shared(val)); - } - else - elems.emplace_back(nullptr); - } - } - -bool ValTrace::SameList(const ValTrace& vt) const - { - return SameElems(vt); - } - -bool ValTrace::SameRecord(const ValTrace& vt) const - { - return SameElems(vt); - } - -bool ValTrace::SameTable(const ValTrace& vt) const - { - auto& vt_elems = vt.elems; - auto n = elems.size(); - if ( n != vt_elems.size() ) - return false; - - auto& vt_elems2 = vt.elems2; - auto n2 = elems2.size(); - if ( n2 != vt_elems2.size() ) - return false; - - ASSERT(n2 == 0 || n == n2); - - // We accommodate the possibility that keys are out-of-order - // between the two sets of elements. - - // The following is O(N^2), but presumably if tables are somehow - // involved (in fact we can only get here if they're used as - // indices into other tables), then they'll likely be small. - for ( auto i = 0U; i < n; ++i ) - { - auto& elem_i = elems[i]; - - // See if we can find a match for it. If we do, we don't - // have to worry that another entry matched it too, since - // all table/set indices will be distinct. - auto j = 0U; - for ( ; j < n; ++j ) - { - auto& vt_elem_j = vt_elems[j]; - if ( *elem_i == *vt_elem_j ) - break; - } - - if ( j == n ) - // No match for the index. - return false; - - if ( n2 > 0 ) - { - // Need a match for the corresponding yield values. - if ( *elems2[i] != *vt_elems2[j] ) - return false; - } - } - - return true; - } - -bool ValTrace::SameVector(const ValTrace& vt) const - { - return SameElems(vt); - } - -bool ValTrace::SameElems(const ValTrace& vt) const - { - auto& vt_elems = vt.elems; - auto n = elems.size(); - if ( n != vt_elems.size() ) - return false; - - for ( auto i = 0U; i < n; ++i ) - { - auto& trace_i = elems[i]; - auto& vt_trace_i = vt_elems[i]; - - if ( trace_i && vt_trace_i ) - { - if ( *trace_i != *vt_trace_i ) - return false; - } - - else if ( trace_i || vt_trace_i ) - return false; - } - - return true; - } - -bool ValTrace::SameSingleton(const ValTrace& vt) const - { - return ! IsAggr(t) && *this == vt; - } - -void ValTrace::ComputeRecordDelta(const ValTrace* prev, DeltaVector& deltas) const - { - auto& prev_elems = prev->elems; - auto n = elems.size(); - if ( n != prev_elems.size() ) - reporter->InternalError("size inconsistency in ValTrace::ComputeRecordDelta"); - - for ( auto i = 0U; i < n; ++i ) - { - const auto trace_i = elems[i].get(); - const auto prev_trace_i = prev_elems[i].get(); - - if ( trace_i ) - { - if ( prev_trace_i ) - { - auto& v = trace_i->GetVal(); - auto& prev_v = prev_trace_i->GetVal(); - - if ( v == prev_v ) - { - trace_i->ComputeDelta(prev_trace_i, deltas); - continue; - } - - if ( trace_i->SameSingleton(*prev_trace_i) ) - // No further work needed. - continue; - } - - deltas.emplace_back(std::make_unique(this, i, trace_i->GetVal())); - } - - else if ( prev_trace_i ) - deltas.emplace_back(std::make_unique(this, i)); - } - } - -void ValTrace::ComputeTableDelta(const ValTrace* prev, DeltaVector& deltas) const - { - auto& prev_elems = prev->elems; - auto& prev_elems2 = prev->elems2; - - auto n = elems.size(); - auto is_set = elems2.size() == 0; - auto prev_n = prev_elems.size(); - - // We can't compare pointers for the indices because they're - // new objects generated afresh by TableVal::ToMap. So we do - // explicit full comparisons for equality, distinguishing values - // newly added, common to both, or (implicitly) removed. We'll - // then go through the common to check them further. - // - // Our approach is O(N^2), but presumably these tables aren't - // large, and in any case generating event traces is not something - // requiring high performance, so we opt for conceptual simplicity. - - // Track which index values are newly added: - std::set added_indices; - - // Track which entry traces are in common. Indexed by previous - // trace elem index, yielding current trace elem index. - std::map common_entries; - - for ( auto i = 0U; i < n; ++i ) - { - const auto trace_i = elems[i].get(); - - bool common = false; - - for ( auto j = 0U; j < prev_n; ++j ) - { - const auto prev_trace_j = prev_elems[j].get(); - - if ( *trace_i == *prev_trace_j ) - { - common_entries[j] = i; - common = true; - break; - } - } - - if ( ! common ) - { - auto v = trace_i->GetVal(); - - if ( is_set ) - deltas.emplace_back(std::make_unique(this, v)); - else - { - auto yield = elems2[i]->GetVal(); - deltas.emplace_back(std::make_unique(this, v, yield)); - } - - added_indices.insert(v.get()); - } - } - - for ( auto j = 0U; j < prev_n; ++j ) - { - auto common_pair = common_entries.find(j); - if ( common_pair == common_entries.end() ) - { - auto& prev_trace = prev_elems[j]; - auto& v = prev_trace->GetVal(); - deltas.emplace_back(std::make_unique(this, v)); - continue; - } - - if ( is_set ) - continue; - - // If we get here, we're analyzing a table for which there's - // a common index. The remaining question is whether the - // yield has changed. - auto i = common_pair->second; - auto& trace2 = elems2[i]; - const auto prev_trace2 = prev_elems2[j]; - - auto& yield = trace2->GetVal(); - auto& prev_yield = prev_trace2->GetVal(); - - if ( yield == prev_yield ) - // Same yield, look for differences in its sub-elements. - trace2->ComputeDelta(prev_trace2.get(), deltas); - - else if ( ! trace2->SameSingleton(*prev_trace2) ) - deltas.emplace_back( - std::make_unique(this, elems[i]->GetVal(), yield)); - } - } - -void ValTrace::ComputeVectorDelta(const ValTrace* prev, DeltaVector& deltas) const - { - auto& prev_elems = prev->elems; - auto n = elems.size(); - auto prev_n = prev_elems.size(); - - // TODO: The following hasn't been tested for robustness to vector holes. - - if ( n < prev_n ) - { - // The vector shrank in size. Easiest to just build it - // from scratch. - deltas.emplace_back(std::make_unique(this)); - return; - } - - // Look for existing entries that need reassignment. - auto i = 0U; - for ( ; i < prev_n; ++i ) - { - const auto trace_i = elems[i].get(); - const auto prev_trace_i = prev_elems[i].get(); - - auto& elem_i = trace_i->GetVal(); - auto& prev_elem_i = prev_trace_i->GetVal(); - - if ( elem_i == prev_elem_i ) - trace_i->ComputeDelta(prev_trace_i, deltas); - else if ( ! trace_i->SameSingleton(*prev_trace_i) ) - deltas.emplace_back(std::make_unique(this, i, elem_i)); - } - - // Now append any new entries. - for ( ; i < n; ++i ) - { - auto& trace_i = elems[i]; - auto& elem_i = trace_i->GetVal(); - deltas.emplace_back(std::make_unique(this, i, elem_i)); - } - } - -std::string ValDelta::Generate(ValTraceMgr* vtm) const - { - return ""; - } - -std::string DeltaReplaceValue::Generate(ValTraceMgr* vtm) const - { - return std::string(" = ") + vtm->ValName(new_val); - } - -std::string DeltaSetField::Generate(ValTraceMgr* vtm) const - { - auto rt = vt->GetType()->AsRecordType(); - auto f = rt->FieldName(field); - return std::string("$") + f + " = " + vtm->ValName(new_val); - } - -std::string DeltaRemoveField::Generate(ValTraceMgr* vtm) const - { - auto rt = vt->GetType()->AsRecordType(); - auto f = rt->FieldName(field); - return std::string("delete ") + vtm->ValName(vt) + "$" + f; - } - -std::string DeltaRecordCreate::Generate(ValTraceMgr* vtm) const - { - auto rv = cast_intrusive(vt->GetVal()); - auto rt = rv->GetType(); - auto n = rt->NumFields(); - - std::string args; - - for ( auto i = 0; i < n; ++i ) - { - auto v_i = rv->GetField(i); - if ( v_i ) - { - if ( ! args.empty() ) - args += ", "; - - args += std::string("$") + rt->FieldName(i) + "=" + vtm->ValName(v_i); - } - } - - auto name = rt->GetName(); - if ( name.empty() ) - name = "record"; - - return std::string(" = ") + name + "(" + args + ")"; - } - -std::string DeltaSetSetEntry::Generate(ValTraceMgr* vtm) const - { - return std::string("add ") + vtm->ValName(vt) + "[" + vtm->ValName(index) + "]"; - } - -std::string DeltaSetTableEntry::Generate(ValTraceMgr* vtm) const - { - return std::string("[") + vtm->ValName(index) + "] = " + vtm->ValName(new_val); - } - -std::string DeltaRemoveTableEntry::Generate(ValTraceMgr* vtm) const - { - return std::string("delete ") + vtm->ValName(vt) + "[" + vtm->ValName(index) + "]"; - } - -std::string DeltaSetCreate::Generate(ValTraceMgr* vtm) const - { - auto sv = cast_intrusive(vt->GetVal()); - auto members = sv->ToMap(); - - std::string args; - - for ( auto& m : members ) - { - if ( ! args.empty() ) - args += ", "; - - args += vtm->ValName(m.first); - } - - auto name = sv->GetType()->GetName(); - if ( name.empty() ) - name = "set"; - - return std::string(" = ") + name + "(" + args + ")"; - } - -std::string DeltaTableCreate::Generate(ValTraceMgr* vtm) const - { - auto tv = cast_intrusive(vt->GetVal()); - auto members = tv->ToMap(); - - std::string args; - - for ( auto& m : members ) - { - if ( ! args.empty() ) - args += ", "; - - args += std::string("[") + vtm->ValName(m.first) + "] = " + vtm->ValName(m.second); - } - - auto name = tv->GetType()->GetName(); - if ( name.empty() ) - name = "table"; - - return std::string(" = ") + name + "(" + args + ")"; - } - -std::string DeltaVectorSet::Generate(ValTraceMgr* vtm) const - { - return std::string("[") + std::to_string(index) + "] = " + vtm->ValName(elem); - } - -std::string DeltaVectorAppend::Generate(ValTraceMgr* vtm) const - { - return std::string("[") + std::to_string(index) + "] = " + vtm->ValName(elem); - } - -std::string DeltaVectorCreate::Generate(ValTraceMgr* vtm) const - { - auto& elems = vt->GetElems(); - std::string vec; - - for ( auto& e : elems ) - { - if ( vec.size() > 0 ) - vec += ", "; +static std::string escape_string(const u_char* b, int len) { + std::string res = "\""; - vec += vtm->ValName(e->GetVal()); - } - - return std::string(" = vector(") + vec + ")"; - } - -std::string DeltaUnsupportedCreate::Generate(ValTraceMgr* vtm) const - { - return " = UNSUPPORTED " + obj_desc_short(vt->GetVal()->GetType().get()); - } + for ( int i = 0; i < len; ++i ) { + unsigned char c = b[i]; -EventTrace::EventTrace(const ScriptFunc* _ev, double _nt, size_t event_num) : ev(_ev), nt(_nt) - { - auto ev_name = std::regex_replace(ev->Name(), std::regex(":"), "_"); + switch ( c ) { + case '\a': res += "\\a"; break; + case '\b': res += "\\b"; break; + case '\f': res += "\\f"; break; + case '\n': res += "\\n"; break; + case '\r': res += "\\r"; break; + case '\t': res += "\\t"; break; + case '\v': res += "\\v"; break; - name = ev_name + "_" + std::to_string(event_num) + "__et"; - } + case '\\': res += "\\\\"; break; + case '"': res += "\\\""; break; + + default: + if ( isprint(c) ) + res += c; + else { + char buf[8192]; + snprintf(buf, sizeof buf, "%03o", c); + res += "\\"; + res += buf; + } + break; + } + } + + return res + "\""; +} + +ValTrace::ValTrace(const ValPtr& _v) : v(_v) { + t = v->GetType(); + + switch ( t->Tag() ) { + case TYPE_LIST: TraceList(cast_intrusive(v)); break; + + case TYPE_RECORD: TraceRecord(cast_intrusive(v)); break; + + case TYPE_TABLE: TraceTable(cast_intrusive(v)); break; + + case TYPE_VECTOR: TraceVector(cast_intrusive(v)); break; + + default: break; + } +} + +bool ValTrace::operator==(const ValTrace& vt) const { + auto& vt_v = vt.GetVal(); + if ( vt_v == v ) + return true; + + auto tag = t->Tag(); + + if ( vt.GetType()->Tag() != tag ) + return false; + + switch ( tag ) { + case TYPE_BOOL: + case TYPE_INT: + case TYPE_ENUM: return v->AsInt() == vt_v->AsInt(); + + case TYPE_COUNT: + case TYPE_PORT: return v->AsCount() == vt_v->AsCount(); + + case TYPE_DOUBLE: + case TYPE_INTERVAL: + case TYPE_TIME: return v->AsDouble() == vt_v->AsDouble(); + + case TYPE_STRING: return (*v->AsString()) == (*vt_v->AsString()); + + case TYPE_ADDR: return v->AsAddr() == vt_v->AsAddr(); + + case TYPE_SUBNET: return v->AsSubNet() == vt_v->AsSubNet(); + + case TYPE_FUNC: return v->AsFile() == vt_v->AsFile(); + + case TYPE_FILE: return v->AsFile() == vt_v->AsFile(); + + case TYPE_PATTERN: return v->AsPattern() == vt_v->AsPattern(); + + case TYPE_ANY: return v->AsSubNet() == vt_v->AsSubNet(); + + case TYPE_TYPE: return v->AsType() == vt_v->AsType(); + + case TYPE_OPAQUE: return false; // needs pointer equivalence + + case TYPE_LIST: return SameList(vt); + + case TYPE_RECORD: return SameRecord(vt); + + case TYPE_TABLE: return SameTable(vt); + + case TYPE_VECTOR: return SameVector(vt); + + default: reporter->InternalError("bad type in ValTrace::operator=="); + } +} + +void ValTrace::ComputeDelta(const ValTrace* prev, DeltaVector& deltas) const { + auto tag = t->Tag(); + + if ( prev ) { + ASSERT(prev->GetType()->Tag() == tag); + + auto& prev_v = prev->GetVal(); + + if ( prev_v != v ) { + if ( *this != *prev ) + deltas.emplace_back(std::make_unique(this, v)); + return; + } + } + + switch ( tag ) { + case TYPE_BOOL: + case TYPE_INT: + case TYPE_ENUM: + case TYPE_COUNT: + case TYPE_PORT: + case TYPE_DOUBLE: + case TYPE_INTERVAL: + case TYPE_TIME: + case TYPE_STRING: + case TYPE_ADDR: + case TYPE_SUBNET: + case TYPE_FUNC: + case TYPE_PATTERN: + case TYPE_TYPE: + // These don't change in place. No need to create + // them as stand-alone variables, since we can just + // use the constant representation instead. + break; + + case TYPE_ANY: + case TYPE_FILE: + case TYPE_OPAQUE: + // If we have a previous instance, we can ignore this + // one, because we know it's equivalent (due to the + // test at the beginning of this method), and it's + // not meaningful to recurse inside it looking for + // interior changes. + if ( ! prev ) + deltas.emplace_back(std::make_unique(this)); + break; + + case TYPE_LIST: + // We shouldn't see these exposed directly, as they're + // not manipulable at script-level. An exception + // might be for "any" types that are then decomposed + // via compound assignment; for now, we don't support + // those. + reporter->InternalError("list type seen in ValTrace::ComputeDelta"); + break; + + case TYPE_RECORD: + if ( prev ) + ComputeRecordDelta(prev, deltas); + else + deltas.emplace_back(std::make_unique(this)); + break; + + case TYPE_TABLE: + if ( prev ) + ComputeTableDelta(prev, deltas); + + else if ( GetType()->AsTableType()->IsUnspecifiedTable() ) + // For unspecified values, we generate them + // as empty constructors, because we don't + // know their yield type and thus can't + // create variables corresponding to them. + break; + + else if ( t->Yield() ) + deltas.emplace_back(std::make_unique(this)); + else + deltas.emplace_back(std::make_unique(this)); + break; + + case TYPE_VECTOR: + if ( prev ) + ComputeVectorDelta(prev, deltas); + + else if ( GetType()->AsVectorType()->IsUnspecifiedVector() ) + // See above for empty tables/sets. + break; + + else + deltas.emplace_back(std::make_unique(this)); + break; + + default: reporter->InternalError("bad type in ValTrace::ComputeDelta"); + } +} + +void ValTrace::TraceList(const ListValPtr& lv) { + auto vals = lv->Vals(); + for ( auto& v : vals ) + elems.emplace_back(std::make_shared(v)); +} + +void ValTrace::TraceRecord(const RecordValPtr& rv) { + auto n = rv->NumFields(); + auto rt = rv->GetType(); + + for ( auto i = 0U; i < n; ++i ) { + auto f = rv->RawOptField(i); + if ( f ) { + auto val = f->ToVal(rt->GetFieldType(i)); + elems.emplace_back(std::make_shared(val)); + } + else + elems.emplace_back(nullptr); + } +} + +void ValTrace::TraceTable(const TableValPtr& tv) { + for ( auto& elem : tv->ToMap() ) { + auto& key = elem.first; + elems.emplace_back(std::make_shared(key)); + + auto& val = elem.second; + if ( val ) + elems2.emplace_back(std::make_shared(val)); + } +} + +void ValTrace::TraceVector(const VectorValPtr& vv) { + auto& vec = vv->RawVec(); + auto n = vec.size(); + auto& yt = vv->RawYieldType(); + auto& yts = vv->RawYieldTypes(); + + for ( auto i = 0U; i < n; ++i ) { + auto& elem = vec[i]; + if ( elem ) { + auto& t = yts ? (*yts)[i] : yt; + auto val = elem->ToVal(t); + elems.emplace_back(std::make_shared(val)); + } + else + elems.emplace_back(nullptr); + } +} + +bool ValTrace::SameList(const ValTrace& vt) const { return SameElems(vt); } + +bool ValTrace::SameRecord(const ValTrace& vt) const { return SameElems(vt); } + +bool ValTrace::SameTable(const ValTrace& vt) const { + auto& vt_elems = vt.elems; + auto n = elems.size(); + if ( n != vt_elems.size() ) + return false; + + auto& vt_elems2 = vt.elems2; + auto n2 = elems2.size(); + if ( n2 != vt_elems2.size() ) + return false; + + ASSERT(n2 == 0 || n == n2); + + // We accommodate the possibility that keys are out-of-order + // between the two sets of elements. + + // The following is O(N^2), but presumably if tables are somehow + // involved (in fact we can only get here if they're used as + // indices into other tables), then they'll likely be small. + for ( auto i = 0U; i < n; ++i ) { + auto& elem_i = elems[i]; + + // See if we can find a match for it. If we do, we don't + // have to worry that another entry matched it too, since + // all table/set indices will be distinct. + auto j = 0U; + for ( ; j < n; ++j ) { + auto& vt_elem_j = vt_elems[j]; + if ( *elem_i == *vt_elem_j ) + break; + } + + if ( j == n ) + // No match for the index. + return false; + + if ( n2 > 0 ) { + // Need a match for the corresponding yield values. + if ( *elems2[i] != *vt_elems2[j] ) + return false; + } + } + + return true; +} + +bool ValTrace::SameVector(const ValTrace& vt) const { return SameElems(vt); } + +bool ValTrace::SameElems(const ValTrace& vt) const { + auto& vt_elems = vt.elems; + auto n = elems.size(); + if ( n != vt_elems.size() ) + return false; + + for ( auto i = 0U; i < n; ++i ) { + auto& trace_i = elems[i]; + auto& vt_trace_i = vt_elems[i]; + + if ( trace_i && vt_trace_i ) { + if ( *trace_i != *vt_trace_i ) + return false; + } + + else if ( trace_i || vt_trace_i ) + return false; + } + + return true; +} + +bool ValTrace::SameSingleton(const ValTrace& vt) const { return ! IsAggr(t) && *this == vt; } + +void ValTrace::ComputeRecordDelta(const ValTrace* prev, DeltaVector& deltas) const { + auto& prev_elems = prev->elems; + auto n = elems.size(); + if ( n != prev_elems.size() ) + reporter->InternalError("size inconsistency in ValTrace::ComputeRecordDelta"); + + for ( auto i = 0U; i < n; ++i ) { + const auto trace_i = elems[i].get(); + const auto prev_trace_i = prev_elems[i].get(); + + if ( trace_i ) { + if ( prev_trace_i ) { + auto& v = trace_i->GetVal(); + auto& prev_v = prev_trace_i->GetVal(); + + if ( v == prev_v ) { + trace_i->ComputeDelta(prev_trace_i, deltas); + continue; + } + + if ( trace_i->SameSingleton(*prev_trace_i) ) + // No further work needed. + continue; + } + + deltas.emplace_back(std::make_unique(this, i, trace_i->GetVal())); + } + + else if ( prev_trace_i ) + deltas.emplace_back(std::make_unique(this, i)); + } +} + +void ValTrace::ComputeTableDelta(const ValTrace* prev, DeltaVector& deltas) const { + auto& prev_elems = prev->elems; + auto& prev_elems2 = prev->elems2; + + auto n = elems.size(); + auto is_set = elems2.size() == 0; + auto prev_n = prev_elems.size(); + + // We can't compare pointers for the indices because they're + // new objects generated afresh by TableVal::ToMap. So we do + // explicit full comparisons for equality, distinguishing values + // newly added, common to both, or (implicitly) removed. We'll + // then go through the common to check them further. + // + // Our approach is O(N^2), but presumably these tables aren't + // large, and in any case generating event traces is not something + // requiring high performance, so we opt for conceptual simplicity. + + // Track which index values are newly added: + std::set added_indices; + + // Track which entry traces are in common. Indexed by previous + // trace elem index, yielding current trace elem index. + std::map common_entries; + + for ( auto i = 0U; i < n; ++i ) { + const auto trace_i = elems[i].get(); + + bool common = false; + + for ( auto j = 0U; j < prev_n; ++j ) { + const auto prev_trace_j = prev_elems[j].get(); + + if ( *trace_i == *prev_trace_j ) { + common_entries[j] = i; + common = true; + break; + } + } + + if ( ! common ) { + auto v = trace_i->GetVal(); + + if ( is_set ) + deltas.emplace_back(std::make_unique(this, v)); + else { + auto yield = elems2[i]->GetVal(); + deltas.emplace_back(std::make_unique(this, v, yield)); + } + + added_indices.insert(v.get()); + } + } + + for ( auto j = 0U; j < prev_n; ++j ) { + auto common_pair = common_entries.find(j); + if ( common_pair == common_entries.end() ) { + auto& prev_trace = prev_elems[j]; + auto& v = prev_trace->GetVal(); + deltas.emplace_back(std::make_unique(this, v)); + continue; + } + + if ( is_set ) + continue; + + // If we get here, we're analyzing a table for which there's + // a common index. The remaining question is whether the + // yield has changed. + auto i = common_pair->second; + auto& trace2 = elems2[i]; + const auto prev_trace2 = prev_elems2[j]; + + auto& yield = trace2->GetVal(); + auto& prev_yield = prev_trace2->GetVal(); + + if ( yield == prev_yield ) + // Same yield, look for differences in its sub-elements. + trace2->ComputeDelta(prev_trace2.get(), deltas); + + else if ( ! trace2->SameSingleton(*prev_trace2) ) + deltas.emplace_back(std::make_unique(this, elems[i]->GetVal(), yield)); + } +} + +void ValTrace::ComputeVectorDelta(const ValTrace* prev, DeltaVector& deltas) const { + auto& prev_elems = prev->elems; + auto n = elems.size(); + auto prev_n = prev_elems.size(); + + // TODO: The following hasn't been tested for robustness to vector holes. + + if ( n < prev_n ) { + // The vector shrank in size. Easiest to just build it + // from scratch. + deltas.emplace_back(std::make_unique(this)); + return; + } + + // Look for existing entries that need reassignment. + auto i = 0U; + for ( ; i < prev_n; ++i ) { + const auto trace_i = elems[i].get(); + const auto prev_trace_i = prev_elems[i].get(); + + auto& elem_i = trace_i->GetVal(); + auto& prev_elem_i = prev_trace_i->GetVal(); + + if ( elem_i == prev_elem_i ) + trace_i->ComputeDelta(prev_trace_i, deltas); + else if ( ! trace_i->SameSingleton(*prev_trace_i) ) + deltas.emplace_back(std::make_unique(this, i, elem_i)); + } + + // Now append any new entries. + for ( ; i < n; ++i ) { + auto& trace_i = elems[i]; + auto& elem_i = trace_i->GetVal(); + deltas.emplace_back(std::make_unique(this, i, elem_i)); + } +} + +std::string ValDelta::Generate(ValTraceMgr* vtm) const { return ""; } + +std::string DeltaReplaceValue::Generate(ValTraceMgr* vtm) const { return std::string(" = ") + vtm->ValName(new_val); } + +std::string DeltaSetField::Generate(ValTraceMgr* vtm) const { + auto rt = vt->GetType()->AsRecordType(); + auto f = rt->FieldName(field); + return std::string("$") + f + " = " + vtm->ValName(new_val); +} + +std::string DeltaRemoveField::Generate(ValTraceMgr* vtm) const { + auto rt = vt->GetType()->AsRecordType(); + auto f = rt->FieldName(field); + return std::string("delete ") + vtm->ValName(vt) + "$" + f; +} + +std::string DeltaRecordCreate::Generate(ValTraceMgr* vtm) const { + auto rv = cast_intrusive(vt->GetVal()); + auto rt = rv->GetType(); + auto n = rt->NumFields(); + + std::string args; + + for ( auto i = 0; i < n; ++i ) { + auto v_i = rv->GetField(i); + if ( v_i ) { + if ( ! args.empty() ) + args += ", "; + + args += std::string("$") + rt->FieldName(i) + "=" + vtm->ValName(v_i); + } + } + + auto name = rt->GetName(); + if ( name.empty() ) + name = "record"; + + return std::string(" = ") + name + "(" + args + ")"; +} + +std::string DeltaSetSetEntry::Generate(ValTraceMgr* vtm) const { + return std::string("add ") + vtm->ValName(vt) + "[" + vtm->ValName(index) + "]"; +} + +std::string DeltaSetTableEntry::Generate(ValTraceMgr* vtm) const { + return std::string("[") + vtm->ValName(index) + "] = " + vtm->ValName(new_val); +} + +std::string DeltaRemoveTableEntry::Generate(ValTraceMgr* vtm) const { + return std::string("delete ") + vtm->ValName(vt) + "[" + vtm->ValName(index) + "]"; +} + +std::string DeltaSetCreate::Generate(ValTraceMgr* vtm) const { + auto sv = cast_intrusive(vt->GetVal()); + auto members = sv->ToMap(); + + std::string args; + + for ( auto& m : members ) { + if ( ! args.empty() ) + args += ", "; + + args += vtm->ValName(m.first); + } + + auto name = sv->GetType()->GetName(); + if ( name.empty() ) + name = "set"; + + return std::string(" = ") + name + "(" + args + ")"; +} + +std::string DeltaTableCreate::Generate(ValTraceMgr* vtm) const { + auto tv = cast_intrusive(vt->GetVal()); + auto members = tv->ToMap(); + + std::string args; + + for ( auto& m : members ) { + if ( ! args.empty() ) + args += ", "; + + args += std::string("[") + vtm->ValName(m.first) + "] = " + vtm->ValName(m.second); + } + + auto name = tv->GetType()->GetName(); + if ( name.empty() ) + name = "table"; + + return std::string(" = ") + name + "(" + args + ")"; +} + +std::string DeltaVectorSet::Generate(ValTraceMgr* vtm) const { + return std::string("[") + std::to_string(index) + "] = " + vtm->ValName(elem); +} + +std::string DeltaVectorAppend::Generate(ValTraceMgr* vtm) const { + return std::string("[") + std::to_string(index) + "] = " + vtm->ValName(elem); +} + +std::string DeltaVectorCreate::Generate(ValTraceMgr* vtm) const { + auto& elems = vt->GetElems(); + std::string vec; + + for ( auto& e : elems ) { + if ( vec.size() > 0 ) + vec += ", "; + + vec += vtm->ValName(e->GetVal()); + } + + return std::string(" = vector(") + vec + ")"; +} + +std::string DeltaUnsupportedCreate::Generate(ValTraceMgr* vtm) const { + return " = UNSUPPORTED " + obj_desc_short(vt->GetVal()->GetType().get()); +} + +EventTrace::EventTrace(const ScriptFunc* _ev, double _nt, size_t event_num) : ev(_ev), nt(_nt) { + auto ev_name = std::regex_replace(ev->Name(), std::regex(":"), "_"); + + name = ev_name + "_" + std::to_string(event_num) + "__et"; +} void EventTrace::Generate(FILE* f, ValTraceMgr& vtm, const DeltaGenVec& dvec, std::string successor, - int num_pre) const - { - int offset = 0; - for ( auto& d : dvec ) - { - auto& val = d.GetVal(); - - if ( d.IsFirstDef() && vtm.IsGlobal(val) ) - { - auto& val_name = vtm.ValName(val); - - std::string type_name; - auto& t = val->GetType(); - auto& tn = t->GetName(); - if ( tn.empty() ) - { - ODesc d; - t->Describe(&d); - type_name = d.Description(); - } - else - type_name = tn; - - auto anno = offset < num_pre ? " # from script" : ""; - - fprintf(f, "global %s: %s;%s\n", val_name.c_str(), type_name.c_str(), anno); - } - - ++offset; - } - - fprintf(f, "event %s()\n", name.c_str()); - fprintf(f, "\t{\n"); - - offset = 0; - for ( auto& d : dvec ) - { - fprintf(f, "\t"); - - auto& val = d.GetVal(); - bool define_local = d.IsFirstDef() && ! vtm.IsGlobal(val); - - if ( define_local ) - fprintf(f, "local "); - - if ( d.NeedsLHS() ) - { - fprintf(f, "%s", vtm.ValName(val).c_str()); - - if ( define_local ) - fprintf(f, ": %s", obj_desc_short(val->GetType().get()).c_str()); - } - - auto anno = offset < num_pre ? " # from script" : ""; - - fprintf(f, "%s;%s\n", d.RHS().c_str(), anno); - - ++offset; - } - - if ( ! dvec.empty() ) - fprintf(f, "\n"); - - fprintf(f, "\tevent %s(%s);\n\n", ev->Name(), args.c_str()); - - if ( successor.empty() ) - { - // The following isn't necessary with our current approach - // to managing chains of events, which avoids having to set - // exit_only_after_terminate=T. - // fprintf(f, "\tterminate();\n"); - } - else - { - auto tm = vtm.TimeConstant(nt); - fprintf(f, "\tset_network_time(%s);\n", tm.c_str()); - fprintf(f, "\tevent __EventTrace::%s();\n", successor.c_str()); - } - - fprintf(f, "\t}\n"); - } - -void EventTrace::Generate(FILE* f, ValTraceMgr& vtm, const EventTrace* predecessor, - std::string successor) const - { - if ( predecessor ) - { - auto& pre_deltas = predecessor->post_deltas; - int num_pre = pre_deltas.size(); - - if ( num_pre > 0 ) - { - auto total_deltas = pre_deltas; - total_deltas.insert(total_deltas.end(), deltas.begin(), deltas.end()); - Generate(f, vtm, total_deltas, successor, num_pre); - return; - } - } - - Generate(f, vtm, deltas, successor); - } - -void ValTraceMgr::TraceEventValues(std::shared_ptr et, const zeek::Args* args) - { - curr_ev = std::move(et); - - auto num_vals = vals.size(); - - std::string ev_args; - for ( auto& a : *args ) - { - AddVal(a); - - if ( ! ev_args.empty() ) - ev_args += ", "; - - ev_args += ValName(a); - } - - curr_ev->SetArgs(ev_args); - - // Now look for any values newly-processed with this event and - // remember them so we can catch uses of them in future events. - for ( auto i = num_vals; i < vals.size(); ++i ) - { - processed_vals.insert(vals[i].get()); - ASSERT(val_names.count(vals[i].get()) > 0); - } - } - -void ValTraceMgr::FinishCurrentEvent(const zeek::Args* args) - { - auto num_vals = vals.size(); - - curr_ev->SetDoingPost(); - - for ( auto& a : *args ) - AddVal(a); - - for ( auto i = num_vals; i < vals.size(); ++i ) - processed_vals.insert(vals[i].get()); - } - -const std::string& ValTraceMgr::ValName(const ValPtr& v) - { - auto find = val_names.find(v.get()); - if ( find == val_names.end() ) - find = val_names.insert({v.get(), GenValName(v)}).first; - - ValUsed(v); - - return find->second; - } - -std::string ValTraceMgr::TimeConstant(double t) - { - if ( t < std::max(base_time, 1e6) ) - return "double_to_time(" + std::to_string(t) + ")"; - - if ( ! base_time ) - base_time = t; - - if ( t == base_time ) - return "double_to_time(__base_time)"; - - t -= base_time; - return "double_to_time(__base_time + " + std::to_string(t) + ")"; - } - -void ValTraceMgr::AddVal(ValPtr v) - { - auto mapping = val_map.find(v.get()); - - if ( mapping == val_map.end() ) - NewVal(v); - else - { - auto vt = std::make_shared(v); - AssessChange(vt.get(), mapping->second.get()); - val_map[v.get()] = vt; - } - } - -void ValTraceMgr::NewVal(ValPtr v) - { - // Make sure the Val sticks around into the future. - vals.push_back(v); - - auto vt = std::make_shared(v); - AssessChange(vt.get(), nullptr); - val_map[v.get()] = vt; - } - -void ValTraceMgr::ValUsed(const ValPtr& v) - { - ASSERT(val_names.count(v.get()) > 0); - if ( processed_vals.count(v.get()) > 0 ) - // We saw this value when processing a previous event. - globals.insert(v.get()); - } - -void ValTraceMgr::AssessChange(const ValTrace* vt, const ValTrace* prev_vt) - { - DeltaVector deltas; - - vt->ComputeDelta(prev_vt, deltas); - - // Used to track deltas across the batch, to suppress redundant ones - // (which can arise due to two aggregates both including the same - // sub-element). - std::unordered_set previous_deltas; - - for ( auto& d : deltas ) - { - auto vp = d->GetValTrace()->GetVal(); - auto v = vp.get(); - auto rhs = d->Generate(this); - - bool needs_lhs = d->NeedsLHS(); - bool is_first_def = false; - - if ( needs_lhs && val_names.count(v) == 0 ) - { - TrackVar(v); - is_first_def = true; - } - - ASSERT(val_names.count(v) > 0); - - // The "/" in the following is just to have a delimiter - // to make sure the string is unambiguous. - auto full_delta = val_names[v] + "/" + rhs; - if ( previous_deltas.count(full_delta) > 0 ) - continue; - - previous_deltas.insert(full_delta); - - ValUsed(vp); - curr_ev->AddDelta(vp, rhs, needs_lhs, is_first_def); - } - - auto& v = vt->GetVal(); - if ( IsAggr(v->GetType()) && (prev_vt || ! IsUnspecifiedAggregate(v)) ) - ValUsed(vt->GetVal()); - } - -void ValTraceMgr::TrackVar(const Val* v) - { - std::string base_name = IsUnsupported(v) ? "UNSUPPORTED" : "val"; - auto val_name = "__" + base_name + std::to_string(num_vars++); - val_names[v] = val_name; - } - -std::string ValTraceMgr::GenValName(const ValPtr& v) - { - if ( IsAggr(v->GetType()) && ! IsUnspecifiedAggregate(v) ) - { // Aggregate shouldn't exist; create it - ASSERT(val_map.count(v.get()) == 0); - NewVal(v); - return val_names[v.get()]; - } - - // Non-aggregate (or unspecified aggregate) can be expressed using - // a constant. - auto t = v->GetType(); - auto tag = t->Tag(); - std::string rep; - bool track_constant = false; - - switch ( tag ) - { - case TYPE_STRING: - { - auto s = v->AsStringVal(); - rep = escape_string(s->Bytes(), s->Len()); - track_constant = s->Len() > 0; - break; - } - - case TYPE_LIST: - { - auto lv = cast_intrusive(v); - for ( auto& v_i : lv->Vals() ) - { - if ( ! rep.empty() ) - rep += ", "; - - rep += ValName(v_i); - } - break; - } - - case TYPE_FUNC: - rep = v->AsFunc()->Name(); - break; - - case TYPE_TIME: - { - auto tm = v->AsDouble(); - rep = TimeConstant(tm); - - if ( tm > 0.0 && rep.find("__base_time") == std::string::npos ) - // We're not representing it using base_time. - track_constant = true; - - break; - } - - case TYPE_INTERVAL: - rep = "double_to_interval(" + std::to_string(v->AsDouble()) + ")"; - break; - - case TYPE_TABLE: - rep = t->Yield() ? "table()" : "set()"; - break; - - case TYPE_VECTOR: - rep = "vector()"; - break; - - case TYPE_PATTERN: - case TYPE_PORT: - case TYPE_ADDR: - case TYPE_SUBNET: - { - ODesc d; - v->Describe(&d); - rep = d.Description(); - track_constant = true; - - if ( tag == TYPE_ADDR || tag == TYPE_SUBNET ) - { - // Fix up deficiency that IPv6 addresses are - // described without surrounding []'s. - const auto& addr = tag == TYPE_ADDR ? v->AsAddr() : v->AsSubNet().Prefix(); - if ( addr.GetFamily() == IPv6 ) - rep = "[" + rep + "]"; - } - } - break; - - default: - { - ODesc d; - v->Describe(&d); - rep = d.Description(); - } - } - - val_names[v.get()] = rep; - vals.push_back(v); - - if ( track_constant ) - constants[tag].insert(rep); - - std::array constants; - - return rep; - } - -bool ValTraceMgr::IsUnspecifiedAggregate(const ValPtr& v) const - { - auto t = v->GetType()->Tag(); - - if ( t == TYPE_TABLE && v->GetType()->IsUnspecifiedTable() ) - return true; - - if ( t == TYPE_VECTOR && v->GetType()->IsUnspecifiedVector() ) - return true; - - return false; - } - -bool ValTraceMgr::IsUnsupported(const Val* v) const - { - auto t = v->GetType()->Tag(); - return t == TYPE_ANY || t == TYPE_FILE || t == TYPE_OPAQUE; - } - -EventTraceMgr::EventTraceMgr(const std::string& trace_file) - { - f = fopen(trace_file.c_str(), "w"); - if ( ! f ) - reporter->FatalError("can't open event trace file %s", trace_file.c_str()); - } - -EventTraceMgr::~EventTraceMgr() - { - if ( events.empty() ) - return; - - fprintf(f, "module __EventTrace;\n\n"); - - auto bt = vtm.GetBaseTime(); - - if ( bt ) - fprintf(f, "global __base_time = %.06f;\n\n", bt); - - for ( auto& e : events ) - fprintf(f, "global %s: event();\n", e->GetName()); - - fprintf(f, "\nevent zeek_init() &priority=-999999\n"); - fprintf(f, "\t{\n"); - fprintf(f, "\tevent __EventTrace::%s();\n", events.front()->GetName()); - fprintf(f, "\t}\n"); - - for ( auto i = 0U; i < events.size(); ++i ) - { - fprintf(f, "\n"); - - auto predecessor = i > 0 ? events[i - 1] : nullptr; - auto successor = i + 1 < events.size() ? events[i + 1]->GetName() : ""; - events[i]->Generate(f, vtm, predecessor.get(), successor); - } - - const auto& constants = vtm.GetConstants(); - - for ( auto tag = 0; tag < NUM_TYPES; ++tag ) - { - auto& c_t = constants[tag]; - if ( c_t.empty() && (tag != TYPE_TIME || ! bt) ) - continue; - - fprintf(f, "\n# constants of type %s:\n", type_name(TypeTag(tag))); - if ( tag == TYPE_TIME && bt ) - fprintf(f, "#\t__base_time = %.06f\n", bt); - - for ( auto& c : c_t ) - fprintf(f, "#\t%s\n", c.c_str()); - } - - fclose(f); - } + int num_pre) const { + int offset = 0; + for ( auto& d : dvec ) { + auto& val = d.GetVal(); -void EventTraceMgr::StartEvent(const ScriptFunc* ev, const zeek::Args* args) - { - if ( script_events.count(ev->Name()) > 0 ) - return; + if ( d.IsFirstDef() && vtm.IsGlobal(val) ) { + auto& val_name = vtm.ValName(val); - auto nt = run_state::network_time; - if ( nt == 0.0 || util::streq(ev->Name(), "zeek_init") ) - return; + std::string type_name; + auto& t = val->GetType(); + auto& tn = t->GetName(); + if ( tn.empty() ) { + ODesc d; + t->Describe(&d); + type_name = d.Description(); + } + else + type_name = tn; - if ( ! vtm.GetBaseTime() ) - vtm.SetBaseTime(nt); + auto anno = offset < num_pre ? " # from script" : ""; - auto et = std::make_shared(ev, nt, events.size()); - events.emplace_back(et); + fprintf(f, "global %s: %s;%s\n", val_name.c_str(), type_name.c_str(), anno); + } - vtm.TraceEventValues(et, args); - } + ++offset; + } -void EventTraceMgr::EndEvent(const ScriptFunc* ev, const zeek::Args* args) - { - if ( script_events.count(ev->Name()) > 0 ) - return; + fprintf(f, "event %s()\n", name.c_str()); + fprintf(f, "\t{\n"); - if ( run_state::network_time > 0.0 && ! util::streq(ev->Name(), "zeek_init") ) - vtm.FinishCurrentEvent(args); - } + offset = 0; + for ( auto& d : dvec ) { + fprintf(f, "\t"); -void EventTraceMgr::ScriptEventQueued(const EventHandlerPtr& h) - { - script_events.insert(h->Name()); - } + auto& val = d.GetVal(); + bool define_local = d.IsFirstDef() && ! vtm.IsGlobal(val); - } // namespace zeek::detail + if ( define_local ) + fprintf(f, "local "); + + if ( d.NeedsLHS() ) { + fprintf(f, "%s", vtm.ValName(val).c_str()); + + if ( define_local ) + fprintf(f, ": %s", obj_desc_short(val->GetType().get()).c_str()); + } + + auto anno = offset < num_pre ? " # from script" : ""; + + fprintf(f, "%s;%s\n", d.RHS().c_str(), anno); + + ++offset; + } + + if ( ! dvec.empty() ) + fprintf(f, "\n"); + + fprintf(f, "\tevent %s(%s);\n\n", ev->Name(), args.c_str()); + + if ( successor.empty() ) { + // The following isn't necessary with our current approach + // to managing chains of events, which avoids having to set + // exit_only_after_terminate=T. + // fprintf(f, "\tterminate();\n"); + } + else { + auto tm = vtm.TimeConstant(nt); + fprintf(f, "\tset_network_time(%s);\n", tm.c_str()); + fprintf(f, "\tevent __EventTrace::%s();\n", successor.c_str()); + } + + fprintf(f, "\t}\n"); +} + +void EventTrace::Generate(FILE* f, ValTraceMgr& vtm, const EventTrace* predecessor, std::string successor) const { + if ( predecessor ) { + auto& pre_deltas = predecessor->post_deltas; + int num_pre = pre_deltas.size(); + + if ( num_pre > 0 ) { + auto total_deltas = pre_deltas; + total_deltas.insert(total_deltas.end(), deltas.begin(), deltas.end()); + Generate(f, vtm, total_deltas, successor, num_pre); + return; + } + } + + Generate(f, vtm, deltas, successor); +} + +void ValTraceMgr::TraceEventValues(std::shared_ptr et, const zeek::Args* args) { + curr_ev = std::move(et); + + auto num_vals = vals.size(); + + std::string ev_args; + for ( auto& a : *args ) { + AddVal(a); + + if ( ! ev_args.empty() ) + ev_args += ", "; + + ev_args += ValName(a); + } + + curr_ev->SetArgs(ev_args); + + // Now look for any values newly-processed with this event and + // remember them so we can catch uses of them in future events. + for ( auto i = num_vals; i < vals.size(); ++i ) { + processed_vals.insert(vals[i].get()); + ASSERT(val_names.count(vals[i].get()) > 0); + } +} + +void ValTraceMgr::FinishCurrentEvent(const zeek::Args* args) { + auto num_vals = vals.size(); + + curr_ev->SetDoingPost(); + + for ( auto& a : *args ) + AddVal(a); + + for ( auto i = num_vals; i < vals.size(); ++i ) + processed_vals.insert(vals[i].get()); +} + +const std::string& ValTraceMgr::ValName(const ValPtr& v) { + auto find = val_names.find(v.get()); + if ( find == val_names.end() ) + find = val_names.insert({v.get(), GenValName(v)}).first; + + ValUsed(v); + + return find->second; +} + +std::string ValTraceMgr::TimeConstant(double t) { + if ( t < std::max(base_time, 1e6) ) + return "double_to_time(" + std::to_string(t) + ")"; + + if ( ! base_time ) + base_time = t; + + if ( t == base_time ) + return "double_to_time(__base_time)"; + + t -= base_time; + return "double_to_time(__base_time + " + std::to_string(t) + ")"; +} + +void ValTraceMgr::AddVal(ValPtr v) { + auto mapping = val_map.find(v.get()); + + if ( mapping == val_map.end() ) + NewVal(v); + else { + auto vt = std::make_shared(v); + AssessChange(vt.get(), mapping->second.get()); + val_map[v.get()] = vt; + } +} + +void ValTraceMgr::NewVal(ValPtr v) { + // Make sure the Val sticks around into the future. + vals.push_back(v); + + auto vt = std::make_shared(v); + AssessChange(vt.get(), nullptr); + val_map[v.get()] = vt; +} + +void ValTraceMgr::ValUsed(const ValPtr& v) { + ASSERT(val_names.count(v.get()) > 0); + if ( processed_vals.count(v.get()) > 0 ) + // We saw this value when processing a previous event. + globals.insert(v.get()); +} + +void ValTraceMgr::AssessChange(const ValTrace* vt, const ValTrace* prev_vt) { + DeltaVector deltas; + + vt->ComputeDelta(prev_vt, deltas); + + // Used to track deltas across the batch, to suppress redundant ones + // (which can arise due to two aggregates both including the same + // sub-element). + std::unordered_set previous_deltas; + + for ( auto& d : deltas ) { + auto vp = d->GetValTrace()->GetVal(); + auto v = vp.get(); + auto rhs = d->Generate(this); + + bool needs_lhs = d->NeedsLHS(); + bool is_first_def = false; + + if ( needs_lhs && val_names.count(v) == 0 ) { + TrackVar(v); + is_first_def = true; + } + + ASSERT(val_names.count(v) > 0); + + // The "/" in the following is just to have a delimiter + // to make sure the string is unambiguous. + auto full_delta = val_names[v] + "/" + rhs; + if ( previous_deltas.count(full_delta) > 0 ) + continue; + + previous_deltas.insert(full_delta); + + ValUsed(vp); + curr_ev->AddDelta(vp, rhs, needs_lhs, is_first_def); + } + + auto& v = vt->GetVal(); + if ( IsAggr(v->GetType()) && (prev_vt || ! IsUnspecifiedAggregate(v)) ) + ValUsed(vt->GetVal()); +} + +void ValTraceMgr::TrackVar(const Val* v) { + std::string base_name = IsUnsupported(v) ? "UNSUPPORTED" : "val"; + auto val_name = "__" + base_name + std::to_string(num_vars++); + val_names[v] = val_name; +} + +std::string ValTraceMgr::GenValName(const ValPtr& v) { + if ( IsAggr(v->GetType()) && ! IsUnspecifiedAggregate(v) ) { // Aggregate shouldn't exist; create it + ASSERT(val_map.count(v.get()) == 0); + NewVal(v); + return val_names[v.get()]; + } + + // Non-aggregate (or unspecified aggregate) can be expressed using + // a constant. + auto t = v->GetType(); + auto tag = t->Tag(); + std::string rep; + bool track_constant = false; + + switch ( tag ) { + case TYPE_STRING: { + auto s = v->AsStringVal(); + rep = escape_string(s->Bytes(), s->Len()); + track_constant = s->Len() > 0; + break; + } + + case TYPE_LIST: { + auto lv = cast_intrusive(v); + for ( auto& v_i : lv->Vals() ) { + if ( ! rep.empty() ) + rep += ", "; + + rep += ValName(v_i); + } + break; + } + + case TYPE_FUNC: rep = v->AsFunc()->Name(); break; + + case TYPE_TIME: { + auto tm = v->AsDouble(); + rep = TimeConstant(tm); + + if ( tm > 0.0 && rep.find("__base_time") == std::string::npos ) + // We're not representing it using base_time. + track_constant = true; + + break; + } + + case TYPE_INTERVAL: rep = "double_to_interval(" + std::to_string(v->AsDouble()) + ")"; break; + + case TYPE_TABLE: rep = t->Yield() ? "table()" : "set()"; break; + + case TYPE_VECTOR: rep = "vector()"; break; + + case TYPE_PATTERN: + case TYPE_PORT: + case TYPE_ADDR: + case TYPE_SUBNET: { + ODesc d; + v->Describe(&d); + rep = d.Description(); + track_constant = true; + + if ( tag == TYPE_ADDR || tag == TYPE_SUBNET ) { + // Fix up deficiency that IPv6 addresses are + // described without surrounding []'s. + const auto& addr = tag == TYPE_ADDR ? v->AsAddr() : v->AsSubNet().Prefix(); + if ( addr.GetFamily() == IPv6 ) + rep = "[" + rep + "]"; + } + } break; + + default: { + ODesc d; + v->Describe(&d); + rep = d.Description(); + } + } + + val_names[v.get()] = rep; + vals.push_back(v); + + if ( track_constant ) + constants[tag].insert(rep); + + std::array constants; + + return rep; +} + +bool ValTraceMgr::IsUnspecifiedAggregate(const ValPtr& v) const { + auto t = v->GetType()->Tag(); + + if ( t == TYPE_TABLE && v->GetType()->IsUnspecifiedTable() ) + return true; + + if ( t == TYPE_VECTOR && v->GetType()->IsUnspecifiedVector() ) + return true; + + return false; +} + +bool ValTraceMgr::IsUnsupported(const Val* v) const { + auto t = v->GetType()->Tag(); + return t == TYPE_ANY || t == TYPE_FILE || t == TYPE_OPAQUE; +} + +EventTraceMgr::EventTraceMgr(const std::string& trace_file) { + f = fopen(trace_file.c_str(), "w"); + if ( ! f ) + reporter->FatalError("can't open event trace file %s", trace_file.c_str()); +} + +EventTraceMgr::~EventTraceMgr() { + if ( events.empty() ) + return; + + fprintf(f, "module __EventTrace;\n\n"); + + auto bt = vtm.GetBaseTime(); + + if ( bt ) + fprintf(f, "global __base_time = %.06f;\n\n", bt); + + for ( auto& e : events ) + fprintf(f, "global %s: event();\n", e->GetName()); + + fprintf(f, "\nevent zeek_init() &priority=-999999\n"); + fprintf(f, "\t{\n"); + fprintf(f, "\tevent __EventTrace::%s();\n", events.front()->GetName()); + fprintf(f, "\t}\n"); + + for ( auto i = 0U; i < events.size(); ++i ) { + fprintf(f, "\n"); + + auto predecessor = i > 0 ? events[i - 1] : nullptr; + auto successor = i + 1 < events.size() ? events[i + 1]->GetName() : ""; + events[i]->Generate(f, vtm, predecessor.get(), successor); + } + + const auto& constants = vtm.GetConstants(); + + for ( auto tag = 0; tag < NUM_TYPES; ++tag ) { + auto& c_t = constants[tag]; + if ( c_t.empty() && (tag != TYPE_TIME || ! bt) ) + continue; + + fprintf(f, "\n# constants of type %s:\n", type_name(TypeTag(tag))); + if ( tag == TYPE_TIME && bt ) + fprintf(f, "#\t__base_time = %.06f\n", bt); + + for ( auto& c : c_t ) + fprintf(f, "#\t%s\n", c.c_str()); + } + + fclose(f); +} + +void EventTraceMgr::StartEvent(const ScriptFunc* ev, const zeek::Args* args) { + if ( script_events.count(ev->Name()) > 0 ) + return; + + auto nt = run_state::network_time; + if ( nt == 0.0 || util::streq(ev->Name(), "zeek_init") ) + return; + + if ( ! vtm.GetBaseTime() ) + vtm.SetBaseTime(nt); + + auto et = std::make_shared(ev, nt, events.size()); + events.emplace_back(et); + + vtm.TraceEventValues(et, args); +} + +void EventTraceMgr::EndEvent(const ScriptFunc* ev, const zeek::Args* args) { + if ( script_events.count(ev->Name()) > 0 ) + return; + + if ( run_state::network_time > 0.0 && ! util::streq(ev->Name(), "zeek_init") ) + vtm.FinishCurrentEvent(args); +} + +void EventTraceMgr::ScriptEventQueued(const EventHandlerPtr& h) { script_events.insert(h->Name()); } + +} // namespace zeek::detail diff --git a/src/EventTrace.h b/src/EventTrace.h index 092f2badd6..de111de8f6 100644 --- a/src/EventTrace.h +++ b/src/EventTrace.h @@ -5,37 +5,35 @@ #include "zeek/Val.h" #include "zeek/ZeekArgs.h" -namespace zeek::detail - { +namespace zeek::detail { class ValTrace; class ValTraceMgr; // Abstract class for capturing a single difference between two script-level // values. Includes notions of inserting, changing, or deleting a value. -class ValDelta - { +class ValDelta { public: - ValDelta(const ValTrace* _vt) : vt(_vt) { } - virtual ~ValDelta() { } + ValDelta(const ValTrace* _vt) : vt(_vt) {} + virtual ~ValDelta() {} - // Return a string that performs the update operation, expressed - // as Zeek scripting. Does not include a terminating semicolon. - virtual std::string Generate(ValTraceMgr* vtm) const; + // Return a string that performs the update operation, expressed + // as Zeek scripting. Does not include a terminating semicolon. + virtual std::string Generate(ValTraceMgr* vtm) const; - // Whether the generated string needs the affected value to - // explicitly appear on the left-hand-side. Note that this - // might not be as a simple "LHS = RHS" assignment, but instead - // as "LHS$field = RHS" or "LHS[index] = RHS". - // - // Returns false for generated strings like "delete LHS[index]". - virtual bool NeedsLHS() const { return true; } + // Whether the generated string needs the affected value to + // explicitly appear on the left-hand-side. Note that this + // might not be as a simple "LHS = RHS" assignment, but instead + // as "LHS$field = RHS" or "LHS[index] = RHS". + // + // Returns false for generated strings like "delete LHS[index]". + virtual bool NeedsLHS() const { return true; } - const ValTrace* GetValTrace() const { return vt; } + const ValTrace* GetValTrace() const { return vt; } protected: - const ValTrace* vt; - }; + const ValTrace* vt; +}; using DeltaVector = std::vector>; @@ -43,464 +41,426 @@ using DeltaVector = std::vector>; // For non-aggregates, this is simply the Val object, but for aggregates // it is (recursively) each of the sub-elements, in a manner that can then // be readily compared against future instances. -class ValTrace - { +class ValTrace { public: - ValTrace(const ValPtr& v); - ~ValTrace() = default; + ValTrace(const ValPtr& v); + ~ValTrace() = default; - const ValPtr& GetVal() const { return v; } - const TypePtr& GetType() const { return t; } - const auto& GetElems() const { return elems; } + const ValPtr& GetVal() const { return v; } + const TypePtr& GetType() const { return t; } + const auto& GetElems() const { return elems; } - // Returns true if this trace and the given one represent the - // same underlying value. Can involve subelement-by-subelement - // (recursive) comparisons. - bool operator==(const ValTrace& vt) const; - bool operator!=(const ValTrace& vt) const { return ! ((*this) == vt); } + // Returns true if this trace and the given one represent the + // same underlying value. Can involve subelement-by-subelement + // (recursive) comparisons. + bool operator==(const ValTrace& vt) const; + bool operator!=(const ValTrace& vt) const { return ! ((*this) == vt); } - // Computes the deltas between a previous ValTrace and this one. - // If "prev" is nil then we're creating this value from scratch - // (though if it's an aggregate, we may reuse existing values - // for some of its components). - // - // Returns the accumulated differences in "deltas". If on return - // nothing was added to "deltas" then the two ValTrace's are equivalent - // (no changes between them). - void ComputeDelta(const ValTrace* prev, DeltaVector& deltas) const; + // Computes the deltas between a previous ValTrace and this one. + // If "prev" is nil then we're creating this value from scratch + // (though if it's an aggregate, we may reuse existing values + // for some of its components). + // + // Returns the accumulated differences in "deltas". If on return + // nothing was added to "deltas" then the two ValTrace's are equivalent + // (no changes between them). + void ComputeDelta(const ValTrace* prev, DeltaVector& deltas) const; private: - // Methods for tracing different types of aggregate values. - void TraceList(const ListValPtr& lv); - void TraceRecord(const RecordValPtr& rv); - void TraceTable(const TableValPtr& tv); - void TraceVector(const VectorValPtr& vv); + // Methods for tracing different types of aggregate values. + void TraceList(const ListValPtr& lv); + void TraceRecord(const RecordValPtr& rv); + void TraceTable(const TableValPtr& tv); + void TraceVector(const VectorValPtr& vv); - // Predicates for comparing different types of aggregates for equality. - bool SameList(const ValTrace& vt) const; - bool SameRecord(const ValTrace& vt) const; - bool SameTable(const ValTrace& vt) const; - bool SameVector(const ValTrace& vt) const; + // Predicates for comparing different types of aggregates for equality. + bool SameList(const ValTrace& vt) const; + bool SameRecord(const ValTrace& vt) const; + bool SameTable(const ValTrace& vt) const; + bool SameVector(const ValTrace& vt) const; - // Helper function that knows about the internal vector-of-subelements - // we use for aggregates. - bool SameElems(const ValTrace& vt) const; + // Helper function that knows about the internal vector-of-subelements + // we use for aggregates. + bool SameElems(const ValTrace& vt) const; - // True if this value is a singleton and it's the same value as - // represented in "vt". - bool SameSingleton(const ValTrace& vt) const; + // True if this value is a singleton and it's the same value as + // represented in "vt". + bool SameSingleton(const ValTrace& vt) const; - // Add to "deltas" the differences needed to turn a previous instance - // of the given type of aggregate to the current instance. - void ComputeRecordDelta(const ValTrace* prev, DeltaVector& deltas) const; - void ComputeTableDelta(const ValTrace* prev, DeltaVector& deltas) const; - void ComputeVectorDelta(const ValTrace* prev, DeltaVector& deltas) const; + // Add to "deltas" the differences needed to turn a previous instance + // of the given type of aggregate to the current instance. + void ComputeRecordDelta(const ValTrace* prev, DeltaVector& deltas) const; + void ComputeTableDelta(const ValTrace* prev, DeltaVector& deltas) const; + void ComputeVectorDelta(const ValTrace* prev, DeltaVector& deltas) const; - // Holds sub-elements for aggregates. - std::vector> elems; + // Holds sub-elements for aggregates. + std::vector> elems; - // A parallel vector used for the yield values of tables. - std::vector> elems2; + // A parallel vector used for the yield values of tables. + std::vector> elems2; - ValPtr v; - TypePtr t; // v's type, for convenience - }; + ValPtr v; + TypePtr t; // v's type, for convenience +}; // Captures the basic notion of a new, non-equivalent value being assigned. -class DeltaReplaceValue : public ValDelta - { +class DeltaReplaceValue : public ValDelta { public: - DeltaReplaceValue(const ValTrace* _vt, ValPtr _new_val) - : ValDelta(_vt), new_val(std::move(_new_val)) - { - } + DeltaReplaceValue(const ValTrace* _vt, ValPtr _new_val) : ValDelta(_vt), new_val(std::move(_new_val)) {} - std::string Generate(ValTraceMgr* vtm) const override; + std::string Generate(ValTraceMgr* vtm) const override; private: - ValPtr new_val; - }; + ValPtr new_val; +}; // Captures the notion of setting a record field. -class DeltaSetField : public ValDelta - { +class DeltaSetField : public ValDelta { public: - DeltaSetField(const ValTrace* _vt, int _field, ValPtr _new_val) - : ValDelta(_vt), field(_field), new_val(std::move(_new_val)) - { - } + DeltaSetField(const ValTrace* _vt, int _field, ValPtr _new_val) + : ValDelta(_vt), field(_field), new_val(std::move(_new_val)) {} - std::string Generate(ValTraceMgr* vtm) const override; + std::string Generate(ValTraceMgr* vtm) const override; private: - int field; - ValPtr new_val; - }; + int field; + ValPtr new_val; +}; // Captures the notion of deleting a record field. -class DeltaRemoveField : public ValDelta - { +class DeltaRemoveField : public ValDelta { public: - DeltaRemoveField(const ValTrace* _vt, int _field) : ValDelta(_vt), field(_field) { } + DeltaRemoveField(const ValTrace* _vt, int _field) : ValDelta(_vt), field(_field) {} - std::string Generate(ValTraceMgr* vtm) const override; - bool NeedsLHS() const override { return false; } + std::string Generate(ValTraceMgr* vtm) const override; + bool NeedsLHS() const override { return false; } private: - int field; - }; + int field; +}; // Captures the notion of creating a record from scratch. -class DeltaRecordCreate : public ValDelta - { +class DeltaRecordCreate : public ValDelta { public: - DeltaRecordCreate(const ValTrace* _vt) : ValDelta(_vt) { } + DeltaRecordCreate(const ValTrace* _vt) : ValDelta(_vt) {} - std::string Generate(ValTraceMgr* vtm) const override; - }; + std::string Generate(ValTraceMgr* vtm) const override; +}; // Captures the notion of adding an element to a set. Use DeltaRemoveTableEntry to // delete values. -class DeltaSetSetEntry : public ValDelta - { +class DeltaSetSetEntry : public ValDelta { public: - DeltaSetSetEntry(const ValTrace* _vt, ValPtr _index) : ValDelta(_vt), index(_index) { } + DeltaSetSetEntry(const ValTrace* _vt, ValPtr _index) : ValDelta(_vt), index(_index) {} - std::string Generate(ValTraceMgr* vtm) const override; - bool NeedsLHS() const override { return false; } + std::string Generate(ValTraceMgr* vtm) const override; + bool NeedsLHS() const override { return false; } private: - ValPtr index; - }; + ValPtr index; +}; // Captures the notion of setting a table entry (which includes both changing // an existing one and adding a new one). Use DeltaRemoveTableEntry to // delete values. -class DeltaSetTableEntry : public ValDelta - { +class DeltaSetTableEntry : public ValDelta { public: - DeltaSetTableEntry(const ValTrace* _vt, ValPtr _index, ValPtr _new_val) - : ValDelta(_vt), index(_index), new_val(std::move(_new_val)) - { - } + DeltaSetTableEntry(const ValTrace* _vt, ValPtr _index, ValPtr _new_val) + : ValDelta(_vt), index(_index), new_val(std::move(_new_val)) {} - std::string Generate(ValTraceMgr* vtm) const override; + std::string Generate(ValTraceMgr* vtm) const override; private: - ValPtr index; - ValPtr new_val; - }; + ValPtr index; + ValPtr new_val; +}; // Captures the notion of removing a table/set entry. -class DeltaRemoveTableEntry : public ValDelta - { +class DeltaRemoveTableEntry : public ValDelta { public: - DeltaRemoveTableEntry(const ValTrace* _vt, ValPtr _index) - : ValDelta(_vt), index(std::move(_index)) - { - } + DeltaRemoveTableEntry(const ValTrace* _vt, ValPtr _index) : ValDelta(_vt), index(std::move(_index)) {} - std::string Generate(ValTraceMgr* vtm) const override; - bool NeedsLHS() const override { return false; } + std::string Generate(ValTraceMgr* vtm) const override; + bool NeedsLHS() const override { return false; } private: - ValPtr index; - }; + ValPtr index; +}; // Captures the notion of creating a set from scratch. -class DeltaSetCreate : public ValDelta - { +class DeltaSetCreate : public ValDelta { public: - DeltaSetCreate(const ValTrace* _vt) : ValDelta(_vt) { } + DeltaSetCreate(const ValTrace* _vt) : ValDelta(_vt) {} - std::string Generate(ValTraceMgr* vtm) const override; - }; + std::string Generate(ValTraceMgr* vtm) const override; +}; // Captures the notion of creating a table from scratch. -class DeltaTableCreate : public ValDelta - { +class DeltaTableCreate : public ValDelta { public: - DeltaTableCreate(const ValTrace* _vt) : ValDelta(_vt) { } + DeltaTableCreate(const ValTrace* _vt) : ValDelta(_vt) {} - std::string Generate(ValTraceMgr* vtm) const override; - }; + std::string Generate(ValTraceMgr* vtm) const override; +}; // Captures the notion of changing an element of a vector. -class DeltaVectorSet : public ValDelta - { +class DeltaVectorSet : public ValDelta { public: - DeltaVectorSet(const ValTrace* _vt, int _index, ValPtr _elem) - : ValDelta(_vt), index(_index), elem(std::move(_elem)) - { - } + DeltaVectorSet(const ValTrace* _vt, int _index, ValPtr _elem) + : ValDelta(_vt), index(_index), elem(std::move(_elem)) {} - std::string Generate(ValTraceMgr* vtm) const override; + std::string Generate(ValTraceMgr* vtm) const override; private: - int index; - ValPtr elem; - }; + int index; + ValPtr elem; +}; // Captures the notion of adding an entry to the end of a vector. -class DeltaVectorAppend : public ValDelta - { +class DeltaVectorAppend : public ValDelta { public: - DeltaVectorAppend(const ValTrace* _vt, int _index, ValPtr _elem) - : ValDelta(_vt), index(_index), elem(std::move(_elem)) - { - } + DeltaVectorAppend(const ValTrace* _vt, int _index, ValPtr _elem) + : ValDelta(_vt), index(_index), elem(std::move(_elem)) {} - std::string Generate(ValTraceMgr* vtm) const override; + std::string Generate(ValTraceMgr* vtm) const override; private: - int index; - ValPtr elem; - }; + int index; + ValPtr elem; +}; // Captures the notion of replacing a vector wholesale. -class DeltaVectorCreate : public ValDelta - { +class DeltaVectorCreate : public ValDelta { public: - DeltaVectorCreate(const ValTrace* _vt) : ValDelta(_vt) { } + DeltaVectorCreate(const ValTrace* _vt) : ValDelta(_vt) {} - std::string Generate(ValTraceMgr* vtm) const override; - }; + std::string Generate(ValTraceMgr* vtm) const override; +}; // Captures the notion of creating a value with an unsupported type // (like "opaque"). -class DeltaUnsupportedCreate : public ValDelta - { +class DeltaUnsupportedCreate : public ValDelta { public: - DeltaUnsupportedCreate(const ValTrace* _vt) : ValDelta(_vt) { } + DeltaUnsupportedCreate(const ValTrace* _vt) : ValDelta(_vt) {} - std::string Generate(ValTraceMgr* vtm) const override; - }; + std::string Generate(ValTraceMgr* vtm) const override; +}; // Manages the changes to (or creation of) a variable used to represent // a value. -class DeltaGen - { +class DeltaGen { public: - DeltaGen(ValPtr _val, std::string _rhs, bool _needs_lhs, bool _is_first_def) - : val(std::move(_val)), rhs(std::move(_rhs)), needs_lhs(_needs_lhs), - is_first_def(_is_first_def) - { - } + DeltaGen(ValPtr _val, std::string _rhs, bool _needs_lhs, bool _is_first_def) + : val(std::move(_val)), rhs(std::move(_rhs)), needs_lhs(_needs_lhs), is_first_def(_is_first_def) {} - const ValPtr& GetVal() const { return val; } - const std::string& RHS() const { return rhs; } - bool NeedsLHS() const { return needs_lhs; } - bool IsFirstDef() const { return is_first_def; } + const ValPtr& GetVal() const { return val; } + const std::string& RHS() const { return rhs; } + bool NeedsLHS() const { return needs_lhs; } + bool IsFirstDef() const { return is_first_def; } private: - ValPtr val; + ValPtr val; - // The expression to set the variable to. - std::string rhs; + // The expression to set the variable to. + std::string rhs; - // Whether that expression needs the variable explicitly provides - // on the lefthand side. - bool needs_lhs; + // Whether that expression needs the variable explicitly provides + // on the lefthand side. + bool needs_lhs; - // Whether this is the first definition of the variable (in which - // case we also need to declare the variable). - bool is_first_def; - }; + // Whether this is the first definition of the variable (in which + // case we also need to declare the variable). + bool is_first_def; +}; using DeltaGenVec = std::vector; // Tracks a single event. -class EventTrace - { +class EventTrace { public: - // Constructed in terms of the associated script function, "network - // time" when the event occurred, and the position of this event - // within all of those being traced. - EventTrace(const ScriptFunc* _ev, double _nt, size_t event_num); + // Constructed in terms of the associated script function, "network + // time" when the event occurred, and the position of this event + // within all of those being traced. + EventTrace(const ScriptFunc* _ev, double _nt, size_t event_num); - // Sets a string representation of the arguments (values) being - // passed to the event. - void SetArgs(std::string _args) { args = std::move(_args); } + // Sets a string representation of the arguments (values) being + // passed to the event. + void SetArgs(std::string _args) { args = std::move(_args); } - // Adds to the trace an update for the given value. - void AddDelta(ValPtr val, std::string rhs, bool needs_lhs, bool is_first_def) - { - auto& d = is_post ? post_deltas : deltas; - d.emplace_back(DeltaGen(val, rhs, needs_lhs, is_first_def)); - } + // Adds to the trace an update for the given value. + void AddDelta(ValPtr val, std::string rhs, bool needs_lhs, bool is_first_def) { + auto& d = is_post ? post_deltas : deltas; + d.emplace_back(DeltaGen(val, rhs, needs_lhs, is_first_def)); + } - // Initially we analyze events pre-execution. When this flag - // is set, we switch to instead analyzing post-execution. The - // difference allows us to annotate the output with "# from script" - // comments that flag changes created by script execution rather - // than event engine activity. - void SetDoingPost() { is_post = true; } + // Initially we analyze events pre-execution. When this flag + // is set, we switch to instead analyzing post-execution. The + // difference allows us to annotate the output with "# from script" + // comments that flag changes created by script execution rather + // than event engine activity. + void SetDoingPost() { is_post = true; } - const char* GetName() const { return name.c_str(); } + const char* GetName() const { return name.c_str(); } - // Generates an internal event handler that sets up the values - // associated with the traced event, followed by queueing the traced - // event, and then queueing the successor internal event handler, - // if any. - // - // "predecessor", if non-nil, gives the event that came just before - // this one (used for "# from script" annotations"). "successor", - // if not empty, gives the name of the successor internal event. - void Generate(FILE* f, ValTraceMgr& vtm, const EventTrace* predecessor, - std::string successor) const; + // Generates an internal event handler that sets up the values + // associated with the traced event, followed by queueing the traced + // event, and then queueing the successor internal event handler, + // if any. + // + // "predecessor", if non-nil, gives the event that came just before + // this one (used for "# from script" annotations"). "successor", + // if not empty, gives the name of the successor internal event. + void Generate(FILE* f, ValTraceMgr& vtm, const EventTrace* predecessor, std::string successor) const; private: - // "dvec" is either just our deltas, or the "post_deltas" of our - // predecessor plus our deltas. - void Generate(FILE* f, ValTraceMgr& vtm, const DeltaGenVec& dvec, std::string successor, - int num_pre = 0) const; + // "dvec" is either just our deltas, or the "post_deltas" of our + // predecessor plus our deltas. + void Generate(FILE* f, ValTraceMgr& vtm, const DeltaGenVec& dvec, std::string successor, int num_pre = 0) const; - const ScriptFunc* ev; - double nt; - bool is_post = false; + const ScriptFunc* ev; + double nt; + bool is_post = false; - // The deltas needed to construct the values associated with this - // event prior to its execution. - DeltaGenVec deltas; + // The deltas needed to construct the values associated with this + // event prior to its execution. + DeltaGenVec deltas; - // The deltas capturing any changes to the original values as induced - // by executing its event handlers. - DeltaGenVec post_deltas; + // The deltas capturing any changes to the original values as induced + // by executing its event handlers. + DeltaGenVec post_deltas; - // The event's name and a string representation of its arguments. - std::string name; - std::string args; - }; + // The event's name and a string representation of its arguments. + std::string name; + std::string args; +}; // Manages all of the events and associated values seen during the execution. -class ValTraceMgr - { +class ValTraceMgr { public: - // Invoked to trace a new event with the associated arguments. - void TraceEventValues(std::shared_ptr et, const zeek::Args* args); + // Invoked to trace a new event with the associated arguments. + void TraceEventValues(std::shared_ptr et, const zeek::Args* args); - // Invoked when the current event finishes execution. The arguments - // are again provided, for convenience so we don't have to remember - // them from the previous method. - void FinishCurrentEvent(const zeek::Args* args); + // Invoked when the current event finishes execution. The arguments + // are again provided, for convenience so we don't have to remember + // them from the previous method. + void FinishCurrentEvent(const zeek::Args* args); - // Returns the name of the script variable associated with the - // given value. - const std::string& ValName(const ValPtr& v); - const std::string& ValName(const ValTrace* vt) { return ValName(vt->GetVal()); } + // Returns the name of the script variable associated with the + // given value. + const std::string& ValName(const ValPtr& v); + const std::string& ValName(const ValTrace* vt) { return ValName(vt->GetVal()); } - // Returns true if the script variable associated with the given value - // needs to be global (because it's used across multiple events). - bool IsGlobal(const ValPtr& v) const { return globals.count(v.get()) > 0; } + // Returns true if the script variable associated with the given value + // needs to be global (because it's used across multiple events). + bool IsGlobal(const ValPtr& v) const { return globals.count(v.get()) > 0; } - // Returns or sets the "base time" from which eligible times are - // transformed into offsets rather than maintained as absolute - // values. - double GetBaseTime() const { return base_time; } - void SetBaseTime(double bt) { base_time = bt; } + // Returns or sets the "base time" from which eligible times are + // transformed into offsets rather than maintained as absolute + // values. + double GetBaseTime() const { return base_time; } + void SetBaseTime(double bt) { base_time = bt; } - // Returns a Zeek script representation of the given "time" value. - // This might be relative to base_time or might be absolute. - std::string TimeConstant(double t); + // Returns a Zeek script representation of the given "time" value. + // This might be relative to base_time or might be absolute. + std::string TimeConstant(double t); - // Returns the array of per-type-tag constants. - const auto& GetConstants() const { return constants; } + // Returns the array of per-type-tag constants. + const auto& GetConstants() const { return constants; } private: - // Traces the given value, which we may-or-may-not have seen before. - void AddVal(ValPtr v); + // Traces the given value, which we may-or-may-not have seen before. + void AddVal(ValPtr v); - // Creates a new value, associating a script variable with it. - void NewVal(ValPtr v); + // Creates a new value, associating a script variable with it. + void NewVal(ValPtr v); - // Called when the given value is used in an expression that sets - // or updates another value. This lets us track which values are - // used across multiple events, and thus need to be global. - void ValUsed(const ValPtr& v); + // Called when the given value is used in an expression that sets + // or updates another value. This lets us track which values are + // used across multiple events, and thus need to be global. + void ValUsed(const ValPtr& v); - // Compares the two value traces to build up deltas capturing - // the difference between the previous one and the current one. - void AssessChange(const ValTrace* vt, const ValTrace* prev_vt); + // Compares the two value traces to build up deltas capturing + // the difference between the previous one and the current one. + void AssessChange(const ValTrace* vt, const ValTrace* prev_vt); - // Create and track a script variable associated with the given value. - void TrackVar(const Val* vt); + // Create and track a script variable associated with the given value. + void TrackVar(const Val* vt); - // Generates a name for a value. - std::string GenValName(const ValPtr& v); + // Generates a name for a value. + std::string GenValName(const ValPtr& v); - // True if the given value is an unspecified (and empty set, - // table, or vector appearing as a constant rather than an - // already-typed value). - bool IsUnspecifiedAggregate(const ValPtr& v) const; + // True if the given value is an unspecified (and empty set, + // table, or vector appearing as a constant rather than an + // already-typed value). + bool IsUnspecifiedAggregate(const ValPtr& v) const; - // True if the given value has an unsupported type. - bool IsUnsupported(const Val* v) const; + // True if the given value has an unsupported type. + bool IsUnsupported(const Val* v) const; - // Maps values to their associated traces. - std::unordered_map> val_map; + // Maps values to their associated traces. + std::unordered_map> val_map; - // Maps values to the "names" we associated with them. For simple - // values, the name is just a Zeek script constant. For aggregates, - // it's a dedicated script variable. - std::unordered_map val_names; - int num_vars = 0; // the number of dedicated script variables + // Maps values to the "names" we associated with them. For simple + // values, the name is just a Zeek script constant. For aggregates, + // it's a dedicated script variable. + std::unordered_map val_names; + int num_vars = 0; // the number of dedicated script variables - // Tracks which values we've processed up through the preceding event. - // Any re-use we then see for the current event (via a ValUsed() call) - // then tells us that the value is used across events, and thus its - // associated script variable needs to be global. - std::unordered_set processed_vals; + // Tracks which values we've processed up through the preceding event. + // Any re-use we then see for the current event (via a ValUsed() call) + // then tells us that the value is used across events, and thus its + // associated script variable needs to be global. + std::unordered_set processed_vals; - // Tracks which values have associated script variables that need - // to be global. - std::unordered_set globals; + // Tracks which values have associated script variables that need + // to be global. + std::unordered_set globals; - // Indexed by type tag, stores an ordered set of all of the distinct - // representations of constants of that type. - std::array, NUM_TYPES> constants; + // Indexed by type tag, stores an ordered set of all of the distinct + // representations of constants of that type. + std::array, NUM_TYPES> constants; - // If non-zero, then we've established a "base time" and will report - // time constants as offsets from it (when reasonable, i.e., no - // negative offsets, and base_time can't be too close to 0.0). - double base_time = 0.0; + // If non-zero, then we've established a "base time" and will report + // time constants as offsets from it (when reasonable, i.e., no + // negative offsets, and base_time can't be too close to 0.0). + double base_time = 0.0; - // The event we're currently tracing. - std::shared_ptr curr_ev; + // The event we're currently tracing. + std::shared_ptr curr_ev; - // Hang on to values we're tracking to make sure the pointers don't - // get reused when the main use of the value ends. - std::vector vals; - }; + // Hang on to values we're tracking to make sure the pointers don't + // get reused when the main use of the value ends. + std::vector vals; +}; // Manages tracing of all of the events seen during execution, including // the final generation of the trace script. -class EventTraceMgr - { +class EventTraceMgr { public: - EventTraceMgr(const std::string& trace_file); - ~EventTraceMgr(); + EventTraceMgr(const std::string& trace_file); + ~EventTraceMgr(); - // Called at the beginning of invoking an event's handlers. - void StartEvent(const ScriptFunc* ev, const zeek::Args* args); + // Called at the beginning of invoking an event's handlers. + void StartEvent(const ScriptFunc* ev, const zeek::Args* args); - // Called after finishing with invoking an event's handlers. - void EndEvent(const ScriptFunc* ev, const zeek::Args* args); + // Called after finishing with invoking an event's handlers. + void EndEvent(const ScriptFunc* ev, const zeek::Args* args); - // Used to track events generated at script-level. - void ScriptEventQueued(const EventHandlerPtr& h); + // Used to track events generated at script-level. + void ScriptEventQueued(const EventHandlerPtr& h); private: - FILE* f = nullptr; - ValTraceMgr vtm; + FILE* f = nullptr; + ValTraceMgr vtm; - // All of the events we've traced so far. - std::vector> events; + // All of the events we've traced so far. + std::vector> events; - // The names of all of the script events that have been generated. - std::unordered_set script_events; - }; + // The names of all of the script events that have been generated. + std::unordered_set script_events; +}; // If non-nil then we're doing event tracing. extern std::unique_ptr etm; - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/Expr.cc b/src/Expr.cc index 6f037e63c5..80f48dd17c 100644 --- a/src/Expr.cc +++ b/src/Expr.cc @@ -27,5559 +27,4781 @@ #include "zeek/script_opt/ExprOptInfo.h" #include "zeek/script_opt/ScriptOpt.h" -namespace zeek::detail - { +namespace zeek::detail { -const char* expr_name(ExprTag t) - { - static const char* expr_names[int(NUM_EXPRS)] = { - "name", - "const", - "(*)", - "++", - "--", - "!", - "~", - "+", - "-", - "+", - "-", - "+=", - "-=", - "*", - "/", - "/", // mask operator - "%", - "&", - "|", - "^", - "<<", - ">>", - "&&", - "||", - "<", - "<=", - "==", - "!=", - ">=", - ">", - "?:", - "ref", - "=", - "[]", - "$", - "?$", - "[=]", - "table()", - "set()", - "vector()", - "$=", - "in", - "<<>>", - "()", - "function()", - "event", - "schedule", - "coerce", - "record_coerce", - "table_coerce", - "vector_coerce", - "sizeof", - "cast", - "is", - "[:]=", - "inline()", - "[]=", - "$=", - "vec+=", - "to_any_coerce", - "from_any_coerce", - "from_any_vec_coerce", - "any[]", - "nop", +const char* expr_name(ExprTag t) { + static const char* expr_names[int(NUM_EXPRS)] = { + "name", + "const", + "(*)", + "++", + "--", + "!", + "~", + "+", + "-", + "+", + "-", + "+=", + "-=", + "*", + "/", + "/", // mask operator + "%", + "&", + "|", + "^", + "<<", + ">>", + "&&", + "||", + "<", + "<=", + "==", + "!=", + ">=", + ">", + "?:", + "ref", + "=", + "[]", + "$", + "?$", + "[=]", + "table()", + "set()", + "vector()", + "$=", + "in", + "<<>>", + "()", + "function()", + "event", + "schedule", + "coerce", + "record_coerce", + "table_coerce", + "vector_coerce", + "sizeof", + "cast", + "is", + "[:]=", + "inline()", + "[]=", + "$=", + "vec+=", + "to_any_coerce", + "from_any_coerce", + "from_any_vec_coerce", + "any[]", + "nop", - }; + }; - if ( int(t) >= NUM_EXPRS ) - { - static char errbuf[512]; + if ( int(t) >= NUM_EXPRS ) { + static char errbuf[512]; - // This isn't quite right - we return a static buffer, - // so multiple calls to expr_name() could lead to confusion - // by overwriting the buffer. But oh well. - snprintf(errbuf, sizeof(errbuf), "%d: not an expression tag", int(t)); - return errbuf; - } + // This isn't quite right - we return a static buffer, + // so multiple calls to expr_name() could lead to confusion + // by overwriting the buffer. But oh well. + snprintf(errbuf, sizeof(errbuf), "%d: not an expression tag", int(t)); + return errbuf; + } - return expr_names[int(t)]; - } + return expr_names[int(t)]; +} int Expr::num_exprs = 0; -Expr::Expr(ExprTag arg_tag) : tag(arg_tag), paren(false), type(nullptr) - { - SetLocationInfo(&start_location, &end_location); - opt_info = new ExprOptInfo(); - ++num_exprs; - } - -Expr::~Expr() - { - delete opt_info; - } - -const ListExpr* Expr::AsListExpr() const - { - CHECK_TAG(tag, EXPR_LIST, "Expr::AsListExpr", expr_name) - return (const ListExpr*)this; - } - -ListExpr* Expr::AsListExpr() - { - CHECK_TAG(tag, EXPR_LIST, "Expr::AsListExpr", expr_name) - return (ListExpr*)this; - } - -ListExprPtr Expr::AsListExprPtr() - { - CHECK_TAG(tag, EXPR_LIST, "Expr::AsListExpr", expr_name) - return {NewRef{}, (ListExpr*)this}; - } - -const NameExpr* Expr::AsNameExpr() const - { - CHECK_TAG(tag, EXPR_NAME, "Expr::AsNameExpr", expr_name) - return (const NameExpr*)this; - } - -NameExpr* Expr::AsNameExpr() - { - CHECK_TAG(tag, EXPR_NAME, "Expr::AsNameExpr", expr_name) - return (NameExpr*)this; - } - -NameExprPtr Expr::AsNameExprPtr() - { - CHECK_TAG(tag, EXPR_NAME, "Expr::AsNameExpr", expr_name) - return {NewRef{}, (NameExpr*)this}; - } - -const ConstExpr* Expr::AsConstExpr() const - { - CHECK_TAG(tag, EXPR_CONST, "Expr::AsConstExpr", expr_name) - return (const ConstExpr*)this; - } - -ConstExprPtr Expr::AsConstExprPtr() - { - CHECK_TAG(tag, EXPR_CONST, "Expr::AsConstExpr", expr_name) - return {NewRef{}, (ConstExpr*)this}; - } - -const CallExpr* Expr::AsCallExpr() const - { - CHECK_TAG(tag, EXPR_CALL, "Expr::AsCallExpr", expr_name) - return (const CallExpr*)this; - } - -const AssignExpr* Expr::AsAssignExpr() const - { - CHECK_TAG(tag, EXPR_ASSIGN, "Expr::AsAssignExpr", expr_name) - return (const AssignExpr*)this; - } - -AssignExpr* Expr::AsAssignExpr() - { - CHECK_TAG(tag, EXPR_ASSIGN, "Expr::AsAssignExpr", expr_name) - return (AssignExpr*)this; - } - -const IndexExpr* Expr::AsIndexExpr() const - { - CHECK_TAG(tag, EXPR_INDEX, "Expr::AsIndexExpr", expr_name) - return (const IndexExpr*)this; - } - -IndexExpr* Expr::AsIndexExpr() - { - CHECK_TAG(tag, EXPR_INDEX, "Expr::AsIndexExpr", expr_name) - return (IndexExpr*)this; - } - -const EventExpr* Expr::AsEventExpr() const - { - CHECK_TAG(tag, EXPR_EVENT, "Expr::AsEventExpr", expr_name) - return (const EventExpr*)this; - } - -EventExprPtr Expr::AsEventExprPtr() - { - CHECK_TAG(tag, EXPR_EVENT, "Expr::AsEventExpr", expr_name) - return {NewRef{}, (EventExpr*)this}; - } - -const RefExpr* Expr::AsRefExpr() const - { - CHECK_TAG(tag, EXPR_REF, "Expr::AsRefExpr", expr_name) - return (const RefExpr*)this; - } - -RefExprPtr Expr::AsRefExprPtr() - { - CHECK_TAG(tag, EXPR_REF, "Expr::AsRefExpr", expr_name) - return {NewRef{}, (RefExpr*)this}; - } - -bool Expr::CanAdd() const - { - return false; - } - -bool Expr::CanDel() const - { - return false; - } - -void Expr::Add(Frame* /* f */) - { - Internal("Expr::Add called"); - } - -void Expr::Delete(Frame* /* f */) - { - Internal("Expr::Delete called"); - } - -ExprPtr Expr::MakeLvalue() - { - if ( ! IsError() ) - ExprError("can't be assigned to"); - - return ThisPtr(); - } - -bool Expr::InvertSense() - { - return false; - } - -void Expr::Assign(Frame* /* f */, ValPtr /* v */) - { - Internal("Expr::Assign called"); - } - -void Expr::AssignToIndex(ValPtr v1, ValPtr v2, ValPtr v3) const - { - bool iterators_invalidated; - - auto error_msg = assign_to_index(std::move(v1), std::move(v2), std::move(v3), - iterators_invalidated); - - if ( iterators_invalidated ) - reporter->ExprRuntimeWarning(this, "possible loop/iterator invalidation"); - - if ( error_msg ) - RuntimeErrorWithCallStack(error_msg); - } - -static int get_slice_index(int idx, int len) - { - if ( abs(idx) > len ) - idx = idx > 0 ? len : 0; // Clamp maximum positive/negative indices. - else if ( idx < 0 ) - idx += len; // Map to a positive index. - - return idx; - } - -const char* assign_to_index(ValPtr v1, ValPtr v2, ValPtr v3, bool& iterators_invalidated) - { - iterators_invalidated = false; - - if ( ! v1 || ! v2 || ! v3 ) - return nullptr; - - // Hold an extra reference in case the ownership transfer - // to the table/vector goes wrong and we still want to obtain - // diagnostic info from the original value after the assignment - // already unref'd. - auto v_extra = v3; - - switch ( v1->GetType()->Tag() ) - { - case TYPE_VECTOR: - { - const ListVal* lv = v2->AsListVal(); - VectorVal* v1_vect = v1->AsVectorVal(); - - if ( lv->Length() > 1 ) - { - auto len = v1_vect->Size(); - zeek_int_t first = get_slice_index(lv->Idx(0)->CoerceToInt(), len); - zeek_int_t last = get_slice_index(lv->Idx(1)->CoerceToInt(), len); - - // Remove the elements from the vector within the slice. - for ( auto idx = first; idx < last; idx++ ) - v1_vect->Remove(first); - - // Insert the new elements starting at the first - // position. - - VectorVal* v_vect = v3->AsVectorVal(); - - for ( auto idx = 0u; idx < v_vect->Size(); idx++, first++ ) - v1_vect->Insert(first, v_vect->ValAt(idx)); - } - - else if ( ! v1_vect->Assign(lv->Idx(0)->CoerceToUnsigned(), std::move(v3)) ) - { - v3 = std::move(v_extra); - - if ( v3 ) - { - ODesc d; - v3->Describe(&d); - const auto& vt = v3->GetType(); - auto vtt = vt->Tag(); - std::string tn = vtt == TYPE_RECORD ? vt->GetName() : type_name(vtt); - return util::fmt( - "vector index assignment failed for invalid type '%s', value: %s", - tn.data(), d.Description()); - } - else - return "assignment failed with null value"; - } - break; - } - - case TYPE_TABLE: - { - if ( ! v1->AsTableVal()->Assign(std::move(v2), std::move(v3), true, - &iterators_invalidated) ) - { - v3 = std::move(v_extra); - - if ( v3 ) - { - ODesc d; - v3->Describe(&d); - const auto& vt = v3->GetType(); - auto vtt = vt->Tag(); - std::string tn = vtt == TYPE_RECORD ? vt->GetName() : type_name(vtt); - return util::fmt( - "table index assignment failed for invalid type '%s', value: %s", tn.data(), - d.Description()); - } - else - return "assignment failed with null value"; - } - - break; - } - - case TYPE_STRING: - return "assignment via string index accessor not allowed"; - break; - - default: - return "bad index expression type in assignment"; - break; - } - - return nullptr; - } - -TypePtr Expr::InitType() const - { - return type; - } - -bool Expr::IsRecordElement(TypeDecl* /* td */) const - { - return false; - } - -bool Expr::IsError() const - { - return type && type->Tag() == TYPE_ERROR; - } - -void Expr::SetError() - { - SetType(error_type()); - } - -void Expr::SetError(const char* msg) - { - Error(msg); - SetError(); - } - -bool Expr::IsZero() const - { - return IsConst() && ExprVal()->IsZero(); - } - -bool Expr::IsOne() const - { - return IsConst() && ExprVal()->IsOne(); - } - -void Expr::Describe(ODesc* d) const - { - if ( IsParen() && ! d->IsBinary() ) - d->Add("("); - - if ( d->IsBinary() ) - AddTag(d); - - ExprDescribe(d); - - if ( IsParen() && ! d->IsBinary() ) - d->Add(")"); - } - -void Expr::AddTag(ODesc* d) const - { - if ( d->IsBinary() ) - d->Add(int(Tag())); - else - d->AddSP(expr_name(Tag())); - } - -void Expr::Canonicalize() { } - -void Expr::SetType(TypePtr t) - { - if ( ! type || type->Tag() != TYPE_ERROR ) - type = std::move(t); - } - -void Expr::ExprError(const char msg[]) - { - Error(msg); - SetError(); - } - -void Expr::RuntimeError(const std::string& msg) const - { - reporter->ExprRuntimeError(this, "%s", msg.data()); - } - -void Expr::RuntimeErrorWithCallStack(const std::string& msg) const - { - auto rcs = render_call_stack(); - - if ( rcs.empty() ) - reporter->ExprRuntimeError(this, "%s", msg.data()); - else - { - ODesc d; - d.SetShort(); - Describe(&d); - reporter->RuntimeError(GetLocationInfo(), "%s, expression: %s, call stack: %s", msg.data(), - d.Description(), rcs.data()); - } - } - -NameExpr::NameExpr(IDPtr arg_id, bool const_init) : Expr(EXPR_NAME), id(std::move(arg_id)) - { - in_const_init = const_init; - - if ( id->IsType() ) - SetType(make_intrusive(id->GetType())); - else - SetType(id->GetType()); - - EventHandler* h = event_registry->Lookup(id->Name()); - if ( h ) - h->SetUsed(); - } +Expr::Expr(ExprTag arg_tag) : tag(arg_tag), paren(false), type(nullptr) { + SetLocationInfo(&start_location, &end_location); + opt_info = new ExprOptInfo(); + ++num_exprs; +} + +Expr::~Expr() { delete opt_info; } + +const ListExpr* Expr::AsListExpr() const { + CHECK_TAG(tag, EXPR_LIST, "Expr::AsListExpr", expr_name) + return (const ListExpr*)this; +} + +ListExpr* Expr::AsListExpr() { + CHECK_TAG(tag, EXPR_LIST, "Expr::AsListExpr", expr_name) + return (ListExpr*)this; +} + +ListExprPtr Expr::AsListExprPtr() { + CHECK_TAG(tag, EXPR_LIST, "Expr::AsListExpr", expr_name) + return {NewRef{}, (ListExpr*)this}; +} + +const NameExpr* Expr::AsNameExpr() const { + CHECK_TAG(tag, EXPR_NAME, "Expr::AsNameExpr", expr_name) + return (const NameExpr*)this; +} + +NameExpr* Expr::AsNameExpr() { + CHECK_TAG(tag, EXPR_NAME, "Expr::AsNameExpr", expr_name) + return (NameExpr*)this; +} + +NameExprPtr Expr::AsNameExprPtr() { + CHECK_TAG(tag, EXPR_NAME, "Expr::AsNameExpr", expr_name) + return {NewRef{}, (NameExpr*)this}; +} + +const ConstExpr* Expr::AsConstExpr() const { + CHECK_TAG(tag, EXPR_CONST, "Expr::AsConstExpr", expr_name) + return (const ConstExpr*)this; +} + +ConstExprPtr Expr::AsConstExprPtr() { + CHECK_TAG(tag, EXPR_CONST, "Expr::AsConstExpr", expr_name) + return {NewRef{}, (ConstExpr*)this}; +} + +const CallExpr* Expr::AsCallExpr() const { + CHECK_TAG(tag, EXPR_CALL, "Expr::AsCallExpr", expr_name) + return (const CallExpr*)this; +} + +const AssignExpr* Expr::AsAssignExpr() const { + CHECK_TAG(tag, EXPR_ASSIGN, "Expr::AsAssignExpr", expr_name) + return (const AssignExpr*)this; +} + +AssignExpr* Expr::AsAssignExpr() { + CHECK_TAG(tag, EXPR_ASSIGN, "Expr::AsAssignExpr", expr_name) + return (AssignExpr*)this; +} + +const IndexExpr* Expr::AsIndexExpr() const { + CHECK_TAG(tag, EXPR_INDEX, "Expr::AsIndexExpr", expr_name) + return (const IndexExpr*)this; +} + +IndexExpr* Expr::AsIndexExpr() { + CHECK_TAG(tag, EXPR_INDEX, "Expr::AsIndexExpr", expr_name) + return (IndexExpr*)this; +} + +const EventExpr* Expr::AsEventExpr() const { + CHECK_TAG(tag, EXPR_EVENT, "Expr::AsEventExpr", expr_name) + return (const EventExpr*)this; +} + +EventExprPtr Expr::AsEventExprPtr() { + CHECK_TAG(tag, EXPR_EVENT, "Expr::AsEventExpr", expr_name) + return {NewRef{}, (EventExpr*)this}; +} + +const RefExpr* Expr::AsRefExpr() const { + CHECK_TAG(tag, EXPR_REF, "Expr::AsRefExpr", expr_name) + return (const RefExpr*)this; +} + +RefExprPtr Expr::AsRefExprPtr() { + CHECK_TAG(tag, EXPR_REF, "Expr::AsRefExpr", expr_name) + return {NewRef{}, (RefExpr*)this}; +} + +bool Expr::CanAdd() const { return false; } + +bool Expr::CanDel() const { return false; } + +void Expr::Add(Frame* /* f */) { Internal("Expr::Add called"); } + +void Expr::Delete(Frame* /* f */) { Internal("Expr::Delete called"); } + +ExprPtr Expr::MakeLvalue() { + if ( ! IsError() ) + ExprError("can't be assigned to"); + + return ThisPtr(); +} + +bool Expr::InvertSense() { return false; } + +void Expr::Assign(Frame* /* f */, ValPtr /* v */) { Internal("Expr::Assign called"); } + +void Expr::AssignToIndex(ValPtr v1, ValPtr v2, ValPtr v3) const { + bool iterators_invalidated; + + auto error_msg = assign_to_index(std::move(v1), std::move(v2), std::move(v3), iterators_invalidated); + + if ( iterators_invalidated ) + reporter->ExprRuntimeWarning(this, "possible loop/iterator invalidation"); + + if ( error_msg ) + RuntimeErrorWithCallStack(error_msg); +} + +static int get_slice_index(int idx, int len) { + if ( abs(idx) > len ) + idx = idx > 0 ? len : 0; // Clamp maximum positive/negative indices. + else if ( idx < 0 ) + idx += len; // Map to a positive index. + + return idx; +} + +const char* assign_to_index(ValPtr v1, ValPtr v2, ValPtr v3, bool& iterators_invalidated) { + iterators_invalidated = false; + + if ( ! v1 || ! v2 || ! v3 ) + return nullptr; + + // Hold an extra reference in case the ownership transfer + // to the table/vector goes wrong and we still want to obtain + // diagnostic info from the original value after the assignment + // already unref'd. + auto v_extra = v3; + + switch ( v1->GetType()->Tag() ) { + case TYPE_VECTOR: { + const ListVal* lv = v2->AsListVal(); + VectorVal* v1_vect = v1->AsVectorVal(); + + if ( lv->Length() > 1 ) { + auto len = v1_vect->Size(); + zeek_int_t first = get_slice_index(lv->Idx(0)->CoerceToInt(), len); + zeek_int_t last = get_slice_index(lv->Idx(1)->CoerceToInt(), len); + + // Remove the elements from the vector within the slice. + for ( auto idx = first; idx < last; idx++ ) + v1_vect->Remove(first); + + // Insert the new elements starting at the first + // position. + + VectorVal* v_vect = v3->AsVectorVal(); + + for ( auto idx = 0u; idx < v_vect->Size(); idx++, first++ ) + v1_vect->Insert(first, v_vect->ValAt(idx)); + } + + else if ( ! v1_vect->Assign(lv->Idx(0)->CoerceToUnsigned(), std::move(v3)) ) { + v3 = std::move(v_extra); + + if ( v3 ) { + ODesc d; + v3->Describe(&d); + const auto& vt = v3->GetType(); + auto vtt = vt->Tag(); + std::string tn = vtt == TYPE_RECORD ? vt->GetName() : type_name(vtt); + return util::fmt("vector index assignment failed for invalid type '%s', value: %s", tn.data(), + d.Description()); + } + else + return "assignment failed with null value"; + } + break; + } + + case TYPE_TABLE: { + if ( ! v1->AsTableVal()->Assign(std::move(v2), std::move(v3), true, &iterators_invalidated) ) { + v3 = std::move(v_extra); + + if ( v3 ) { + ODesc d; + v3->Describe(&d); + const auto& vt = v3->GetType(); + auto vtt = vt->Tag(); + std::string tn = vtt == TYPE_RECORD ? vt->GetName() : type_name(vtt); + return util::fmt("table index assignment failed for invalid type '%s', value: %s", tn.data(), + d.Description()); + } + else + return "assignment failed with null value"; + } + + break; + } + + case TYPE_STRING: return "assignment via string index accessor not allowed"; break; + + default: return "bad index expression type in assignment"; break; + } + + return nullptr; +} + +TypePtr Expr::InitType() const { return type; } + +bool Expr::IsRecordElement(TypeDecl* /* td */) const { return false; } + +bool Expr::IsError() const { return type && type->Tag() == TYPE_ERROR; } + +void Expr::SetError() { SetType(error_type()); } + +void Expr::SetError(const char* msg) { + Error(msg); + SetError(); +} + +bool Expr::IsZero() const { return IsConst() && ExprVal()->IsZero(); } + +bool Expr::IsOne() const { return IsConst() && ExprVal()->IsOne(); } + +void Expr::Describe(ODesc* d) const { + if ( IsParen() && ! d->IsBinary() ) + d->Add("("); + + if ( d->IsBinary() ) + AddTag(d); + + ExprDescribe(d); + + if ( IsParen() && ! d->IsBinary() ) + d->Add(")"); +} + +void Expr::AddTag(ODesc* d) const { + if ( d->IsBinary() ) + d->Add(int(Tag())); + else + d->AddSP(expr_name(Tag())); +} + +void Expr::Canonicalize() {} + +void Expr::SetType(TypePtr t) { + if ( ! type || type->Tag() != TYPE_ERROR ) + type = std::move(t); +} + +void Expr::ExprError(const char msg[]) { + Error(msg); + SetError(); +} + +void Expr::RuntimeError(const std::string& msg) const { reporter->ExprRuntimeError(this, "%s", msg.data()); } + +void Expr::RuntimeErrorWithCallStack(const std::string& msg) const { + auto rcs = render_call_stack(); + + if ( rcs.empty() ) + reporter->ExprRuntimeError(this, "%s", msg.data()); + else { + ODesc d; + d.SetShort(); + Describe(&d); + reporter->RuntimeError(GetLocationInfo(), "%s, expression: %s, call stack: %s", msg.data(), d.Description(), + rcs.data()); + } +} + +NameExpr::NameExpr(IDPtr arg_id, bool const_init) : Expr(EXPR_NAME), id(std::move(arg_id)) { + in_const_init = const_init; + + if ( id->IsType() ) + SetType(make_intrusive(id->GetType())); + else + SetType(id->GetType()); + + EventHandler* h = event_registry->Lookup(id->Name()); + if ( h ) + h->SetUsed(); +} // This isn't in-lined to avoid needing to pull in ID.h. -const IDPtr& NameExpr::IdPtr() const - { - return id; - } - -ValPtr NameExpr::Eval(Frame* f) const - { - ValPtr v; - - if ( id->IsType() ) - return make_intrusive(id->GetType(), true); - - if ( id->IsGlobal() ) - v = id->GetVal(); - - else if ( f ) - v = f->GetElementByID(id); - - else - // No frame - evaluating for purposes of resolving a - // compile-time constant. - return nullptr; - - if ( v ) - return v; - else - { - RuntimeError("value used but not set"); - return nullptr; - } - } - -ExprPtr NameExpr::MakeLvalue() - { - if ( id->IsType() ) - ExprError("Type name is not an lvalue"); - - if ( id->IsConst() && ! in_const_init ) - ExprError("const is not a modifiable lvalue"); - - if ( id->IsOption() && ! in_const_init ) - ExprError("option is not a modifiable lvalue"); - - return make_intrusive(ThisPtr()); - } - -void NameExpr::Assign(Frame* f, ValPtr v) - { - if ( id->IsGlobal() ) - id->SetVal(std::move(v)); - else - f->SetElement(id, std::move(v)); - } - -TraversalCode NameExpr::Traverse(TraversalCallback* cb) const - { - TraversalCode tc = cb->PreExpr(this); - HANDLE_TC_EXPR_PRE(tc); - - tc = id->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); - - tc = cb->PostExpr(this); - HANDLE_TC_EXPR_POST(tc); - } - -void NameExpr::ExprDescribe(ODesc* d) const - { - if ( d->IsReadable() ) - d->Add(id->Name()); - else - d->AddCS(id->Name()); - } - -ConstExpr::ConstExpr(ValPtr arg_val) : Expr(EXPR_CONST), val(std::move(arg_val)) - { - if ( val ) - { - if ( val->GetType()->Tag() == TYPE_LIST && val->AsListVal()->Length() == 1 ) - val = val->AsListVal()->Idx(0); - - SetType(val->GetType()); - } - else - SetError(); - } - -void ConstExpr::ExprDescribe(ODesc* d) const - { - val->Describe(d); - } - -ValPtr ConstExpr::Eval(Frame* /* f */) const - { - return {NewRef{}, Value()}; - } - -TraversalCode ConstExpr::Traverse(TraversalCallback* cb) const - { - TraversalCode tc = cb->PreExpr(this); - HANDLE_TC_EXPR_PRE(tc); - - tc = cb->PostExpr(this); - HANDLE_TC_EXPR_POST(tc); - } - -UnaryExpr::UnaryExpr(ExprTag arg_tag, ExprPtr arg_op) : Expr(arg_tag), op(std::move(arg_op)) - { - if ( op->IsError() ) - SetError(); - } - -ValPtr UnaryExpr::Eval(Frame* f) const - { - if ( IsError() ) - return nullptr; - - auto v = op->Eval(f); - - if ( ! v ) - return nullptr; - - if ( is_vector(v) && Tag() != EXPR_IS && Tag() != EXPR_CAST && - // The following allows passing vectors-by-reference to - // functions that use vector-of-any for generic vector - // manipulation ... - Tag() != EXPR_TO_ANY_COERCE && - // ... and the following to avoid vectorizing operations - // on vector-of-any's - Tag() != EXPR_FROM_ANY_COERCE ) - { - VectorVal* v_op = v->AsVectorVal(); - VectorTypePtr out_t; - - if ( GetType()->Tag() == TYPE_ANY ) - out_t = v->GetType(); - else - out_t = GetType(); - - auto result = make_intrusive(std::move(out_t)); - - for ( unsigned int i = 0; i < v_op->Size(); ++i ) - { - auto vop = v_op->ValAt(i); - if ( vop ) - result->Assign(i, Fold(vop.get())); - else - result->Assign(i, nullptr); - } - - return result; - } - else - return Fold(v.get()); - } - -bool UnaryExpr::IsPure() const - { - return op->IsPure(); - } - -TraversalCode UnaryExpr::Traverse(TraversalCallback* cb) const - { - TraversalCode tc = cb->PreExpr(this); - HANDLE_TC_EXPR_PRE(tc); - - tc = op->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); - - tc = cb->PostExpr(this); - HANDLE_TC_EXPR_POST(tc); - } - -ValPtr UnaryExpr::Fold(Val* v) const - { - return {NewRef{}, v}; - } - -void UnaryExpr::ExprDescribe(ODesc* d) const - { - bool is_coerce = Tag() == EXPR_ARITH_COERCE || Tag() == EXPR_RECORD_COERCE || - Tag() == EXPR_TABLE_COERCE; - - if ( d->IsReadable() ) - { - if ( is_coerce ) - d->Add("(coerce "); - else if ( Tag() != EXPR_REF ) - d->Add(expr_name(Tag())); - } - - op->Describe(d); - - if ( d->IsReadable() && is_coerce ) - { - d->Add(" to "); - GetType()->Describe(d); - d->Add(")"); - } - } - -ValPtr BinaryExpr::Eval(Frame* f) const - { - if ( IsError() ) - return nullptr; - - auto v1 = op1->Eval(f); - - if ( ! v1 ) - return nullptr; - - auto v2 = op2->Eval(f); - - if ( ! v2 ) - return nullptr; - - bool is_vec1 = is_vector(v1); - bool is_vec2 = is_vector(v2); - - if ( is_vec1 && is_vec2 ) - { // fold pairs of elements - VectorVal* v_op1 = v1->AsVectorVal(); - VectorVal* v_op2 = v2->AsVectorVal(); - - if ( v_op1->Size() != v_op2->Size() ) - { - RuntimeError("vector operands are of different sizes"); - return nullptr; - } - - auto v_result = make_intrusive(GetType()); - - for ( unsigned int i = 0; i < v_op1->Size(); ++i ) - { - auto v1_i = v_op1->ValAt(i); - auto v2_i = v_op2->ValAt(i); - if ( v1_i && v2_i ) - v_result->Assign(i, Fold(v_op1->ValAt(i).get(), v_op2->ValAt(i).get())); - else - v_result->Assign(i, nullptr); - } - - return v_result; - } - - if ( IsVector(GetType()->Tag()) && (is_vec1 || is_vec2) ) - { // fold vector against scalar - VectorVal* vv = (is_vec1 ? v1 : v2)->AsVectorVal(); - auto v_result = make_intrusive(GetType()); - - for ( unsigned int i = 0; i < vv->Size(); ++i ) - { - auto vv_i = vv->ValAt(i); - if ( vv_i ) - v_result->Assign(i, - is_vec1 ? Fold(vv_i.get(), v2.get()) : Fold(v1.get(), vv_i.get())); - else - v_result->Assign(i, nullptr); - } - - return v_result; - } - - // scalar op scalar - return Fold(v1.get(), v2.get()); - } - -bool BinaryExpr::IsPure() const - { - return op1->IsPure() && op2->IsPure(); - } - -TraversalCode BinaryExpr::Traverse(TraversalCallback* cb) const - { - TraversalCode tc = cb->PreExpr(this); - HANDLE_TC_EXPR_PRE(tc); - - tc = op1->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); - - tc = op2->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); - - tc = cb->PostExpr(this); - HANDLE_TC_EXPR_POST(tc); - } - -void BinaryExpr::ExprDescribe(ODesc* d) const - { - op1->Describe(d); - - d->SP(); - if ( d->IsReadable() ) - d->AddSP(expr_name(Tag())); - - op2->Describe(d); - } - -ValPtr BinaryExpr::Fold(Val* v1, Val* v2) const - { - auto& t1 = v1->GetType(); - InternalTypeTag it = t1->InternalType(); - - if ( it == TYPE_INTERNAL_STRING ) - return StringFold(v1, v2); - - if ( t1->Tag() == TYPE_PATTERN ) - return PatternFold(v1, v2); - - if ( t1->IsSet() ) - return SetFold(v1, v2); - - if ( t1->IsTable() ) - return TableFold(v1, v2); - - if ( t1->Tag() == TYPE_VECTOR ) - { - // We only get here when using a matching vector on the RHS. - if ( ! v2->AsVectorVal()->AddTo(v1, false) ) - Error("incompatible vector element assignment", v2); - return {NewRef{}, v1}; - } - - if ( it == TYPE_INTERNAL_ADDR ) - return AddrFold(v1, v2); - - if ( it == TYPE_INTERNAL_SUBNET ) - return SubNetFold(v1, v2); - - zeek_int_t i1 = 0, i2 = 0, i3 = 0; - zeek_uint_t u1 = 0, u2 = 0, u3 = 0; - double d1 = 0.0, d2 = 0.0, d3 = 0.0; - bool is_integral = false; - bool is_unsigned = false; - - if ( it == TYPE_INTERNAL_INT ) - { - i1 = v1->InternalInt(); - i2 = v2->InternalInt(); - is_integral = true; - } - else if ( it == TYPE_INTERNAL_UNSIGNED ) - { - u1 = v1->InternalUnsigned(); - u2 = v2->InternalUnsigned(); - is_unsigned = true; - } - else if ( it == TYPE_INTERNAL_DOUBLE ) - { - d1 = v1->InternalDouble(); - d2 = v2->InternalDouble(); - } - else - RuntimeErrorWithCallStack("bad type in BinaryExpr::Fold"); - - switch ( tag ) - { -#define DO_INT_FOLD(op) \ - if ( is_integral ) \ - i3 = i1 op i2; \ - else if ( is_unsigned ) \ - u3 = u1 op u2; \ - else \ - RuntimeErrorWithCallStack("bad type in BinaryExpr::Fold"); - -#define DO_UINT_FOLD(op) \ - if ( is_unsigned ) \ - u3 = u1 op u2; \ - else \ - RuntimeErrorWithCallStack("bad type in BinaryExpr::Fold"); - -#define DO_FOLD(op) \ - if ( is_integral ) \ - i3 = i1 op i2; \ - else if ( is_unsigned ) \ - u3 = u1 op u2; \ - else \ - d3 = d1 op d2; - -#define DO_INT_VAL_FOLD(op) \ - if ( is_integral ) \ - i3 = i1 op i2; \ - else if ( is_unsigned ) \ - i3 = u1 op u2; \ - else \ - i3 = d1 op d2; - - case EXPR_ADD: - case EXPR_ADD_TO: - DO_FOLD(+); - break; - case EXPR_SUB: - case EXPR_REMOVE_FROM: - DO_FOLD(-); - // When subtracting and the result is larger than the left - // operand we mostly likely underflowed and log a warning. - if ( is_unsigned && u3 > u1 ) - reporter->ExprRuntimeWarning(this, "count underflow"); - break; - case EXPR_TIMES: - DO_FOLD(*); - break; - case EXPR_DIVIDE: - { - if ( is_integral ) - { - if ( i2 == 0 ) - RuntimeError("division by zero"); - - i3 = i1 / i2; - } - - else if ( is_unsigned ) - { - if ( u2 == 0 ) - RuntimeError("division by zero"); - - u3 = u1 / u2; - } - else - { - if ( d2 == 0 ) - RuntimeError("division by zero"); - - d3 = d1 / d2; - } - } - break; - - case EXPR_MOD: - { - if ( is_integral ) - { - if ( i2 == 0 ) - RuntimeError("modulo by zero"); - - i3 = i1 % i2; - } - - else if ( is_unsigned ) - { - if ( u2 == 0 ) - RuntimeError("modulo by zero"); - - u3 = u1 % u2; - } - - else - RuntimeErrorWithCallStack("bad type in BinaryExpr::Fold"); - } - - break; - - case EXPR_AND: - DO_UINT_FOLD(&); - break; - case EXPR_OR: - DO_UINT_FOLD(|); - break; - case EXPR_XOR: - DO_UINT_FOLD(^); - break; - case EXPR_LSHIFT: - { - if ( is_integral ) - { - if ( i1 < 0 ) - RuntimeError("left shifting a negative number is undefined"); - - i3 = i1 << static_cast(i2); - } - - else if ( is_unsigned ) - u3 = u1 << u2; - - else - RuntimeErrorWithCallStack("bad type in BinaryExpr::Fold"); - break; - } - case EXPR_RSHIFT: - { - if ( is_integral ) - i3 = i1 >> static_cast(i2); - - else if ( is_unsigned ) - u3 = u1 >> u2; - - else - RuntimeErrorWithCallStack("bad type in BinaryExpr::Fold"); - break; - } - - case EXPR_AND_AND: - DO_INT_FOLD(&&); - break; - case EXPR_OR_OR: - DO_INT_FOLD(||); - break; - - case EXPR_LT: - DO_INT_VAL_FOLD(<); - break; - case EXPR_LE: - DO_INT_VAL_FOLD(<=); - break; - case EXPR_EQ: - DO_INT_VAL_FOLD(==); - break; - case EXPR_NE: - DO_INT_VAL_FOLD(!=); - break; - case EXPR_GE: - DO_INT_VAL_FOLD(>=); - break; - case EXPR_GT: - DO_INT_VAL_FOLD(>); - break; - - default: - BadTag("BinaryExpr::Fold", expr_name(tag)); - } - - const auto& ret_type = IsVector(GetType()->Tag()) ? GetType()->Yield() : GetType(); - - if ( ret_type->Tag() == TYPE_INTERVAL ) - return make_intrusive(d3); - else if ( ret_type->Tag() == TYPE_TIME ) - return make_intrusive(d3); - else if ( ret_type->Tag() == TYPE_DOUBLE ) - return make_intrusive(d3); - else if ( ret_type->InternalType() == TYPE_INTERNAL_UNSIGNED ) - return val_mgr->Count(u3); - else if ( ret_type->Tag() == TYPE_BOOL ) - return val_mgr->Bool(i3); - else - return val_mgr->Int(i3); - } - -ValPtr BinaryExpr::StringFold(Val* v1, Val* v2) const - { - const String* s1 = v1->AsString(); - const String* s2 = v2->AsString(); - int result = 0; - - switch ( tag ) - { +const IDPtr& NameExpr::IdPtr() const { return id; } + +ValPtr NameExpr::Eval(Frame* f) const { + ValPtr v; + + if ( id->IsType() ) + return make_intrusive(id->GetType(), true); + + if ( id->IsGlobal() ) + v = id->GetVal(); + + else if ( f ) + v = f->GetElementByID(id); + + else + // No frame - evaluating for purposes of resolving a + // compile-time constant. + return nullptr; + + if ( v ) + return v; + else { + RuntimeError("value used but not set"); + return nullptr; + } +} + +ExprPtr NameExpr::MakeLvalue() { + if ( id->IsType() ) + ExprError("Type name is not an lvalue"); + + if ( id->IsConst() && ! in_const_init ) + ExprError("const is not a modifiable lvalue"); + + if ( id->IsOption() && ! in_const_init ) + ExprError("option is not a modifiable lvalue"); + + return make_intrusive(ThisPtr()); +} + +void NameExpr::Assign(Frame* f, ValPtr v) { + if ( id->IsGlobal() ) + id->SetVal(std::move(v)); + else + f->SetElement(id, std::move(v)); +} + +TraversalCode NameExpr::Traverse(TraversalCallback* cb) const { + TraversalCode tc = cb->PreExpr(this); + HANDLE_TC_EXPR_PRE(tc); + + tc = id->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + + tc = cb->PostExpr(this); + HANDLE_TC_EXPR_POST(tc); +} + +void NameExpr::ExprDescribe(ODesc* d) const { + if ( d->IsReadable() ) + d->Add(id->Name()); + else + d->AddCS(id->Name()); +} + +ConstExpr::ConstExpr(ValPtr arg_val) : Expr(EXPR_CONST), val(std::move(arg_val)) { + if ( val ) { + if ( val->GetType()->Tag() == TYPE_LIST && val->AsListVal()->Length() == 1 ) + val = val->AsListVal()->Idx(0); + + SetType(val->GetType()); + } + else + SetError(); +} + +void ConstExpr::ExprDescribe(ODesc* d) const { val->Describe(d); } + +ValPtr ConstExpr::Eval(Frame* /* f */) const { return {NewRef{}, Value()}; } + +TraversalCode ConstExpr::Traverse(TraversalCallback* cb) const { + TraversalCode tc = cb->PreExpr(this); + HANDLE_TC_EXPR_PRE(tc); + + tc = cb->PostExpr(this); + HANDLE_TC_EXPR_POST(tc); +} + +UnaryExpr::UnaryExpr(ExprTag arg_tag, ExprPtr arg_op) : Expr(arg_tag), op(std::move(arg_op)) { + if ( op->IsError() ) + SetError(); +} + +ValPtr UnaryExpr::Eval(Frame* f) const { + if ( IsError() ) + return nullptr; + + auto v = op->Eval(f); + + if ( ! v ) + return nullptr; + + if ( is_vector(v) && Tag() != EXPR_IS && Tag() != EXPR_CAST && + // The following allows passing vectors-by-reference to + // functions that use vector-of-any for generic vector + // manipulation ... + Tag() != EXPR_TO_ANY_COERCE && + // ... and the following to avoid vectorizing operations + // on vector-of-any's + Tag() != EXPR_FROM_ANY_COERCE ) { + VectorVal* v_op = v->AsVectorVal(); + VectorTypePtr out_t; + + if ( GetType()->Tag() == TYPE_ANY ) + out_t = v->GetType(); + else + out_t = GetType(); + + auto result = make_intrusive(std::move(out_t)); + + for ( unsigned int i = 0; i < v_op->Size(); ++i ) { + auto vop = v_op->ValAt(i); + if ( vop ) + result->Assign(i, Fold(vop.get())); + else + result->Assign(i, nullptr); + } + + return result; + } + else + return Fold(v.get()); +} + +bool UnaryExpr::IsPure() const { return op->IsPure(); } + +TraversalCode UnaryExpr::Traverse(TraversalCallback* cb) const { + TraversalCode tc = cb->PreExpr(this); + HANDLE_TC_EXPR_PRE(tc); + + tc = op->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + + tc = cb->PostExpr(this); + HANDLE_TC_EXPR_POST(tc); +} + +ValPtr UnaryExpr::Fold(Val* v) const { return {NewRef{}, v}; } + +void UnaryExpr::ExprDescribe(ODesc* d) const { + bool is_coerce = Tag() == EXPR_ARITH_COERCE || Tag() == EXPR_RECORD_COERCE || Tag() == EXPR_TABLE_COERCE; + + if ( d->IsReadable() ) { + if ( is_coerce ) + d->Add("(coerce "); + else if ( Tag() != EXPR_REF ) + d->Add(expr_name(Tag())); + } + + op->Describe(d); + + if ( d->IsReadable() && is_coerce ) { + d->Add(" to "); + GetType()->Describe(d); + d->Add(")"); + } +} + +ValPtr BinaryExpr::Eval(Frame* f) const { + if ( IsError() ) + return nullptr; + + auto v1 = op1->Eval(f); + + if ( ! v1 ) + return nullptr; + + auto v2 = op2->Eval(f); + + if ( ! v2 ) + return nullptr; + + bool is_vec1 = is_vector(v1); + bool is_vec2 = is_vector(v2); + + if ( is_vec1 && is_vec2 ) { // fold pairs of elements + VectorVal* v_op1 = v1->AsVectorVal(); + VectorVal* v_op2 = v2->AsVectorVal(); + + if ( v_op1->Size() != v_op2->Size() ) { + RuntimeError("vector operands are of different sizes"); + return nullptr; + } + + auto v_result = make_intrusive(GetType()); + + for ( unsigned int i = 0; i < v_op1->Size(); ++i ) { + auto v1_i = v_op1->ValAt(i); + auto v2_i = v_op2->ValAt(i); + if ( v1_i && v2_i ) + v_result->Assign(i, Fold(v_op1->ValAt(i).get(), v_op2->ValAt(i).get())); + else + v_result->Assign(i, nullptr); + } + + return v_result; + } + + if ( IsVector(GetType()->Tag()) && (is_vec1 || is_vec2) ) { // fold vector against scalar + VectorVal* vv = (is_vec1 ? v1 : v2)->AsVectorVal(); + auto v_result = make_intrusive(GetType()); + + for ( unsigned int i = 0; i < vv->Size(); ++i ) { + auto vv_i = vv->ValAt(i); + if ( vv_i ) + v_result->Assign(i, is_vec1 ? Fold(vv_i.get(), v2.get()) : Fold(v1.get(), vv_i.get())); + else + v_result->Assign(i, nullptr); + } + + return v_result; + } + + // scalar op scalar + return Fold(v1.get(), v2.get()); +} + +bool BinaryExpr::IsPure() const { return op1->IsPure() && op2->IsPure(); } + +TraversalCode BinaryExpr::Traverse(TraversalCallback* cb) const { + TraversalCode tc = cb->PreExpr(this); + HANDLE_TC_EXPR_PRE(tc); + + tc = op1->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + + tc = op2->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + + tc = cb->PostExpr(this); + HANDLE_TC_EXPR_POST(tc); +} + +void BinaryExpr::ExprDescribe(ODesc* d) const { + op1->Describe(d); + + d->SP(); + if ( d->IsReadable() ) + d->AddSP(expr_name(Tag())); + + op2->Describe(d); +} + +ValPtr BinaryExpr::Fold(Val* v1, Val* v2) const { + auto& t1 = v1->GetType(); + InternalTypeTag it = t1->InternalType(); + + if ( it == TYPE_INTERNAL_STRING ) + return StringFold(v1, v2); + + if ( t1->Tag() == TYPE_PATTERN ) + return PatternFold(v1, v2); + + if ( t1->IsSet() ) + return SetFold(v1, v2); + + if ( t1->IsTable() ) + return TableFold(v1, v2); + + if ( t1->Tag() == TYPE_VECTOR ) { + // We only get here when using a matching vector on the RHS. + if ( ! v2->AsVectorVal()->AddTo(v1, false) ) + Error("incompatible vector element assignment", v2); + return {NewRef{}, v1}; + } + + if ( it == TYPE_INTERNAL_ADDR ) + return AddrFold(v1, v2); + + if ( it == TYPE_INTERNAL_SUBNET ) + return SubNetFold(v1, v2); + + zeek_int_t i1 = 0, i2 = 0, i3 = 0; + zeek_uint_t u1 = 0, u2 = 0, u3 = 0; + double d1 = 0.0, d2 = 0.0, d3 = 0.0; + bool is_integral = false; + bool is_unsigned = false; + + if ( it == TYPE_INTERNAL_INT ) { + i1 = v1->InternalInt(); + i2 = v2->InternalInt(); + is_integral = true; + } + else if ( it == TYPE_INTERNAL_UNSIGNED ) { + u1 = v1->InternalUnsigned(); + u2 = v2->InternalUnsigned(); + is_unsigned = true; + } + else if ( it == TYPE_INTERNAL_DOUBLE ) { + d1 = v1->InternalDouble(); + d2 = v2->InternalDouble(); + } + else + RuntimeErrorWithCallStack("bad type in BinaryExpr::Fold"); + + switch ( tag ) { +#define DO_INT_FOLD(op) \ + if ( is_integral ) \ + i3 = i1 op i2; \ + else if ( is_unsigned ) \ + u3 = u1 op u2; \ + else \ + RuntimeErrorWithCallStack("bad type in BinaryExpr::Fold"); + +#define DO_UINT_FOLD(op) \ + if ( is_unsigned ) \ + u3 = u1 op u2; \ + else \ + RuntimeErrorWithCallStack("bad type in BinaryExpr::Fold"); + +#define DO_FOLD(op) \ + if ( is_integral ) \ + i3 = i1 op i2; \ + else if ( is_unsigned ) \ + u3 = u1 op u2; \ + else \ + d3 = d1 op d2; + +#define DO_INT_VAL_FOLD(op) \ + if ( is_integral ) \ + i3 = i1 op i2; \ + else if ( is_unsigned ) \ + i3 = u1 op u2; \ + else \ + i3 = d1 op d2; + + case EXPR_ADD: + case EXPR_ADD_TO: DO_FOLD(+); break; + case EXPR_SUB: + case EXPR_REMOVE_FROM: + DO_FOLD(-); + // When subtracting and the result is larger than the left + // operand we mostly likely underflowed and log a warning. + if ( is_unsigned && u3 > u1 ) + reporter->ExprRuntimeWarning(this, "count underflow"); + break; + case EXPR_TIMES: DO_FOLD(*); break; + case EXPR_DIVIDE: { + if ( is_integral ) { + if ( i2 == 0 ) + RuntimeError("division by zero"); + + i3 = i1 / i2; + } + + else if ( is_unsigned ) { + if ( u2 == 0 ) + RuntimeError("division by zero"); + + u3 = u1 / u2; + } + else { + if ( d2 == 0 ) + RuntimeError("division by zero"); + + d3 = d1 / d2; + } + } break; + + case EXPR_MOD: { + if ( is_integral ) { + if ( i2 == 0 ) + RuntimeError("modulo by zero"); + + i3 = i1 % i2; + } + + else if ( is_unsigned ) { + if ( u2 == 0 ) + RuntimeError("modulo by zero"); + + u3 = u1 % u2; + } + + else + RuntimeErrorWithCallStack("bad type in BinaryExpr::Fold"); + } + + break; + + case EXPR_AND: DO_UINT_FOLD(&); break; + case EXPR_OR: DO_UINT_FOLD(|); break; + case EXPR_XOR: DO_UINT_FOLD(^); break; + case EXPR_LSHIFT: { + if ( is_integral ) { + if ( i1 < 0 ) + RuntimeError("left shifting a negative number is undefined"); + + i3 = i1 << static_cast(i2); + } + + else if ( is_unsigned ) + u3 = u1 << u2; + + else + RuntimeErrorWithCallStack("bad type in BinaryExpr::Fold"); + break; + } + case EXPR_RSHIFT: { + if ( is_integral ) + i3 = i1 >> static_cast(i2); + + else if ( is_unsigned ) + u3 = u1 >> u2; + + else + RuntimeErrorWithCallStack("bad type in BinaryExpr::Fold"); + break; + } + + case EXPR_AND_AND: DO_INT_FOLD(&&); break; + case EXPR_OR_OR: DO_INT_FOLD(||); break; + + case EXPR_LT: DO_INT_VAL_FOLD(<); break; + case EXPR_LE: DO_INT_VAL_FOLD(<=); break; + case EXPR_EQ: DO_INT_VAL_FOLD(==); break; + case EXPR_NE: DO_INT_VAL_FOLD(!=); break; + case EXPR_GE: DO_INT_VAL_FOLD(>=); break; + case EXPR_GT: DO_INT_VAL_FOLD(>); break; + + default: BadTag("BinaryExpr::Fold", expr_name(tag)); + } + + const auto& ret_type = IsVector(GetType()->Tag()) ? GetType()->Yield() : GetType(); + + if ( ret_type->Tag() == TYPE_INTERVAL ) + return make_intrusive(d3); + else if ( ret_type->Tag() == TYPE_TIME ) + return make_intrusive(d3); + else if ( ret_type->Tag() == TYPE_DOUBLE ) + return make_intrusive(d3); + else if ( ret_type->InternalType() == TYPE_INTERNAL_UNSIGNED ) + return val_mgr->Count(u3); + else if ( ret_type->Tag() == TYPE_BOOL ) + return val_mgr->Bool(i3); + else + return val_mgr->Int(i3); +} + +ValPtr BinaryExpr::StringFold(Val* v1, Val* v2) const { + const String* s1 = v1->AsString(); + const String* s2 = v2->AsString(); + int result = 0; + + switch ( tag ) { #undef DO_FOLD -#define DO_FOLD(sense) \ - { \ - result = Bstr_cmp(s1, s2) sense 0; \ - break; \ - } - - case EXPR_LT: - DO_FOLD(<) - case EXPR_LE: - DO_FOLD(<=) - case EXPR_EQ: - DO_FOLD(==) - case EXPR_NE: - DO_FOLD(!=) - case EXPR_GE: - DO_FOLD(>=) - case EXPR_GT: - DO_FOLD(>) - - case EXPR_ADD: - case EXPR_ADD_TO: - { - std::vector strings; - strings.push_back(s1); - strings.push_back(s2); - - return make_intrusive(concatenate(strings)); - } - - default: - BadTag("BinaryExpr::StringFold", expr_name(tag)); - } - - return val_mgr->Bool(result); - } - -ValPtr BinaryExpr::PatternFold(Val* v1, Val* v2) const - { - const RE_Matcher* re1 = v1->AsPattern(); - const RE_Matcher* re2 = v2->AsPattern(); - - if ( tag != EXPR_AND && tag != EXPR_OR ) - BadTag("BinaryExpr::PatternFold"); - - RE_Matcher* res = tag == EXPR_AND ? RE_Matcher_conjunction(re1, re2) - : RE_Matcher_disjunction(re1, re2); - - return make_intrusive(res); - } - -ValPtr BinaryExpr::SetFold(Val* v1, Val* v2) const - { - TableVal* tv1 = v1->AsTableVal(); - TableVal* tv2 = v2->AsTableVal(); - bool res = false; - - switch ( tag ) - { - case EXPR_AND: - return tv1->Intersection(*tv2); - - case EXPR_OR: - { - auto rval = v1->Clone(); - - if ( ! tv2->AddTo(rval.get(), false, false) ) - reporter->InternalError("set union failed to type check"); - - return rval; - } - - case EXPR_SUB: - { - auto rval = v1->Clone(); - - if ( ! tv2->RemoveFrom(rval.get()) ) - reporter->InternalError("set difference failed to type check"); - - return rval; - } - - case EXPR_EQ: - res = tv1->EqualTo(*tv2); - break; - - case EXPR_NE: - res = ! tv1->EqualTo(*tv2); - break; - - case EXPR_LT: - res = tv1->IsSubsetOf(*tv2) && tv1->Size() < tv2->Size(); - break; - - case EXPR_LE: - res = tv1->IsSubsetOf(*tv2); - break; - - case EXPR_GE: - case EXPR_GT: - // These shouldn't happen due to canonicalization. - reporter->InternalError("confusion over canonicalization in set comparison"); - break; - - case EXPR_ADD_TO: - // Avoid doing the AddTo operation if tv2 is empty, - // because then it might not type-check for trivial - // reasons. - if ( tv2->Size() > 0 ) - tv2->AddTo(tv1, false); - return {NewRef{}, tv1}; - - case EXPR_REMOVE_FROM: - if ( tv2->Size() > 0 ) - tv2->RemoveFrom(tv1); - return {NewRef{}, tv1}; - - default: - BadTag("BinaryExpr::SetFold", expr_name(tag)); - return nullptr; - } - - return val_mgr->Bool(res); - } - -ValPtr BinaryExpr::TableFold(Val* v1, Val* v2) const - { - TableVal* tv1 = v1->AsTableVal(); - TableVal* tv2 = v2->AsTableVal(); - - switch ( tag ) - { - case EXPR_ADD_TO: - // Avoid doing the AddTo operation if tv2 is empty, - // because then it might not type-check for trivial - // reasons. - if ( tv2->Size() > 0 ) - tv2->AddTo(tv1, false); - return {NewRef{}, tv1}; - - case EXPR_REMOVE_FROM: - if ( tv2->Size() > 0 ) - tv2->RemoveFrom(tv1); - return {NewRef{}, tv1}; - - default: - BadTag("BinaryExpr::TableFold", expr_name(tag)); - } - - return nullptr; - } - -ValPtr BinaryExpr::AddrFold(Val* v1, Val* v2) const - { - IPAddr a1 = v1->AsAddr(); - IPAddr a2 = v2->AsAddr(); - bool result = false; - - switch ( tag ) - { - - case EXPR_LT: - result = a1 < a2; - break; - case EXPR_LE: - result = a1 < a2 || a1 == a2; - break; - case EXPR_EQ: - result = a1 == a2; - break; - case EXPR_NE: - result = a1 != a2; - break; - case EXPR_GE: - result = ! (a1 < a2); - break; - case EXPR_GT: - result = (! (a1 < a2)) && (a1 != a2); - break; - - default: - BadTag("BinaryExpr::AddrFold", expr_name(tag)); - } - - return val_mgr->Bool(result); - } - -ValPtr BinaryExpr::SubNetFold(Val* v1, Val* v2) const - { - const IPPrefix& n1 = v1->AsSubNet(); - const IPPrefix& n2 = v2->AsSubNet(); - - bool result = n1 == n2; - - if ( tag == EXPR_NE ) - result = ! result; - - return val_mgr->Bool(result); - } - -void BinaryExpr::SwapOps() - { - // We could check here whether the operator is commutative. - using std::swap; - swap(op1, op2); - } - -void BinaryExpr::PromoteOps(TypeTag t) - { - TypeTag bt1 = op1->GetType()->Tag(); - TypeTag bt2 = op2->GetType()->Tag(); - - bool is_vec1 = IsVector(bt1); - bool is_vec2 = IsVector(bt2); - - if ( is_vec1 ) - bt1 = op1->GetType()->AsVectorType()->Yield()->Tag(); - if ( is_vec2 ) - bt2 = op2->GetType()->AsVectorType()->Yield()->Tag(); - - if ( bt1 != t ) - op1 = make_intrusive(op1, t); - if ( bt2 != t ) - op2 = make_intrusive(op2, t); - } - -void BinaryExpr::PromoteType(TypeTag t, bool is_vector) - { - PromoteOps(t); - - if ( is_vector ) - SetType(make_intrusive(base_type(t))); - else - SetType(base_type(t)); - } - -void BinaryExpr::PromoteForInterval(ExprPtr& op) - { - if ( is_vector(op1) || is_vector(op2) ) - SetType(make_intrusive(base_type(TYPE_INTERVAL))); - else - SetType(base_type(TYPE_INTERVAL)); - - if ( op->GetType()->Tag() != TYPE_DOUBLE ) - op = make_intrusive(op, TYPE_DOUBLE); - } - -bool BinaryExpr::CheckForRHSList() - { - if ( op2->Tag() != EXPR_LIST ) - return false; - - auto lhs_t = op1->GetType(); - auto rhs = cast_intrusive(op2); - auto& rhs_exprs = rhs->Exprs(); - - if ( lhs_t->Tag() == TYPE_TABLE ) - { - if ( lhs_t->IsSet() && rhs_exprs.size() >= 1 && same_type(lhs_t, rhs_exprs[0]->GetType()) ) - { - // This is potentially the idiom of "set1 += { set2 }" - // or "set1 += { set2, set3, set4 }". - op2 = {NewRef{}, rhs_exprs[0]}; - - for ( auto i = 1U; i < rhs_exprs.size(); ++i ) - { - ExprPtr re_i = {NewRef{}, rhs_exprs[i]}; - op2 = make_intrusive(EXPR_OR, op2, re_i); - } - - SetType(op1->GetType()); - - return true; - } - - if ( lhs_t->IsTable() && rhs_exprs.size() == 1 && - same_type(lhs_t, rhs_exprs[0]->GetType()) ) - { - // This is the idiom of "table1 += { table2 }" (or -=). - // Unlike for sets we don't allow more than one table - // in the RHS list because table "union" isn't - // well-defined. - op2 = {NewRef{}, rhs_exprs[0]}; - SetType(op1->GetType()); - - return true; - } - - if ( lhs_t->IsTable() ) - op2 = make_intrusive(rhs, nullptr, lhs_t); - else - op2 = make_intrusive(rhs, nullptr, lhs_t); - } - - else if ( lhs_t->Tag() == TYPE_VECTOR ) - { - if ( tag == EXPR_REMOVE_FROM ) - { - ExprError("constructor list not allowed for -= operations on vectors"); - return false; - } - - op2 = make_intrusive(rhs, lhs_t); - } - - else - { - ExprError("invalid constructor list on RHS of assignment"); - return false; - } - - if ( op2->IsError() ) - { - // Message should have already been generated, but propagate. - SetError(); - return false; - } - - // Don't bother type-checking for the degenerate case of the RHS - // being empty, since it won't actually matter. - if ( ! rhs_exprs.empty() && ! same_type(op1->GetType(), op2->GetType()) ) - { - ExprError("type clash for constructor list on RHS of assignment"); - return false; - } - - SetType(op1->GetType()); - - return true; - } - -CloneExpr::CloneExpr(ExprPtr arg_op) : UnaryExpr(EXPR_CLONE, std::move(arg_op)) - { - if ( IsError() ) - return; - - SetType(op->GetType()); - } - -ValPtr CloneExpr::Eval(Frame* f) const - { - if ( IsError() ) - return nullptr; - - if ( auto v = op->Eval(f) ) - return Fold(v.get()); - - return nullptr; - } - -ValPtr CloneExpr::Fold(Val* v) const - { - return v->Clone(); - } - -IncrExpr::IncrExpr(ExprTag arg_tag, ExprPtr arg_op) : UnaryExpr(arg_tag, arg_op->MakeLvalue()) - { - if ( IsError() ) - return; - - const auto& t = op->GetType(); - if ( ! IsIntegral(t->Tag()) ) - ExprError("requires an integral operand"); - else - SetType(t); - } - -ValPtr IncrExpr::DoSingleEval(Frame* f, Val* v) const - { - zeek_int_t k = v->CoerceToInt(); - - if ( Tag() == EXPR_INCR ) - ++k; - else - { - --k; - - if ( k < 0 && v->GetType()->InternalType() == TYPE_INTERNAL_UNSIGNED ) - reporter->ExprRuntimeWarning(this, "count underflow"); - } - - const auto& ret_type = IsVector(GetType()->Tag()) ? GetType()->Yield() : GetType(); - - if ( ret_type->Tag() == TYPE_INT ) - return val_mgr->Int(k); - else - return val_mgr->Count(k); - } - -ValPtr IncrExpr::Eval(Frame* f) const - { - auto v = op->Eval(f); - - if ( ! v ) - return nullptr; - - auto new_v = DoSingleEval(f, v.get()); - op->Assign(f, new_v); - return new_v; - } - -ComplementExpr::ComplementExpr(ExprPtr arg_op) : UnaryExpr(EXPR_COMPLEMENT, std::move(arg_op)) - { - if ( IsError() ) - return; - - const auto& t = op->GetType(); - TypeTag bt = t->Tag(); - - if ( bt != TYPE_COUNT ) - ExprError("requires \"count\" operand"); - else - SetType(base_type(TYPE_COUNT)); - } - -ValPtr ComplementExpr::Fold(Val* v) const - { - return val_mgr->Count(~v->InternalUnsigned()); - } - -NotExpr::NotExpr(ExprPtr arg_op) : UnaryExpr(EXPR_NOT, std::move(arg_op)) - { - if ( IsError() ) - return; - - TypeTag bt = op->GetType()->Tag(); - - if ( ! IsIntegral(bt) && bt != TYPE_BOOL ) - ExprError("requires an integral or boolean operand"); - else - SetType(base_type(TYPE_BOOL)); - } - -ValPtr NotExpr::Fold(Val* v) const - { - return val_mgr->Bool(! v->InternalInt()); - } - -PosExpr::PosExpr(ExprPtr arg_op) : UnaryExpr(EXPR_POSITIVE, std::move(arg_op)) - { - if ( IsError() ) - return; - - const auto& t = IsVector(op->GetType()->Tag()) ? op->GetType()->Yield() : op->GetType(); - - TypeTag bt = t->Tag(); - TypePtr base_result_type; - - if ( IsIntegral(bt) ) - // Promote count and counter to int. - base_result_type = base_type(TYPE_INT); - else if ( bt == TYPE_INTERVAL || bt == TYPE_DOUBLE ) - base_result_type = t; - else - ExprError("requires an integral or double operand"); - - if ( is_vector(op) ) - SetType(make_intrusive(std::move(base_result_type))); - else - SetType(std::move(base_result_type)); - } - -ValPtr PosExpr::Fold(Val* v) const - { - TypeTag t = v->GetType()->Tag(); - - if ( t == TYPE_DOUBLE || t == TYPE_INTERVAL || t == TYPE_INT ) - return {NewRef{}, v}; - else - return val_mgr->Int(v->CoerceToInt()); - } - -NegExpr::NegExpr(ExprPtr arg_op) : UnaryExpr(EXPR_NEGATE, std::move(arg_op)) - { - if ( IsError() ) - return; - - const auto& t = IsVector(op->GetType()->Tag()) ? op->GetType()->Yield() : op->GetType(); - - TypeTag bt = t->Tag(); - TypePtr base_result_type; - - if ( IsIntegral(bt) ) - // Promote count and counter to int. - base_result_type = base_type(TYPE_INT); - else if ( bt == TYPE_INTERVAL || bt == TYPE_DOUBLE ) - base_result_type = t; - else - ExprError("requires an integral or double operand"); - - if ( is_vector(op) ) - SetType(make_intrusive(std::move(base_result_type))); - else - SetType(std::move(base_result_type)); - } - -ValPtr NegExpr::Fold(Val* v) const - { - if ( v->GetType()->Tag() == TYPE_DOUBLE ) - return make_intrusive(-v->InternalDouble()); - else if ( v->GetType()->Tag() == TYPE_INTERVAL ) - return make_intrusive(-v->InternalDouble()); - else - return val_mgr->Int(-v->CoerceToInt()); - } - -SizeExpr::SizeExpr(ExprPtr arg_op) : UnaryExpr(EXPR_SIZE, std::move(arg_op)) - { - if ( IsError() ) - return; - - auto& t = op->GetType(); - - if ( t->Tag() == TYPE_ANY ) - SetType(base_type(TYPE_ANY)); - else if ( t->Tag() == TYPE_FILE || t->Tag() == TYPE_SUBNET || - t->InternalType() == TYPE_INTERNAL_DOUBLE ) - SetType(base_type(TYPE_DOUBLE)); - else - SetType(base_type(TYPE_COUNT)); - } - -ValPtr SizeExpr::Eval(Frame* f) const - { - auto v = op->Eval(f); - - if ( ! v ) - return nullptr; - - return Fold(v.get()); - } - -ValPtr SizeExpr::Fold(Val* v) const - { - return v->SizeVal(); - } +#define DO_FOLD(sense) \ + { \ + result = Bstr_cmp(s1, s2) sense 0; \ + break; \ + } + + case EXPR_LT: DO_FOLD(<) + case EXPR_LE: DO_FOLD(<=) + case EXPR_EQ: DO_FOLD(==) + case EXPR_NE: DO_FOLD(!=) + case EXPR_GE: DO_FOLD(>=) + case EXPR_GT: DO_FOLD(>) + + case EXPR_ADD: + case EXPR_ADD_TO: { + std::vector strings; + strings.push_back(s1); + strings.push_back(s2); + + return make_intrusive(concatenate(strings)); + } + + default: BadTag("BinaryExpr::StringFold", expr_name(tag)); + } + + return val_mgr->Bool(result); +} + +ValPtr BinaryExpr::PatternFold(Val* v1, Val* v2) const { + const RE_Matcher* re1 = v1->AsPattern(); + const RE_Matcher* re2 = v2->AsPattern(); + + if ( tag != EXPR_AND && tag != EXPR_OR ) + BadTag("BinaryExpr::PatternFold"); + + RE_Matcher* res = tag == EXPR_AND ? RE_Matcher_conjunction(re1, re2) : RE_Matcher_disjunction(re1, re2); + + return make_intrusive(res); +} + +ValPtr BinaryExpr::SetFold(Val* v1, Val* v2) const { + TableVal* tv1 = v1->AsTableVal(); + TableVal* tv2 = v2->AsTableVal(); + bool res = false; + + switch ( tag ) { + case EXPR_AND: return tv1->Intersection(*tv2); + + case EXPR_OR: { + auto rval = v1->Clone(); + + if ( ! tv2->AddTo(rval.get(), false, false) ) + reporter->InternalError("set union failed to type check"); + + return rval; + } + + case EXPR_SUB: { + auto rval = v1->Clone(); + + if ( ! tv2->RemoveFrom(rval.get()) ) + reporter->InternalError("set difference failed to type check"); + + return rval; + } + + case EXPR_EQ: res = tv1->EqualTo(*tv2); break; + + case EXPR_NE: res = ! tv1->EqualTo(*tv2); break; + + case EXPR_LT: res = tv1->IsSubsetOf(*tv2) && tv1->Size() < tv2->Size(); break; + + case EXPR_LE: res = tv1->IsSubsetOf(*tv2); break; + + case EXPR_GE: + case EXPR_GT: + // These shouldn't happen due to canonicalization. + reporter->InternalError("confusion over canonicalization in set comparison"); + break; + + case EXPR_ADD_TO: + // Avoid doing the AddTo operation if tv2 is empty, + // because then it might not type-check for trivial + // reasons. + if ( tv2->Size() > 0 ) + tv2->AddTo(tv1, false); + return {NewRef{}, tv1}; + + case EXPR_REMOVE_FROM: + if ( tv2->Size() > 0 ) + tv2->RemoveFrom(tv1); + return {NewRef{}, tv1}; + + default: BadTag("BinaryExpr::SetFold", expr_name(tag)); return nullptr; + } + + return val_mgr->Bool(res); +} + +ValPtr BinaryExpr::TableFold(Val* v1, Val* v2) const { + TableVal* tv1 = v1->AsTableVal(); + TableVal* tv2 = v2->AsTableVal(); + + switch ( tag ) { + case EXPR_ADD_TO: + // Avoid doing the AddTo operation if tv2 is empty, + // because then it might not type-check for trivial + // reasons. + if ( tv2->Size() > 0 ) + tv2->AddTo(tv1, false); + return {NewRef{}, tv1}; + + case EXPR_REMOVE_FROM: + if ( tv2->Size() > 0 ) + tv2->RemoveFrom(tv1); + return {NewRef{}, tv1}; + + default: BadTag("BinaryExpr::TableFold", expr_name(tag)); + } + + return nullptr; +} + +ValPtr BinaryExpr::AddrFold(Val* v1, Val* v2) const { + IPAddr a1 = v1->AsAddr(); + IPAddr a2 = v2->AsAddr(); + bool result = false; + + switch ( tag ) { + case EXPR_LT: result = a1 < a2; break; + case EXPR_LE: result = a1 < a2 || a1 == a2; break; + case EXPR_EQ: result = a1 == a2; break; + case EXPR_NE: result = a1 != a2; break; + case EXPR_GE: result = ! (a1 < a2); break; + case EXPR_GT: result = (! (a1 < a2)) && (a1 != a2); break; + + default: BadTag("BinaryExpr::AddrFold", expr_name(tag)); + } + + return val_mgr->Bool(result); +} + +ValPtr BinaryExpr::SubNetFold(Val* v1, Val* v2) const { + const IPPrefix& n1 = v1->AsSubNet(); + const IPPrefix& n2 = v2->AsSubNet(); + + bool result = n1 == n2; + + if ( tag == EXPR_NE ) + result = ! result; + + return val_mgr->Bool(result); +} + +void BinaryExpr::SwapOps() { + // We could check here whether the operator is commutative. + using std::swap; + swap(op1, op2); +} + +void BinaryExpr::PromoteOps(TypeTag t) { + TypeTag bt1 = op1->GetType()->Tag(); + TypeTag bt2 = op2->GetType()->Tag(); + + bool is_vec1 = IsVector(bt1); + bool is_vec2 = IsVector(bt2); + + if ( is_vec1 ) + bt1 = op1->GetType()->AsVectorType()->Yield()->Tag(); + if ( is_vec2 ) + bt2 = op2->GetType()->AsVectorType()->Yield()->Tag(); + + if ( bt1 != t ) + op1 = make_intrusive(op1, t); + if ( bt2 != t ) + op2 = make_intrusive(op2, t); +} + +void BinaryExpr::PromoteType(TypeTag t, bool is_vector) { + PromoteOps(t); + + if ( is_vector ) + SetType(make_intrusive(base_type(t))); + else + SetType(base_type(t)); +} + +void BinaryExpr::PromoteForInterval(ExprPtr& op) { + if ( is_vector(op1) || is_vector(op2) ) + SetType(make_intrusive(base_type(TYPE_INTERVAL))); + else + SetType(base_type(TYPE_INTERVAL)); + + if ( op->GetType()->Tag() != TYPE_DOUBLE ) + op = make_intrusive(op, TYPE_DOUBLE); +} + +bool BinaryExpr::CheckForRHSList() { + if ( op2->Tag() != EXPR_LIST ) + return false; + + auto lhs_t = op1->GetType(); + auto rhs = cast_intrusive(op2); + auto& rhs_exprs = rhs->Exprs(); + + if ( lhs_t->Tag() == TYPE_TABLE ) { + if ( lhs_t->IsSet() && rhs_exprs.size() >= 1 && same_type(lhs_t, rhs_exprs[0]->GetType()) ) { + // This is potentially the idiom of "set1 += { set2 }" + // or "set1 += { set2, set3, set4 }". + op2 = {NewRef{}, rhs_exprs[0]}; + + for ( auto i = 1U; i < rhs_exprs.size(); ++i ) { + ExprPtr re_i = {NewRef{}, rhs_exprs[i]}; + op2 = make_intrusive(EXPR_OR, op2, re_i); + } + + SetType(op1->GetType()); + + return true; + } + + if ( lhs_t->IsTable() && rhs_exprs.size() == 1 && same_type(lhs_t, rhs_exprs[0]->GetType()) ) { + // This is the idiom of "table1 += { table2 }" (or -=). + // Unlike for sets we don't allow more than one table + // in the RHS list because table "union" isn't + // well-defined. + op2 = {NewRef{}, rhs_exprs[0]}; + SetType(op1->GetType()); + + return true; + } + + if ( lhs_t->IsTable() ) + op2 = make_intrusive(rhs, nullptr, lhs_t); + else + op2 = make_intrusive(rhs, nullptr, lhs_t); + } + + else if ( lhs_t->Tag() == TYPE_VECTOR ) { + if ( tag == EXPR_REMOVE_FROM ) { + ExprError("constructor list not allowed for -= operations on vectors"); + return false; + } + + op2 = make_intrusive(rhs, lhs_t); + } + + else { + ExprError("invalid constructor list on RHS of assignment"); + return false; + } + + if ( op2->IsError() ) { + // Message should have already been generated, but propagate. + SetError(); + return false; + } + + // Don't bother type-checking for the degenerate case of the RHS + // being empty, since it won't actually matter. + if ( ! rhs_exprs.empty() && ! same_type(op1->GetType(), op2->GetType()) ) { + ExprError("type clash for constructor list on RHS of assignment"); + return false; + } + + SetType(op1->GetType()); + + return true; +} + +CloneExpr::CloneExpr(ExprPtr arg_op) : UnaryExpr(EXPR_CLONE, std::move(arg_op)) { + if ( IsError() ) + return; + + SetType(op->GetType()); +} + +ValPtr CloneExpr::Eval(Frame* f) const { + if ( IsError() ) + return nullptr; + + if ( auto v = op->Eval(f) ) + return Fold(v.get()); + + return nullptr; +} + +ValPtr CloneExpr::Fold(Val* v) const { return v->Clone(); } + +IncrExpr::IncrExpr(ExprTag arg_tag, ExprPtr arg_op) : UnaryExpr(arg_tag, arg_op->MakeLvalue()) { + if ( IsError() ) + return; + + const auto& t = op->GetType(); + if ( ! IsIntegral(t->Tag()) ) + ExprError("requires an integral operand"); + else + SetType(t); +} + +ValPtr IncrExpr::DoSingleEval(Frame* f, Val* v) const { + zeek_int_t k = v->CoerceToInt(); + + if ( Tag() == EXPR_INCR ) + ++k; + else { + --k; + + if ( k < 0 && v->GetType()->InternalType() == TYPE_INTERNAL_UNSIGNED ) + reporter->ExprRuntimeWarning(this, "count underflow"); + } + + const auto& ret_type = IsVector(GetType()->Tag()) ? GetType()->Yield() : GetType(); + + if ( ret_type->Tag() == TYPE_INT ) + return val_mgr->Int(k); + else + return val_mgr->Count(k); +} + +ValPtr IncrExpr::Eval(Frame* f) const { + auto v = op->Eval(f); + + if ( ! v ) + return nullptr; + + auto new_v = DoSingleEval(f, v.get()); + op->Assign(f, new_v); + return new_v; +} + +ComplementExpr::ComplementExpr(ExprPtr arg_op) : UnaryExpr(EXPR_COMPLEMENT, std::move(arg_op)) { + if ( IsError() ) + return; + + const auto& t = op->GetType(); + TypeTag bt = t->Tag(); + + if ( bt != TYPE_COUNT ) + ExprError("requires \"count\" operand"); + else + SetType(base_type(TYPE_COUNT)); +} + +ValPtr ComplementExpr::Fold(Val* v) const { return val_mgr->Count(~v->InternalUnsigned()); } + +NotExpr::NotExpr(ExprPtr arg_op) : UnaryExpr(EXPR_NOT, std::move(arg_op)) { + if ( IsError() ) + return; + + TypeTag bt = op->GetType()->Tag(); + + if ( ! IsIntegral(bt) && bt != TYPE_BOOL ) + ExprError("requires an integral or boolean operand"); + else + SetType(base_type(TYPE_BOOL)); +} + +ValPtr NotExpr::Fold(Val* v) const { return val_mgr->Bool(! v->InternalInt()); } + +PosExpr::PosExpr(ExprPtr arg_op) : UnaryExpr(EXPR_POSITIVE, std::move(arg_op)) { + if ( IsError() ) + return; + + const auto& t = IsVector(op->GetType()->Tag()) ? op->GetType()->Yield() : op->GetType(); + + TypeTag bt = t->Tag(); + TypePtr base_result_type; + + if ( IsIntegral(bt) ) + // Promote count and counter to int. + base_result_type = base_type(TYPE_INT); + else if ( bt == TYPE_INTERVAL || bt == TYPE_DOUBLE ) + base_result_type = t; + else + ExprError("requires an integral or double operand"); + + if ( is_vector(op) ) + SetType(make_intrusive(std::move(base_result_type))); + else + SetType(std::move(base_result_type)); +} + +ValPtr PosExpr::Fold(Val* v) const { + TypeTag t = v->GetType()->Tag(); + + if ( t == TYPE_DOUBLE || t == TYPE_INTERVAL || t == TYPE_INT ) + return {NewRef{}, v}; + else + return val_mgr->Int(v->CoerceToInt()); +} + +NegExpr::NegExpr(ExprPtr arg_op) : UnaryExpr(EXPR_NEGATE, std::move(arg_op)) { + if ( IsError() ) + return; + + const auto& t = IsVector(op->GetType()->Tag()) ? op->GetType()->Yield() : op->GetType(); + + TypeTag bt = t->Tag(); + TypePtr base_result_type; + + if ( IsIntegral(bt) ) + // Promote count and counter to int. + base_result_type = base_type(TYPE_INT); + else if ( bt == TYPE_INTERVAL || bt == TYPE_DOUBLE ) + base_result_type = t; + else + ExprError("requires an integral or double operand"); + + if ( is_vector(op) ) + SetType(make_intrusive(std::move(base_result_type))); + else + SetType(std::move(base_result_type)); +} + +ValPtr NegExpr::Fold(Val* v) const { + if ( v->GetType()->Tag() == TYPE_DOUBLE ) + return make_intrusive(-v->InternalDouble()); + else if ( v->GetType()->Tag() == TYPE_INTERVAL ) + return make_intrusive(-v->InternalDouble()); + else + return val_mgr->Int(-v->CoerceToInt()); +} + +SizeExpr::SizeExpr(ExprPtr arg_op) : UnaryExpr(EXPR_SIZE, std::move(arg_op)) { + if ( IsError() ) + return; + + auto& t = op->GetType(); + + if ( t->Tag() == TYPE_ANY ) + SetType(base_type(TYPE_ANY)); + else if ( t->Tag() == TYPE_FILE || t->Tag() == TYPE_SUBNET || t->InternalType() == TYPE_INTERNAL_DOUBLE ) + SetType(base_type(TYPE_DOUBLE)); + else + SetType(base_type(TYPE_COUNT)); +} + +ValPtr SizeExpr::Eval(Frame* f) const { + auto v = op->Eval(f); + + if ( ! v ) + return nullptr; + + return Fold(v.get()); +} + +ValPtr SizeExpr::Fold(Val* v) const { return v->SizeVal(); } // Fill op1 and op2 type tags into bt1 and bt2. // // If both operands are vectors, use their yield type tag. If // either, but not both operands, is a vector, cause an expression // error and return false. -static bool get_types_from_scalars_or_vectors(Expr* e, TypeTag& bt1, TypeTag& bt2) - { - bt1 = e->GetOp1()->GetType()->Tag(); - bt2 = e->GetOp2()->GetType()->Tag(); +static bool get_types_from_scalars_or_vectors(Expr* e, TypeTag& bt1, TypeTag& bt2) { + bt1 = e->GetOp1()->GetType()->Tag(); + bt2 = e->GetOp2()->GetType()->Tag(); - if ( IsVector(bt1) && IsVector(bt2) ) - { - bt1 = e->GetOp1()->GetType()->AsVectorType()->Yield()->Tag(); - bt2 = e->GetOp2()->GetType()->AsVectorType()->Yield()->Tag(); - } - else if ( IsVector(bt1) || IsVector(bt2) ) - { - e->Error("cannot mix vector and scalar operands"); - e->SetError(); - return false; - } + if ( IsVector(bt1) && IsVector(bt2) ) { + bt1 = e->GetOp1()->GetType()->AsVectorType()->Yield()->Tag(); + bt2 = e->GetOp2()->GetType()->AsVectorType()->Yield()->Tag(); + } + else if ( IsVector(bt1) || IsVector(bt2) ) { + e->Error("cannot mix vector and scalar operands"); + e->SetError(); + return false; + } - return true; - } + return true; +} -AddExpr::AddExpr(ExprPtr arg_op1, ExprPtr arg_op2) - : BinaryExpr(EXPR_ADD, std::move(arg_op1), std::move(arg_op2)) - { - if ( IsError() ) - return; +AddExpr::AddExpr(ExprPtr arg_op1, ExprPtr arg_op2) : BinaryExpr(EXPR_ADD, std::move(arg_op1), std::move(arg_op2)) { + if ( IsError() ) + return; - TypeTag bt1, bt2; - if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) - return; + TypeTag bt1, bt2; + if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) + return; - TypePtr base_result_type; + TypePtr base_result_type; - if ( bt2 == TYPE_INTERVAL && (bt1 == TYPE_TIME || bt1 == TYPE_INTERVAL) ) - base_result_type = base_type(bt1); - else if ( bt2 == TYPE_TIME && bt1 == TYPE_INTERVAL ) - base_result_type = base_type(bt2); - else if ( BothArithmetic(bt1, bt2) ) - PromoteType(max_type(bt1, bt2), is_vector(op1) || is_vector(op2)); - else if ( BothString(bt1, bt2) ) - base_result_type = base_type(bt1); - else - ExprError("requires arithmetic operands"); + if ( bt2 == TYPE_INTERVAL && (bt1 == TYPE_TIME || bt1 == TYPE_INTERVAL) ) + base_result_type = base_type(bt1); + else if ( bt2 == TYPE_TIME && bt1 == TYPE_INTERVAL ) + base_result_type = base_type(bt2); + else if ( BothArithmetic(bt1, bt2) ) + PromoteType(max_type(bt1, bt2), is_vector(op1) || is_vector(op2)); + else if ( BothString(bt1, bt2) ) + base_result_type = base_type(bt1); + else + ExprError("requires arithmetic operands"); - if ( base_result_type ) - { - if ( is_vector(op1) ) - SetType(make_intrusive(std::move(base_result_type))); - else - SetType(std::move(base_result_type)); - } - } + if ( base_result_type ) { + if ( is_vector(op1) ) + SetType(make_intrusive(std::move(base_result_type))); + else + SetType(std::move(base_result_type)); + } +} -void AddExpr::Canonicalize() - { - if ( expr_greater(op2.get(), op1.get()) || - (op1->GetType()->Tag() == TYPE_INTERVAL && op2->GetType()->Tag() == TYPE_TIME) || - (op2->IsConst() && ! is_vector(op2->ExprVal()) && ! op1->IsConst()) ) - SwapOps(); - } +void AddExpr::Canonicalize() { + if ( expr_greater(op2.get(), op1.get()) || + (op1->GetType()->Tag() == TYPE_INTERVAL && op2->GetType()->Tag() == TYPE_TIME) || + (op2->IsConst() && ! is_vector(op2->ExprVal()) && ! op1->IsConst()) ) + SwapOps(); +} AddToExpr::AddToExpr(ExprPtr arg_op1, ExprPtr arg_op2) - : BinaryExpr(EXPR_ADD_TO, std::move(arg_op1), std::move(arg_op2)) - { - if ( IsError() ) - return; + : BinaryExpr(EXPR_ADD_TO, std::move(arg_op1), std::move(arg_op2)) { + if ( IsError() ) + return; - auto& t1 = op1->GetType(); - auto& t2 = op2->GetType(); - TypeTag bt1 = t1->Tag(); - TypeTag bt2 = t2->Tag(); + auto& t1 = op1->GetType(); + auto& t2 = op2->GetType(); + TypeTag bt1 = t1->Tag(); + TypeTag bt2 = t2->Tag(); - if ( bt1 != TYPE_TABLE && bt1 != TYPE_VECTOR && bt1 != TYPE_PATTERN ) - op1 = op1->MakeLvalue(); + if ( bt1 != TYPE_TABLE && bt1 != TYPE_VECTOR && bt1 != TYPE_PATTERN ) + op1 = op1->MakeLvalue(); - if ( BothArithmetic(bt1, bt2) ) - PromoteType(max_type(bt1, bt2), is_vector(op1) || is_vector(op2)); - else if ( BothString(bt1, bt2) || BothInterval(bt1, bt2) ) - SetType(base_type(bt1)); + if ( BothArithmetic(bt1, bt2) ) + PromoteType(max_type(bt1, bt2), is_vector(op1) || is_vector(op2)); + else if ( BothString(bt1, bt2) || BothInterval(bt1, bt2) ) + SetType(base_type(bt1)); - else if ( bt2 == TYPE_LIST ) - (void)CheckForRHSList(); + else if ( bt2 == TYPE_LIST ) + (void)CheckForRHSList(); - else if ( bt1 == TYPE_TABLE ) - { - if ( same_type(t1, t2) ) - SetType(t1); - else - ExprError("RHS type mismatch for table/set +="); - } + else if ( bt1 == TYPE_TABLE ) { + if ( same_type(t1, t2) ) + SetType(t1); + else + ExprError("RHS type mismatch for table/set +="); + } - else if ( bt1 == TYPE_PATTERN ) - { - if ( bt2 != TYPE_PATTERN ) - ExprError("pattern += op requires op to be a pattern"); - else - SetType(t1); - } + else if ( bt1 == TYPE_PATTERN ) { + if ( bt2 != TYPE_PATTERN ) + ExprError("pattern += op requires op to be a pattern"); + else + SetType(t1); + } - else if ( IsVector(bt1) ) - { - // We need the IsVector(bt2) check in the following because - // same_type() always treats "any" types as "same". - if ( IsVector(bt2) && same_type(t1, t2) ) - { - SetType(t1); - return; - } + else if ( IsVector(bt1) ) { + // We need the IsVector(bt2) check in the following because + // same_type() always treats "any" types as "same". + if ( IsVector(bt2) && same_type(t1, t2) ) { + SetType(t1); + return; + } - is_vector_elem_append = true; + is_vector_elem_append = true; - bt1 = t1->AsVectorType()->Yield()->Tag(); + bt1 = t1->AsVectorType()->Yield()->Tag(); - if ( IsArithmetic(bt1) ) - { - if ( IsArithmetic(bt2) ) - { - if ( bt2 != bt1 ) - op2 = make_intrusive(std::move(op2), bt1); + if ( IsArithmetic(bt1) ) { + if ( IsArithmetic(bt2) ) { + if ( bt2 != bt1 ) + op2 = make_intrusive(std::move(op2), bt1); - SetType(t1); - } + SetType(t1); + } - else - ExprError("appending non-arithmetic to arithmetic vector"); - } + else + ExprError("appending non-arithmetic to arithmetic vector"); + } - else if ( bt1 != bt2 && bt1 != TYPE_ANY ) - ExprError( - util::fmt("incompatible vector append: %s and %s", type_name(bt1), type_name(bt2))); + else if ( bt1 != bt2 && bt1 != TYPE_ANY ) + ExprError(util::fmt("incompatible vector append: %s and %s", type_name(bt1), type_name(bt2))); - else - SetType(t1); - } + else + SetType(t1); + } - else - ExprError("requires two arithmetic or two string operands"); - } + else + ExprError("requires two arithmetic or two string operands"); +} -ValPtr AddToExpr::Eval(Frame* f) const - { - auto v1 = op1->Eval(f); +ValPtr AddToExpr::Eval(Frame* f) const { + auto v1 = op1->Eval(f); - if ( ! v1 ) - return nullptr; + if ( ! v1 ) + return nullptr; - auto v2 = op2->Eval(f); + auto v2 = op2->Eval(f); - if ( ! v2 ) - return nullptr; + if ( ! v2 ) + return nullptr; - if ( is_vector_elem_append ) - { - VectorVal* vv = v1->AsVectorVal(); + if ( is_vector_elem_append ) { + VectorVal* vv = v1->AsVectorVal(); - if ( ! vv->Assign(vv->Size(), v2) ) - RuntimeError("type-checking failed in vector append"); + if ( ! vv->Assign(vv->Size(), v2) ) + RuntimeError("type-checking failed in vector append"); - return v1; - } + return v1; + } - if ( type->Tag() == TYPE_PATTERN ) - { - v2->AddTo(v1.get(), false); - return v1; - } + if ( type->Tag() == TYPE_PATTERN ) { + v2->AddTo(v1.get(), false); + return v1; + } - if ( auto result = Fold(v1.get(), v2.get()) ) - { - op1->Assign(f, result); - return result; - } - else - return nullptr; - } + if ( auto result = Fold(v1.get(), v2.get()) ) { + op1->Assign(f, result); + return result; + } + else + return nullptr; +} -SubExpr::SubExpr(ExprPtr arg_op1, ExprPtr arg_op2) - : BinaryExpr(EXPR_SUB, std::move(arg_op1), std::move(arg_op2)) - { - if ( IsError() ) - return; +SubExpr::SubExpr(ExprPtr arg_op1, ExprPtr arg_op2) : BinaryExpr(EXPR_SUB, std::move(arg_op1), std::move(arg_op2)) { + if ( IsError() ) + return; - const auto& t1 = op1->GetType(); - const auto& t2 = op2->GetType(); + const auto& t1 = op1->GetType(); + const auto& t2 = op2->GetType(); - TypeTag bt1, bt2; - if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) - return; + TypeTag bt1, bt2; + if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) + return; - TypePtr base_result_type; + TypePtr base_result_type; - if ( bt2 == TYPE_INTERVAL && (bt1 == TYPE_TIME || bt1 == TYPE_INTERVAL) ) - base_result_type = base_type(bt1); + if ( bt2 == TYPE_INTERVAL && (bt1 == TYPE_TIME || bt1 == TYPE_INTERVAL) ) + base_result_type = base_type(bt1); - else if ( bt1 == TYPE_TIME && bt2 == TYPE_TIME ) - SetType(base_type(TYPE_INTERVAL)); + else if ( bt1 == TYPE_TIME && bt2 == TYPE_TIME ) + SetType(base_type(TYPE_INTERVAL)); - else if ( t1->IsSet() && t2->IsSet() ) - { - if ( same_type(t1, t2) ) - SetType(op1->GetType()); - else - ExprError("incompatible \"set\" operands"); - } + else if ( t1->IsSet() && t2->IsSet() ) { + if ( same_type(t1, t2) ) + SetType(op1->GetType()); + else + ExprError("incompatible \"set\" operands"); + } - else if ( BothArithmetic(bt1, bt2) ) - PromoteType(max_type(bt1, bt2), is_vector(op1) || is_vector(op2)); + else if ( BothArithmetic(bt1, bt2) ) + PromoteType(max_type(bt1, bt2), is_vector(op1) || is_vector(op2)); - else - ExprError("requires arithmetic operands"); + else + ExprError("requires arithmetic operands"); - if ( base_result_type ) - { - if ( is_vector(op1) ) - SetType(make_intrusive(std::move(base_result_type))); - else - SetType(std::move(base_result_type)); - } - } + if ( base_result_type ) { + if ( is_vector(op1) ) + SetType(make_intrusive(std::move(base_result_type))); + else + SetType(std::move(base_result_type)); + } +} RemoveFromExpr::RemoveFromExpr(ExprPtr arg_op1, ExprPtr arg_op2) - : BinaryExpr(EXPR_REMOVE_FROM, std::move(arg_op1), std::move(arg_op2)) - { - if ( IsError() ) - return; + : BinaryExpr(EXPR_REMOVE_FROM, std::move(arg_op1), std::move(arg_op2)) { + if ( IsError() ) + return; - auto& t1 = op1->GetType(); - auto& t2 = op2->GetType(); - TypeTag bt1 = t1->Tag(); - TypeTag bt2 = t2->Tag(); + auto& t1 = op1->GetType(); + auto& t2 = op2->GetType(); + TypeTag bt1 = t1->Tag(); + TypeTag bt2 = t2->Tag(); - if ( bt1 != TYPE_TABLE ) - op1 = op1->MakeLvalue(); + if ( bt1 != TYPE_TABLE ) + op1 = op1->MakeLvalue(); - if ( BothArithmetic(bt1, bt2) ) - PromoteType(max_type(bt1, bt2), is_vector(op1) || is_vector(op2)); - else if ( BothInterval(bt1, bt2) ) - SetType(base_type(bt1)); + if ( BothArithmetic(bt1, bt2) ) + PromoteType(max_type(bt1, bt2), is_vector(op1) || is_vector(op2)); + else if ( BothInterval(bt1, bt2) ) + SetType(base_type(bt1)); - else if ( bt2 == TYPE_LIST ) - (void)CheckForRHSList(); + else if ( bt2 == TYPE_LIST ) + (void)CheckForRHSList(); - else if ( bt1 == TYPE_TABLE ) - { - if ( same_type(t1, t2) ) - SetType(t1); - else - ExprError("RHS type mismatch for table/set -="); - } + else if ( bt1 == TYPE_TABLE ) { + if ( same_type(t1, t2) ) + SetType(t1); + else + ExprError("RHS type mismatch for table/set -="); + } - else - ExprError("requires two arithmetic operands"); - } + else + ExprError("requires two arithmetic operands"); +} -ValPtr RemoveFromExpr::Eval(Frame* f) const - { - auto v1 = op1->Eval(f); +ValPtr RemoveFromExpr::Eval(Frame* f) const { + auto v1 = op1->Eval(f); - if ( ! v1 ) - return nullptr; + if ( ! v1 ) + return nullptr; - auto v2 = op2->Eval(f); + auto v2 = op2->Eval(f); - if ( ! v2 ) - return nullptr; + if ( ! v2 ) + return nullptr; - if ( auto result = Fold(v1.get(), v2.get()) ) - { - op1->Assign(f, result); - return result; - } - else - return nullptr; - } + if ( auto result = Fold(v1.get(), v2.get()) ) { + op1->Assign(f, result); + return result; + } + else + return nullptr; +} TimesExpr::TimesExpr(ExprPtr arg_op1, ExprPtr arg_op2) - : BinaryExpr(EXPR_TIMES, std::move(arg_op1), std::move(arg_op2)) - { - if ( IsError() ) - return; + : BinaryExpr(EXPR_TIMES, std::move(arg_op1), std::move(arg_op2)) { + if ( IsError() ) + return; - Canonicalize(); + Canonicalize(); - TypeTag bt1, bt2; - if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) - return; + TypeTag bt1, bt2; + if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) + return; - if ( bt1 == TYPE_INTERVAL || bt2 == TYPE_INTERVAL ) - { - if ( IsArithmetic(bt1) || IsArithmetic(bt2) ) - PromoteForInterval(IsArithmetic(bt1) ? op1 : op2); - else - ExprError("multiplication with interval requires arithmetic operand"); - } - else if ( BothArithmetic(bt1, bt2) ) - PromoteType(max_type(bt1, bt2), is_vector(op1) || is_vector(op2)); - else - ExprError("requires arithmetic operands"); - } + if ( bt1 == TYPE_INTERVAL || bt2 == TYPE_INTERVAL ) { + if ( IsArithmetic(bt1) || IsArithmetic(bt2) ) + PromoteForInterval(IsArithmetic(bt1) ? op1 : op2); + else + ExprError("multiplication with interval requires arithmetic operand"); + } + else if ( BothArithmetic(bt1, bt2) ) + PromoteType(max_type(bt1, bt2), is_vector(op1) || is_vector(op2)); + else + ExprError("requires arithmetic operands"); +} -void TimesExpr::Canonicalize() - { - if ( expr_greater(op2.get(), op1.get()) || op2->GetType()->Tag() == TYPE_INTERVAL || - (op2->IsConst() && ! is_vector(op2->ExprVal()) && ! op1->IsConst()) ) - SwapOps(); - } +void TimesExpr::Canonicalize() { + if ( expr_greater(op2.get(), op1.get()) || op2->GetType()->Tag() == TYPE_INTERVAL || + (op2->IsConst() && ! is_vector(op2->ExprVal()) && ! op1->IsConst()) ) + SwapOps(); +} DivideExpr::DivideExpr(ExprPtr arg_op1, ExprPtr arg_op2) - : BinaryExpr(EXPR_DIVIDE, std::move(arg_op1), std::move(arg_op2)) - { - if ( IsError() ) - return; + : BinaryExpr(EXPR_DIVIDE, std::move(arg_op1), std::move(arg_op2)) { + if ( IsError() ) + return; - TypeTag bt1, bt2; - if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) - return; + TypeTag bt1, bt2; + if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) + return; - if ( bt1 == TYPE_INTERVAL || bt2 == TYPE_INTERVAL ) - { - if ( IsArithmetic(bt1) || IsArithmetic(bt2) ) - PromoteForInterval(IsArithmetic(bt1) ? op1 : op2); - else if ( bt1 == TYPE_INTERVAL && bt2 == TYPE_INTERVAL ) - { - if ( is_vector(op1) ) - SetType(make_intrusive(base_type(TYPE_DOUBLE))); - else - SetType(base_type(TYPE_DOUBLE)); - } - else - ExprError("division of interval requires arithmetic operand"); - } + if ( bt1 == TYPE_INTERVAL || bt2 == TYPE_INTERVAL ) { + if ( IsArithmetic(bt1) || IsArithmetic(bt2) ) + PromoteForInterval(IsArithmetic(bt1) ? op1 : op2); + else if ( bt1 == TYPE_INTERVAL && bt2 == TYPE_INTERVAL ) { + if ( is_vector(op1) ) + SetType(make_intrusive(base_type(TYPE_DOUBLE))); + else + SetType(base_type(TYPE_DOUBLE)); + } + else + ExprError("division of interval requires arithmetic operand"); + } - else if ( BothArithmetic(bt1, bt2) ) - PromoteType(max_type(bt1, bt2), is_vector(op1) || is_vector(op2)); + else if ( BothArithmetic(bt1, bt2) ) + PromoteType(max_type(bt1, bt2), is_vector(op1) || is_vector(op2)); - else - ExprError("requires arithmetic operands"); - } + else + ExprError("requires arithmetic operands"); +} -MaskExpr::MaskExpr(ExprPtr arg_op1, ExprPtr arg_op2) - : BinaryExpr(EXPR_MASK, std::move(arg_op1), std::move(arg_op2)) - { - if ( IsError() ) - return; +MaskExpr::MaskExpr(ExprPtr arg_op1, ExprPtr arg_op2) : BinaryExpr(EXPR_MASK, std::move(arg_op1), std::move(arg_op2)) { + if ( IsError() ) + return; - TypeTag bt1, bt2; - if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) - return; + TypeTag bt1, bt2; + if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) + return; - if ( bt1 == TYPE_ADDR && ! is_vector(op2) && (bt2 == TYPE_COUNT || bt2 == TYPE_INT) ) - SetType(base_type(TYPE_SUBNET)); - else - ExprError("requires address LHS and count/int RHS"); - } + if ( bt1 == TYPE_ADDR && ! is_vector(op2) && (bt2 == TYPE_COUNT || bt2 == TYPE_INT) ) + SetType(base_type(TYPE_SUBNET)); + else + ExprError("requires address LHS and count/int RHS"); +} -ValPtr MaskExpr::AddrFold(Val* v1, Val* v2) const - { - uint32_t mask; +ValPtr MaskExpr::AddrFold(Val* v1, Val* v2) const { + uint32_t mask; - if ( v2->GetType()->Tag() == TYPE_COUNT ) - mask = static_cast(v2->InternalUnsigned()); - else - mask = static_cast(v2->InternalInt()); + if ( v2->GetType()->Tag() == TYPE_COUNT ) + mask = static_cast(v2->InternalUnsigned()); + else + mask = static_cast(v2->InternalInt()); - auto& a = v1->AsAddr(); + auto& a = v1->AsAddr(); - if ( a.GetFamily() == IPv4 ) - { - if ( mask > 32 ) - RuntimeError(util::fmt("bad IPv4 subnet prefix length: %" PRIu32, mask)); - } - else - { - if ( mask > 128 ) - RuntimeError(util::fmt("bad IPv6 subnet prefix length: %" PRIu32, mask)); - } + if ( a.GetFamily() == IPv4 ) { + if ( mask > 32 ) + RuntimeError(util::fmt("bad IPv4 subnet prefix length: %" PRIu32, mask)); + } + else { + if ( mask > 128 ) + RuntimeError(util::fmt("bad IPv6 subnet prefix length: %" PRIu32, mask)); + } - return make_intrusive(a, mask); - } + return make_intrusive(a, mask); +} -ModExpr::ModExpr(ExprPtr arg_op1, ExprPtr arg_op2) - : BinaryExpr(EXPR_MOD, std::move(arg_op1), std::move(arg_op2)) - { - if ( IsError() ) - return; +ModExpr::ModExpr(ExprPtr arg_op1, ExprPtr arg_op2) : BinaryExpr(EXPR_MOD, std::move(arg_op1), std::move(arg_op2)) { + if ( IsError() ) + return; - TypeTag bt1, bt2; - if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) - return; + TypeTag bt1, bt2; + if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) + return; - if ( BothIntegral(bt1, bt2) ) - PromoteType(max_type(bt1, bt2), is_vector(op1) || is_vector(op2)); - else - ExprError("requires integral operands"); - } + if ( BothIntegral(bt1, bt2) ) + PromoteType(max_type(bt1, bt2), is_vector(op1) || is_vector(op2)); + else + ExprError("requires integral operands"); +} BoolExpr::BoolExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2) - : BinaryExpr(arg_tag, std::move(arg_op1), std::move(arg_op2)) - { - if ( IsError() ) - return; + : BinaryExpr(arg_tag, std::move(arg_op1), std::move(arg_op2)) { + if ( IsError() ) + return; - TypeTag bt1, bt2; - if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) - return; + TypeTag bt1, bt2; + if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) + return; - if ( BothBool(bt1, bt2) ) - { - if ( is_vector(op1) ) - SetType(make_intrusive(base_type(TYPE_BOOL))); - else - SetType(base_type(TYPE_BOOL)); - } - else - ExprError("requires boolean operands"); - } + if ( BothBool(bt1, bt2) ) { + if ( is_vector(op1) ) + SetType(make_intrusive(base_type(TYPE_BOOL))); + else + SetType(base_type(TYPE_BOOL)); + } + else + ExprError("requires boolean operands"); +} -ValPtr BoolExpr::DoSingleEval(Frame* f, ValPtr v1, Expr* op2) const - { - if ( ! v1 ) - return nullptr; +ValPtr BoolExpr::DoSingleEval(Frame* f, ValPtr v1, Expr* op2) const { + if ( ! v1 ) + return nullptr; - if ( tag == EXPR_AND_AND ) - { - if ( v1->IsZero() ) - return v1; - else - return op2->Eval(f); - } + if ( tag == EXPR_AND_AND ) { + if ( v1->IsZero() ) + return v1; + else + return op2->Eval(f); + } - else - { - if ( v1->IsZero() ) - return op2->Eval(f); - else - return v1; - } - } + else { + if ( v1->IsZero() ) + return op2->Eval(f); + else + return v1; + } +} -ValPtr BoolExpr::Eval(Frame* f) const - { - if ( IsError() ) - return nullptr; +ValPtr BoolExpr::Eval(Frame* f) const { + if ( IsError() ) + return nullptr; - auto v1 = op1->Eval(f); + auto v1 = op1->Eval(f); - if ( ! v1 ) - return nullptr; + if ( ! v1 ) + return nullptr; - bool is_vec1 = is_vector(op1); - bool is_vec2 = is_vector(op2); + bool is_vec1 = is_vector(op1); + bool is_vec2 = is_vector(op2); - // Handle scalar op scalar - if ( ! is_vec1 && ! is_vec2 ) - return DoSingleEval(f, std::move(v1), op2.get()); + // Handle scalar op scalar + if ( ! is_vec1 && ! is_vec2 ) + return DoSingleEval(f, std::move(v1), op2.get()); - // Both are vectors. - auto v2 = op2->Eval(f); + // Both are vectors. + auto v2 = op2->Eval(f); - if ( ! v2 ) - return nullptr; + if ( ! v2 ) + return nullptr; - VectorVal* vec_v1 = v1->AsVectorVal(); - VectorVal* vec_v2 = v2->AsVectorVal(); + VectorVal* vec_v1 = v1->AsVectorVal(); + VectorVal* vec_v2 = v2->AsVectorVal(); - if ( vec_v1->Size() != vec_v2->Size() ) - { - RuntimeError("vector operands have different sizes"); - return nullptr; - } + if ( vec_v1->Size() != vec_v2->Size() ) { + RuntimeError("vector operands have different sizes"); + return nullptr; + } - auto result = make_intrusive(GetType()); - result->Resize(vec_v1->Size()); + auto result = make_intrusive(GetType()); + result->Resize(vec_v1->Size()); - for ( unsigned int i = 0; i < vec_v1->Size(); ++i ) - { - const auto op1 = vec_v1->BoolAt(i); - const auto op2 = vec_v2->BoolAt(i); + for ( unsigned int i = 0; i < vec_v1->Size(); ++i ) { + const auto op1 = vec_v1->BoolAt(i); + const auto op2 = vec_v2->BoolAt(i); - bool local_result = (tag == EXPR_AND_AND) ? (op1 && op2) : (op1 || op2); + bool local_result = (tag == EXPR_AND_AND) ? (op1 && op2) : (op1 || op2); - result->Assign(i, val_mgr->Bool(local_result)); - } + result->Assign(i, val_mgr->Bool(local_result)); + } - return result; - } + return result; +} BitExpr::BitExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2) - : BinaryExpr(arg_tag, std::move(arg_op1), std::move(arg_op2)) - { - if ( IsError() ) - return; + : BinaryExpr(arg_tag, std::move(arg_op1), std::move(arg_op2)) { + if ( IsError() ) + return; - const auto& t1 = op1->GetType(); - const auto& t2 = op2->GetType(); + const auto& t1 = op1->GetType(); + const auto& t2 = op2->GetType(); - TypeTag bt1 = t1->Tag(); + TypeTag bt1 = t1->Tag(); - if ( IsVector(bt1) ) - bt1 = t1->AsVectorType()->Yield()->Tag(); + if ( IsVector(bt1) ) + bt1 = t1->AsVectorType()->Yield()->Tag(); - TypeTag bt2 = t2->Tag(); + TypeTag bt2 = t2->Tag(); - if ( IsVector(bt2) ) - bt2 = t2->AsVectorType()->Yield()->Tag(); + if ( IsVector(bt2) ) + bt2 = t2->AsVectorType()->Yield()->Tag(); - if ( tag == EXPR_LSHIFT || tag == EXPR_RSHIFT ) - { - if ( (is_vector(op1) || is_vector(op2)) && ! (is_vector(op1) && is_vector(op2)) ) - ExprError("cannot mix vectors and scalars for shift operations"); + if ( tag == EXPR_LSHIFT || tag == EXPR_RSHIFT ) { + if ( (is_vector(op1) || is_vector(op2)) && ! (is_vector(op1) && is_vector(op2)) ) + ExprError("cannot mix vectors and scalars for shift operations"); - if ( IsIntegral(bt1) && bt2 == TYPE_COUNT ) - { - if ( is_vector(op1) || is_vector(op2) ) - SetType(make_intrusive(base_type(bt1))); - else - SetType(base_type(bt1)); - } + if ( IsIntegral(bt1) && bt2 == TYPE_COUNT ) { + if ( is_vector(op1) || is_vector(op2) ) + SetType(make_intrusive(base_type(bt1))); + else + SetType(base_type(bt1)); + } - else if ( IsIntegral(bt1) && bt2 == TYPE_INT ) - ExprError("requires \"count\" right operand"); + else if ( IsIntegral(bt1) && bt2 == TYPE_INT ) + ExprError("requires \"count\" right operand"); - else - ExprError("requires integral operands"); + else + ExprError("requires integral operands"); - return; // because following scalar check isn't apt - } + return; // because following scalar check isn't apt + } - if ( (bt1 == TYPE_COUNT) && (bt2 == TYPE_COUNT) ) - { - if ( is_vector(op1) || is_vector(op2) ) - SetType(make_intrusive(base_type(TYPE_COUNT))); - else - SetType(base_type(TYPE_COUNT)); - } + if ( (bt1 == TYPE_COUNT) && (bt2 == TYPE_COUNT) ) { + if ( is_vector(op1) || is_vector(op2) ) + SetType(make_intrusive(base_type(TYPE_COUNT))); + else + SetType(base_type(TYPE_COUNT)); + } - else if ( bt1 == TYPE_PATTERN ) - { - if ( bt2 != TYPE_PATTERN ) - ExprError("cannot mix pattern and non-pattern operands"); - else if ( tag == EXPR_XOR ) - ExprError("'^' operator does not apply to patterns"); - else - SetType(base_type(TYPE_PATTERN)); - } + else if ( bt1 == TYPE_PATTERN ) { + if ( bt2 != TYPE_PATTERN ) + ExprError("cannot mix pattern and non-pattern operands"); + else if ( tag == EXPR_XOR ) + ExprError("'^' operator does not apply to patterns"); + else + SetType(base_type(TYPE_PATTERN)); + } - else if ( t1->IsSet() && t2->IsSet() ) - { - if ( same_type(t1, t2) ) - SetType(op1->GetType()); - else - ExprError("incompatible \"set\" operands"); - } + else if ( t1->IsSet() && t2->IsSet() ) { + if ( same_type(t1, t2) ) + SetType(op1->GetType()); + else + ExprError("incompatible \"set\" operands"); + } - else - ExprError("requires \"count\" or compatible \"set\" operands"); - } + else + ExprError("requires \"count\" or compatible \"set\" operands"); +} EqExpr::EqExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2) - : BinaryExpr(arg_tag, std::move(arg_op1), std::move(arg_op2)) - { - if ( IsError() ) - return; + : BinaryExpr(arg_tag, std::move(arg_op1), std::move(arg_op2)) { + if ( IsError() ) + return; - Canonicalize(); + Canonicalize(); - const auto& t1 = op1->GetType(); - const auto& t2 = op2->GetType(); + const auto& t1 = op1->GetType(); + const auto& t2 = op2->GetType(); - TypeTag bt1, bt2; - if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) - return; + TypeTag bt1, bt2; + if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) + return; - if ( is_vector(op1) ) - SetType(make_intrusive(base_type(TYPE_BOOL))); - else - SetType(base_type(TYPE_BOOL)); + if ( is_vector(op1) ) + SetType(make_intrusive(base_type(TYPE_BOOL))); + else + SetType(base_type(TYPE_BOOL)); - if ( BothArithmetic(bt1, bt2) ) - PromoteOps(max_type(bt1, bt2)); + if ( BothArithmetic(bt1, bt2) ) + PromoteOps(max_type(bt1, bt2)); - else if ( EitherArithmetic(bt1, bt2) && - // Allow comparisons with zero. - ((bt1 == TYPE_TIME && op2->IsZero()) || (bt2 == TYPE_TIME && op1->IsZero())) ) - PromoteOps(TYPE_TIME); + else if ( EitherArithmetic(bt1, bt2) && + // Allow comparisons with zero. + ((bt1 == TYPE_TIME && op2->IsZero()) || (bt2 == TYPE_TIME && op1->IsZero())) ) + PromoteOps(TYPE_TIME); - else if ( bt1 == bt2 ) - { - switch ( bt1 ) - { - case TYPE_BOOL: - case TYPE_TIME: - case TYPE_INTERVAL: - case TYPE_STRING: - case TYPE_PORT: - case TYPE_ADDR: - case TYPE_SUBNET: - case TYPE_ERROR: - case TYPE_FUNC: - break; + else if ( bt1 == bt2 ) { + switch ( bt1 ) { + case TYPE_BOOL: + case TYPE_TIME: + case TYPE_INTERVAL: + case TYPE_STRING: + case TYPE_PORT: + case TYPE_ADDR: + case TYPE_SUBNET: + case TYPE_ERROR: + case TYPE_FUNC: break; - case TYPE_ENUM: - if ( ! same_type(t1, t2) ) - ExprError("illegal enum comparison"); - break; + case TYPE_ENUM: + if ( ! same_type(t1, t2) ) + ExprError("illegal enum comparison"); + break; - case TYPE_TABLE: - if ( t1->IsSet() && t2->IsSet() ) - { - if ( ! same_type(t1, t2) ) - ExprError("incompatible sets in comparison"); - break; - } + case TYPE_TABLE: + if ( t1->IsSet() && t2->IsSet() ) { + if ( ! same_type(t1, t2) ) + ExprError("incompatible sets in comparison"); + break; + } - // FALL THROUGH + // FALL THROUGH - default: - ExprError("illegal comparison"); - } - } + default: ExprError("illegal comparison"); + } + } - else if ( bt1 == TYPE_PATTERN && bt2 == TYPE_STRING ) - ; + else if ( bt1 == TYPE_PATTERN && bt2 == TYPE_STRING ) + ; - else - ExprError("type clash in comparison"); - } + else + ExprError("type clash in comparison"); +} -void EqExpr::Canonicalize() - { - if ( op2->GetType()->Tag() == TYPE_PATTERN ) - SwapOps(); +void EqExpr::Canonicalize() { + if ( op2->GetType()->Tag() == TYPE_PATTERN ) + SwapOps(); - else if ( op1->GetType()->Tag() == TYPE_PATTERN ) - ; + else if ( op1->GetType()->Tag() == TYPE_PATTERN ) + ; - else if ( expr_greater(op2.get(), op1.get()) ) - SwapOps(); - } + else if ( expr_greater(op2.get(), op1.get()) ) + SwapOps(); +} -ValPtr EqExpr::Fold(Val* v1, Val* v2) const - { - if ( op1->GetType()->Tag() == TYPE_PATTERN ) - { - auto re = v1->As(); - const String* s = v2->AsString(); - if ( tag == EXPR_EQ ) - return val_mgr->Bool(re->MatchExactly(s)); - else - return val_mgr->Bool(! re->MatchExactly(s)); - } - else if ( op1->GetType()->Tag() == TYPE_FUNC ) - { - auto res = v1->AsFunc() == v2->AsFunc(); - return val_mgr->Bool(tag == EXPR_EQ ? res : ! res); - } +ValPtr EqExpr::Fold(Val* v1, Val* v2) const { + if ( op1->GetType()->Tag() == TYPE_PATTERN ) { + auto re = v1->As(); + const String* s = v2->AsString(); + if ( tag == EXPR_EQ ) + return val_mgr->Bool(re->MatchExactly(s)); + else + return val_mgr->Bool(! re->MatchExactly(s)); + } + else if ( op1->GetType()->Tag() == TYPE_FUNC ) { + auto res = v1->AsFunc() == v2->AsFunc(); + return val_mgr->Bool(tag == EXPR_EQ ? res : ! res); + } - else - return BinaryExpr::Fold(v1, v2); - } + else + return BinaryExpr::Fold(v1, v2); +} -bool EqExpr::InvertSense() - { - tag = (tag == EXPR_EQ ? EXPR_NE : EXPR_EQ); - return true; - } +bool EqExpr::InvertSense() { + tag = (tag == EXPR_EQ ? EXPR_NE : EXPR_EQ); + return true; +} RelExpr::RelExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2) - : BinaryExpr(arg_tag, std::move(arg_op1), std::move(arg_op2)) - { - if ( IsError() ) - return; + : BinaryExpr(arg_tag, std::move(arg_op1), std::move(arg_op2)) { + if ( IsError() ) + return; - Canonicalize(); + Canonicalize(); - const auto& t1 = op1->GetType(); - const auto& t2 = op2->GetType(); + const auto& t1 = op1->GetType(); + const auto& t2 = op2->GetType(); - TypeTag bt1, bt2; - if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) - return; + TypeTag bt1, bt2; + if ( ! get_types_from_scalars_or_vectors(this, bt1, bt2) ) + return; - if ( is_vector(op1) ) - SetType(make_intrusive(base_type(TYPE_BOOL))); - else - SetType(base_type(TYPE_BOOL)); + if ( is_vector(op1) ) + SetType(make_intrusive(base_type(TYPE_BOOL))); + else + SetType(base_type(TYPE_BOOL)); - if ( BothArithmetic(bt1, bt2) ) - PromoteOps(max_type(bt1, bt2)); + if ( BothArithmetic(bt1, bt2) ) + PromoteOps(max_type(bt1, bt2)); - else if ( t1->IsSet() && t2->IsSet() ) - { - if ( ! same_type(t1, t2) ) - ExprError("incompatible sets in comparison"); - } + else if ( t1->IsSet() && t2->IsSet() ) { + if ( ! same_type(t1, t2) ) + ExprError("incompatible sets in comparison"); + } - else if ( bt1 != bt2 ) - ExprError("operands must be of the same type"); + else if ( bt1 != bt2 ) + ExprError("operands must be of the same type"); - else if ( bt1 != TYPE_TIME && bt1 != TYPE_INTERVAL && bt1 != TYPE_PORT && bt1 != TYPE_ADDR && - bt1 != TYPE_STRING ) - ExprError("illegal comparison"); - } + else if ( bt1 != TYPE_TIME && bt1 != TYPE_INTERVAL && bt1 != TYPE_PORT && bt1 != TYPE_ADDR && bt1 != TYPE_STRING ) + ExprError("illegal comparison"); +} -void RelExpr::Canonicalize() - { - if ( tag == EXPR_GT ) - { - SwapOps(); - tag = EXPR_LT; - } +void RelExpr::Canonicalize() { + if ( tag == EXPR_GT ) { + SwapOps(); + tag = EXPR_LT; + } - else if ( tag == EXPR_GE ) - { - SwapOps(); - tag = EXPR_LE; - } - } + else if ( tag == EXPR_GE ) { + SwapOps(); + tag = EXPR_LE; + } +} -bool RelExpr::InvertSense() - { - switch ( tag ) - { - case EXPR_LT: - tag = EXPR_GE; - return true; - case EXPR_LE: - tag = EXPR_GT; - return true; - case EXPR_GE: - tag = EXPR_LT; - return true; - case EXPR_GT: - tag = EXPR_LE; - return true; +bool RelExpr::InvertSense() { + switch ( tag ) { + case EXPR_LT: tag = EXPR_GE; return true; + case EXPR_LE: tag = EXPR_GT; return true; + case EXPR_GE: tag = EXPR_LT; return true; + case EXPR_GT: tag = EXPR_LE; return true; - default: - return false; - } - } + default: return false; + } +} CondExpr::CondExpr(ExprPtr arg_op1, ExprPtr arg_op2, ExprPtr arg_op3) - : Expr(EXPR_COND), op1(std::move(arg_op1)), op2(std::move(arg_op2)), op3(std::move(arg_op3)) - { - TypeTag bt1 = op1->GetType()->Tag(); - - if ( IsVector(bt1) ) - bt1 = op1->GetType()->AsVectorType()->Yield()->Tag(); - - if ( op1->IsError() || op2->IsError() || op3->IsError() ) - SetError(); - - else if ( bt1 != TYPE_BOOL ) - ExprError("requires boolean conditional"); - - else - { - TypeTag bt2 = op2->GetType()->Tag(); - TypeTag bt3 = op3->GetType()->Tag(); - - if ( is_vector(op1) ) - { - if ( ! (is_vector(op2) && is_vector(op3)) ) - { - ExprError("vector conditional requires vector alternatives"); - return; - } - - bt2 = op2->GetType()->AsVectorType()->Yield()->Tag(); - bt3 = op3->GetType()->AsVectorType()->Yield()->Tag(); - } - - if ( BothArithmetic(bt2, bt3) ) - { - TypeTag t = max_type(bt2, bt3); - if ( bt2 != t ) - op2 = make_intrusive(std::move(op2), t); - if ( bt3 != t ) - op3 = make_intrusive(std::move(op3), t); - - if ( is_vector(op1) ) - SetType(make_intrusive(base_type(t))); - else - SetType(base_type(t)); - } - - else if ( bt2 != bt3 ) - ExprError("operands must be of the same type"); - - else - { - if ( is_vector(op1) ) - { - ExprError("vector conditional type clash between alternatives"); - return; - } - - if ( bt2 == zeek::TYPE_TABLE ) - { - auto tt2 = op2->GetType(); - auto tt3 = op3->GetType(); - - if ( tt2->IsUnspecifiedTable() ) - op2 = make_intrusive(std::move(op2), std::move(tt3)); - else if ( tt3->IsUnspecifiedTable() ) - op3 = make_intrusive(std::move(op3), std::move(tt2)); - else if ( ! same_type(op2->GetType(), op3->GetType()) ) - ExprError("operands must be of the same type"); - } - else if ( bt2 == zeek::TYPE_VECTOR ) - { - auto vt2 = op2->GetType(); - auto vt3 = op3->GetType(); - - if ( vt2->IsUnspecifiedVector() ) - op2 = make_intrusive(std::move(op2), std::move(vt3)); - else if ( vt3->IsUnspecifiedVector() ) - op3 = make_intrusive(std::move(op3), std::move(vt2)); - else if ( ! same_type(op2->GetType(), op3->GetType()) ) - ExprError("operands must be of the same type"); - } - else if ( ! same_type(op2->GetType(), op3->GetType()) ) - // Records could potentially also coerce, but may have some cases - // where the coercion direction is ambiguous. - ExprError("operands must be of the same type"); - - if ( ! IsError() ) - SetType(op2->GetType()); - } - } - } - -ValPtr CondExpr::Eval(Frame* f) const - { - if ( ! is_vector(op1) ) - { - // Scalar case - auto false_eval = op1->Eval(f)->IsZero(); - return (false_eval ? op3 : op2)->Eval(f); - } - - // Vector case: no mixed scalar/vector cases allowed - auto v1 = op1->Eval(f); - - if ( ! v1 ) - return nullptr; - - auto v2 = op2->Eval(f); - - if ( ! v2 ) - return nullptr; - - auto v3 = op3->Eval(f); - - if ( ! v3 ) - return nullptr; - - VectorVal* cond = v1->AsVectorVal(); - VectorVal* a = v2->AsVectorVal(); - VectorVal* b = v3->AsVectorVal(); - - if ( cond->Size() != a->Size() || a->Size() != b->Size() ) - { - RuntimeError("vectors in conditional expression have different sizes"); - return nullptr; - } - - auto result = make_intrusive(GetType()); - result->Resize(cond->Size()); - - for ( unsigned int i = 0; i < cond->Size(); ++i ) - { - auto local_cond = cond->BoolAt(i); - auto v = local_cond ? a->ValAt(i) : b->ValAt(i); - result->Assign(i, v); - } - - return result; - } - -bool CondExpr::IsPure() const - { - return op1->IsPure() && op2->IsPure() && op3->IsPure(); - } - -TraversalCode CondExpr::Traverse(TraversalCallback* cb) const - { - TraversalCode tc = cb->PreExpr(this); - HANDLE_TC_EXPR_PRE(tc); - - tc = op1->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); - - tc = op2->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); - - tc = op3->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); - - tc = cb->PostExpr(this); - HANDLE_TC_EXPR_POST(tc); - } - -void CondExpr::ExprDescribe(ODesc* d) const - { - op1->Describe(d); - d->AddSP(" ?"); - op2->Describe(d); - d->AddSP(" :"); - op3->Describe(d); - } - -RefExpr::RefExpr(ExprPtr arg_op) : UnaryExpr(EXPR_REF, std::move(arg_op)) - { - if ( IsError() ) - return; - - if ( ! is_assignable(op->GetType()->Tag()) ) - ExprError("illegal assignment target"); - else - SetType(op->GetType()); - } - -ExprPtr RefExpr::MakeLvalue() - { - return ThisPtr(); - } - -void RefExpr::Assign(Frame* f, ValPtr v) - { - op->Assign(f, std::move(v)); - } - -AssignExpr::AssignExpr(ExprPtr arg_op1, ExprPtr arg_op2, bool arg_is_init, ValPtr arg_val, - const AttributesPtr& attrs, bool typecheck) - : BinaryExpr(EXPR_ASSIGN, arg_is_init ? std::move(arg_op1) : arg_op1->MakeLvalue(), - std::move(arg_op2)) - { - val = nullptr; - is_init = arg_is_init; - - if ( IsError() ) - return; - - if ( arg_val ) - SetType(arg_val->GetType()); - else - SetType(op1->GetType()); - - if ( is_init ) - { - SetLocationInfo(op1->GetLocationInfo(), op2->GetLocationInfo()); - return; - } - - if ( op2->Tag() == EXPR_LIST && CheckForRHSList() ) - { - if ( op2->Tag() == EXPR_TABLE_CONSTRUCTOR ) - cast_intrusive(op2)->SetAttrs(attrs); - else if ( op2->Tag() == EXPR_SET_CONSTRUCTOR ) - cast_intrusive(op2)->SetAttrs(attrs); - } - - else if ( typecheck ) - // We discard the status from TypeCheck since it has already - // generated error messages. - (void)TypeCheck(attrs); - - val = std::move(arg_val); - - SetLocationInfo(op1->GetLocationInfo(), op2->GetLocationInfo()); - } - -bool AssignExpr::TypeCheck(const AttributesPtr& attrs) - { - TypeTag bt1 = op1->GetType()->Tag(); - TypeTag bt2 = op2->GetType()->Tag(); - - if ( bt1 == TYPE_LIST && bt2 == TYPE_ANY ) - // This is ok because we cannot explicitly declare lists on - // the script level. - return true; - - // This should be one of them, but not both (i.e. XOR) - if ( ((bt1 == TYPE_ENUM) ^ (bt2 == TYPE_ENUM)) ) - { - ExprError("can't convert to/from enumerated type"); - return false; - } - - if ( IsArithmetic(bt1) ) - return TypeCheckArithmetics(bt1, bt2); - - if ( bt1 == TYPE_TIME && IsArithmetic(bt2) && op2->IsZero() ) - { // Allow assignments to zero as a special case. - op2 = make_intrusive(std::move(op2), bt1); - return true; - } - - if ( bt1 == TYPE_TABLE && bt2 == bt1 && op2->GetType()->AsTableType()->IsUnspecifiedTable() ) - { - op2 = make_intrusive(std::move(op2), op1->GetType()); - return true; - } - - if ( bt1 == TYPE_VECTOR ) - { - if ( bt2 == bt1 && op2->GetType()->AsVectorType()->IsUnspecifiedVector() ) - { - op2 = make_intrusive(std::move(op2), op1->GetType()); - return true; - } - - if ( op2->Tag() == EXPR_LIST ) - { - op2 = make_intrusive(cast_intrusive(op2), - op1->GetType()); - return true; - } - } - - if ( op1->GetType()->Tag() == TYPE_RECORD && op2->GetType()->Tag() == TYPE_RECORD ) - { - if ( same_type(op1->GetType(), op2->GetType()) ) - return true; - - // Need to coerce. - op2 = make_intrusive(std::move(op2), op1->GetType()); - return true; - } - - if ( ! same_type(op1->GetType(), op2->GetType()) ) - { - if ( bt1 == TYPE_TABLE && bt2 == TYPE_TABLE ) - { - if ( op2->Tag() == EXPR_SET_CONSTRUCTOR ) - { - // Some elements in constructor list must not match, see if - // we can create a new constructor now that the expected type - // of LHS is known and let it do coercions where possible. - auto sce = cast_intrusive(op2); - auto ctor_list = cast_intrusive(sce->GetOp1()); - - if ( ! ctor_list ) - Internal("failed typecast to ListExpr"); - - std::unique_ptr> attr_copy; - - if ( sce->GetAttrs() ) - { - const auto& a = sce->GetAttrs()->GetAttrs(); - attr_copy = std::make_unique>(a); - } - - int errors_before = reporter->Errors(); - op2 = make_intrusive(ctor_list, std::move(attr_copy), - op1->GetType()); - int errors_after = reporter->Errors(); - - if ( errors_after > errors_before ) - { - ExprError("type clash in assignment"); - return false; - } - - return true; - } - } - - ExprError("type clash in assignment"); - return false; - } - - return true; - } - -bool AssignExpr::TypeCheckArithmetics(TypeTag bt1, TypeTag bt2) - { - if ( ! IsArithmetic(bt2) ) - { - ExprError(util::fmt("assignment of non-arithmetic value to arithmetic (%s/%s)", - type_name(bt1), type_name(bt2))); - return false; - } - - if ( bt1 == TYPE_DOUBLE ) - { - PromoteOps(TYPE_DOUBLE); - return true; - } - - if ( bt2 == TYPE_DOUBLE ) - { - Warn("dangerous assignment of double to integral"); - op2 = make_intrusive(std::move(op2), bt1); - bt2 = op2->GetType()->Tag(); - } - - if ( bt1 == TYPE_INT ) - PromoteOps(TYPE_INT); - else - { - if ( bt2 == TYPE_INT ) - { - Warn("dangerous assignment of integer to count"); - op2 = make_intrusive(std::move(op2), bt1); - } - - // Assignment of count to counter or vice - // versa is allowed, and requires no - // coercion. - } - - return true; - } - -ValPtr AssignExpr::Eval(Frame* f) const - { - if ( is_init ) - { - RuntimeError("illegal assignment in initialization"); - return nullptr; - } - - if ( auto v = op2->Eval(f) ) - { - op1->Assign(f, v); - - if ( val ) - return val; - - return v; - } - else - return nullptr; - } - -TypePtr AssignExpr::InitType() const - { - if ( op1->Tag() != EXPR_LIST ) - { - Error("bad initializer, first operand should be a list"); - return nullptr; - } - - const auto& tl = op1->GetType(); - if ( tl->Tag() != TYPE_LIST ) - Internal("inconsistent list expr in AssignExpr::InitType"); - - return make_intrusive(IntrusivePtr{NewRef{}, tl->AsTypeList()}, op2->GetType()); - } - -bool AssignExpr::IsRecordElement(TypeDecl* td) const - { - if ( op1->Tag() == EXPR_NAME ) - { - if ( td ) - { - const NameExpr* n = (const NameExpr*)op1.get(); - td->type = op2->GetType(); - td->id = util::copy_string(n->Id()->Name()); - } - - return true; - } - - return false; - } + : Expr(EXPR_COND), op1(std::move(arg_op1)), op2(std::move(arg_op2)), op3(std::move(arg_op3)) { + TypeTag bt1 = op1->GetType()->Tag(); + + if ( IsVector(bt1) ) + bt1 = op1->GetType()->AsVectorType()->Yield()->Tag(); + + if ( op1->IsError() || op2->IsError() || op3->IsError() ) + SetError(); + + else if ( bt1 != TYPE_BOOL ) + ExprError("requires boolean conditional"); + + else { + TypeTag bt2 = op2->GetType()->Tag(); + TypeTag bt3 = op3->GetType()->Tag(); + + if ( is_vector(op1) ) { + if ( ! (is_vector(op2) && is_vector(op3)) ) { + ExprError("vector conditional requires vector alternatives"); + return; + } + + bt2 = op2->GetType()->AsVectorType()->Yield()->Tag(); + bt3 = op3->GetType()->AsVectorType()->Yield()->Tag(); + } + + if ( BothArithmetic(bt2, bt3) ) { + TypeTag t = max_type(bt2, bt3); + if ( bt2 != t ) + op2 = make_intrusive(std::move(op2), t); + if ( bt3 != t ) + op3 = make_intrusive(std::move(op3), t); + + if ( is_vector(op1) ) + SetType(make_intrusive(base_type(t))); + else + SetType(base_type(t)); + } + + else if ( bt2 != bt3 ) + ExprError("operands must be of the same type"); + + else { + if ( is_vector(op1) ) { + ExprError("vector conditional type clash between alternatives"); + return; + } + + if ( bt2 == zeek::TYPE_TABLE ) { + auto tt2 = op2->GetType(); + auto tt3 = op3->GetType(); + + if ( tt2->IsUnspecifiedTable() ) + op2 = make_intrusive(std::move(op2), std::move(tt3)); + else if ( tt3->IsUnspecifiedTable() ) + op3 = make_intrusive(std::move(op3), std::move(tt2)); + else if ( ! same_type(op2->GetType(), op3->GetType()) ) + ExprError("operands must be of the same type"); + } + else if ( bt2 == zeek::TYPE_VECTOR ) { + auto vt2 = op2->GetType(); + auto vt3 = op3->GetType(); + + if ( vt2->IsUnspecifiedVector() ) + op2 = make_intrusive(std::move(op2), std::move(vt3)); + else if ( vt3->IsUnspecifiedVector() ) + op3 = make_intrusive(std::move(op3), std::move(vt2)); + else if ( ! same_type(op2->GetType(), op3->GetType()) ) + ExprError("operands must be of the same type"); + } + else if ( ! same_type(op2->GetType(), op3->GetType()) ) + // Records could potentially also coerce, but may have some cases + // where the coercion direction is ambiguous. + ExprError("operands must be of the same type"); + + if ( ! IsError() ) + SetType(op2->GetType()); + } + } +} + +ValPtr CondExpr::Eval(Frame* f) const { + if ( ! is_vector(op1) ) { + // Scalar case + auto false_eval = op1->Eval(f)->IsZero(); + return (false_eval ? op3 : op2)->Eval(f); + } + + // Vector case: no mixed scalar/vector cases allowed + auto v1 = op1->Eval(f); + + if ( ! v1 ) + return nullptr; + + auto v2 = op2->Eval(f); + + if ( ! v2 ) + return nullptr; + + auto v3 = op3->Eval(f); + + if ( ! v3 ) + return nullptr; + + VectorVal* cond = v1->AsVectorVal(); + VectorVal* a = v2->AsVectorVal(); + VectorVal* b = v3->AsVectorVal(); + + if ( cond->Size() != a->Size() || a->Size() != b->Size() ) { + RuntimeError("vectors in conditional expression have different sizes"); + return nullptr; + } + + auto result = make_intrusive(GetType()); + result->Resize(cond->Size()); + + for ( unsigned int i = 0; i < cond->Size(); ++i ) { + auto local_cond = cond->BoolAt(i); + auto v = local_cond ? a->ValAt(i) : b->ValAt(i); + result->Assign(i, v); + } + + return result; +} + +bool CondExpr::IsPure() const { return op1->IsPure() && op2->IsPure() && op3->IsPure(); } + +TraversalCode CondExpr::Traverse(TraversalCallback* cb) const { + TraversalCode tc = cb->PreExpr(this); + HANDLE_TC_EXPR_PRE(tc); + + tc = op1->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + + tc = op2->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + + tc = op3->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + + tc = cb->PostExpr(this); + HANDLE_TC_EXPR_POST(tc); +} + +void CondExpr::ExprDescribe(ODesc* d) const { + op1->Describe(d); + d->AddSP(" ?"); + op2->Describe(d); + d->AddSP(" :"); + op3->Describe(d); +} + +RefExpr::RefExpr(ExprPtr arg_op) : UnaryExpr(EXPR_REF, std::move(arg_op)) { + if ( IsError() ) + return; + + if ( ! is_assignable(op->GetType()->Tag()) ) + ExprError("illegal assignment target"); + else + SetType(op->GetType()); +} + +ExprPtr RefExpr::MakeLvalue() { return ThisPtr(); } + +void RefExpr::Assign(Frame* f, ValPtr v) { op->Assign(f, std::move(v)); } + +AssignExpr::AssignExpr(ExprPtr arg_op1, ExprPtr arg_op2, bool arg_is_init, ValPtr arg_val, const AttributesPtr& attrs, + bool typecheck) + : BinaryExpr(EXPR_ASSIGN, arg_is_init ? std::move(arg_op1) : arg_op1->MakeLvalue(), std::move(arg_op2)) { + val = nullptr; + is_init = arg_is_init; + + if ( IsError() ) + return; + + if ( arg_val ) + SetType(arg_val->GetType()); + else + SetType(op1->GetType()); + + if ( is_init ) { + SetLocationInfo(op1->GetLocationInfo(), op2->GetLocationInfo()); + return; + } + + if ( op2->Tag() == EXPR_LIST && CheckForRHSList() ) { + if ( op2->Tag() == EXPR_TABLE_CONSTRUCTOR ) + cast_intrusive(op2)->SetAttrs(attrs); + else if ( op2->Tag() == EXPR_SET_CONSTRUCTOR ) + cast_intrusive(op2)->SetAttrs(attrs); + } + + else if ( typecheck ) + // We discard the status from TypeCheck since it has already + // generated error messages. + (void)TypeCheck(attrs); + + val = std::move(arg_val); + + SetLocationInfo(op1->GetLocationInfo(), op2->GetLocationInfo()); +} + +bool AssignExpr::TypeCheck(const AttributesPtr& attrs) { + TypeTag bt1 = op1->GetType()->Tag(); + TypeTag bt2 = op2->GetType()->Tag(); + + if ( bt1 == TYPE_LIST && bt2 == TYPE_ANY ) + // This is ok because we cannot explicitly declare lists on + // the script level. + return true; + + // This should be one of them, but not both (i.e. XOR) + if ( ((bt1 == TYPE_ENUM) ^ (bt2 == TYPE_ENUM)) ) { + ExprError("can't convert to/from enumerated type"); + return false; + } + + if ( IsArithmetic(bt1) ) + return TypeCheckArithmetics(bt1, bt2); + + if ( bt1 == TYPE_TIME && IsArithmetic(bt2) && op2->IsZero() ) { // Allow assignments to zero as a special case. + op2 = make_intrusive(std::move(op2), bt1); + return true; + } + + if ( bt1 == TYPE_TABLE && bt2 == bt1 && op2->GetType()->AsTableType()->IsUnspecifiedTable() ) { + op2 = make_intrusive(std::move(op2), op1->GetType()); + return true; + } + + if ( bt1 == TYPE_VECTOR ) { + if ( bt2 == bt1 && op2->GetType()->AsVectorType()->IsUnspecifiedVector() ) { + op2 = make_intrusive(std::move(op2), op1->GetType()); + return true; + } + + if ( op2->Tag() == EXPR_LIST ) { + op2 = make_intrusive(cast_intrusive(op2), op1->GetType()); + return true; + } + } + + if ( op1->GetType()->Tag() == TYPE_RECORD && op2->GetType()->Tag() == TYPE_RECORD ) { + if ( same_type(op1->GetType(), op2->GetType()) ) + return true; + + // Need to coerce. + op2 = make_intrusive(std::move(op2), op1->GetType()); + return true; + } + + if ( ! same_type(op1->GetType(), op2->GetType()) ) { + if ( bt1 == TYPE_TABLE && bt2 == TYPE_TABLE ) { + if ( op2->Tag() == EXPR_SET_CONSTRUCTOR ) { + // Some elements in constructor list must not match, see if + // we can create a new constructor now that the expected type + // of LHS is known and let it do coercions where possible. + auto sce = cast_intrusive(op2); + auto ctor_list = cast_intrusive(sce->GetOp1()); + + if ( ! ctor_list ) + Internal("failed typecast to ListExpr"); + + std::unique_ptr> attr_copy; + + if ( sce->GetAttrs() ) { + const auto& a = sce->GetAttrs()->GetAttrs(); + attr_copy = std::make_unique>(a); + } + + int errors_before = reporter->Errors(); + op2 = make_intrusive(ctor_list, std::move(attr_copy), op1->GetType()); + int errors_after = reporter->Errors(); + + if ( errors_after > errors_before ) { + ExprError("type clash in assignment"); + return false; + } + + return true; + } + } + + ExprError("type clash in assignment"); + return false; + } + + return true; +} + +bool AssignExpr::TypeCheckArithmetics(TypeTag bt1, TypeTag bt2) { + if ( ! IsArithmetic(bt2) ) { + ExprError( + util::fmt("assignment of non-arithmetic value to arithmetic (%s/%s)", type_name(bt1), type_name(bt2))); + return false; + } + + if ( bt1 == TYPE_DOUBLE ) { + PromoteOps(TYPE_DOUBLE); + return true; + } + + if ( bt2 == TYPE_DOUBLE ) { + Warn("dangerous assignment of double to integral"); + op2 = make_intrusive(std::move(op2), bt1); + bt2 = op2->GetType()->Tag(); + } + + if ( bt1 == TYPE_INT ) + PromoteOps(TYPE_INT); + else { + if ( bt2 == TYPE_INT ) { + Warn("dangerous assignment of integer to count"); + op2 = make_intrusive(std::move(op2), bt1); + } + + // Assignment of count to counter or vice + // versa is allowed, and requires no + // coercion. + } + + return true; +} + +ValPtr AssignExpr::Eval(Frame* f) const { + if ( is_init ) { + RuntimeError("illegal assignment in initialization"); + return nullptr; + } + + if ( auto v = op2->Eval(f) ) { + op1->Assign(f, v); + + if ( val ) + return val; + + return v; + } + else + return nullptr; +} + +TypePtr AssignExpr::InitType() const { + if ( op1->Tag() != EXPR_LIST ) { + Error("bad initializer, first operand should be a list"); + return nullptr; + } + + const auto& tl = op1->GetType(); + if ( tl->Tag() != TYPE_LIST ) + Internal("inconsistent list expr in AssignExpr::InitType"); + + return make_intrusive(IntrusivePtr{NewRef{}, tl->AsTypeList()}, op2->GetType()); +} + +bool AssignExpr::IsRecordElement(TypeDecl* td) const { + if ( op1->Tag() == EXPR_NAME ) { + if ( td ) { + const NameExpr* n = (const NameExpr*)op1.get(); + td->type = op2->GetType(); + td->id = util::copy_string(n->Id()->Name()); + } + + return true; + } + + return false; +} IndexSliceAssignExpr::IndexSliceAssignExpr(ExprPtr op1, ExprPtr op2, bool is_init) - : AssignExpr(std::move(op1), std::move(op2), is_init) - { - } - -ValPtr IndexSliceAssignExpr::Eval(Frame* f) const - { - if ( is_init ) - { - RuntimeError("illegal assignment in initialization"); - return nullptr; - } - - if ( auto v = op2->Eval(f) ) - op1->Assign(f, std::move(v)); - - return nullptr; - } - -IndexExpr::IndexExpr(ExprPtr arg_op1, ListExprPtr arg_op2, bool arg_is_slice, - bool arg_is_inside_when) - : BinaryExpr(EXPR_INDEX, std::move(arg_op1), std::move(arg_op2)), is_slice(arg_is_slice), - is_inside_when(arg_is_inside_when) - { - if ( IsError() ) - return; - - if ( is_slice ) - { - if ( ! IsString(op1->GetType()->Tag()) && ! IsVector(op1->GetType()->Tag()) ) - ExprError("slice notation indexing only supported for strings and vectors currently"); - } - - else if ( IsString(op1->GetType()->Tag()) ) - { - if ( op2->AsListExpr()->Exprs().length() != 1 ) - ExprError("invalid string index expression"); - } - - if ( IsError() ) - return; - - int match_type = op1->GetType()->MatchesIndex(op2->AsListExpr()); - - if ( match_type == DOES_NOT_MATCH_INDEX ) - { - std::string error_msg = util::fmt( - "expression with type '%s' is not a type that can be indexed", - type_name(op1->GetType()->Tag())); - SetError(error_msg.data()); - } + : AssignExpr(std::move(op1), std::move(op2), is_init) {} - else if ( ! op1->GetType()->Yield() ) - { - if ( IsString(op1->GetType()->Tag()) && match_type == MATCHES_INDEX_SCALAR ) - SetType(base_type(TYPE_STRING)); - else - // It's a set - so indexing it yields void. We don't - // directly generate an error message, though, since this - // expression might be part of an add/delete statement, - // rather than yielding a value. - SetType(base_type(TYPE_VOID)); - } +ValPtr IndexSliceAssignExpr::Eval(Frame* f) const { + if ( is_init ) { + RuntimeError("illegal assignment in initialization"); + return nullptr; + } - else if ( match_type == MATCHES_INDEX_SCALAR ) - SetType(op1->GetType()->Yield()); - - else if ( match_type == MATCHES_INDEX_VECTOR ) - SetType(make_intrusive(op1->GetType()->Yield())); - - else - ExprError("Unknown MatchesIndex() return value"); - } - -bool IndexExpr::CanAdd() const - { - if ( IsError() ) - return true; // avoid cascading the error report - - // "add" only allowed if our type is "set". - return op1->GetType()->IsSet(); - } - -bool IndexExpr::CanDel() const - { - if ( IsError() ) - return true; // avoid cascading the error report - - return op1->GetType()->Tag() == TYPE_TABLE; - } - -void IndexExpr::Add(Frame* f) - { - if ( IsError() ) - return; - - auto v1 = op1->Eval(f); - - if ( ! v1 ) - return; - - auto v2 = op2->Eval(f); - - if ( ! v2 ) - return; - - bool iterators_invalidated = false; - v1->AsTableVal()->Assign(std::move(v2), nullptr, true, &iterators_invalidated); - - if ( iterators_invalidated ) - reporter->ExprRuntimeWarning(this, "possible loop/iterator invalidation"); - } - -void IndexExpr::Delete(Frame* f) - { - if ( IsError() ) - return; - - auto v1 = op1->Eval(f); - - if ( ! v1 ) - return; - - auto v2 = op2->Eval(f); - - if ( ! v2 ) - return; - - bool iterators_invalidated = false; - v1->AsTableVal()->Remove(*v2, true, &iterators_invalidated); - - if ( iterators_invalidated ) - reporter->ExprRuntimeWarning(this, "possible loop/iterator invalidation"); - } - -ExprPtr IndexExpr::MakeLvalue() - { - if ( IsString(op1->GetType()->Tag()) ) - ExprError("cannot assign to string index expression"); - - return make_intrusive(ThisPtr()); - } - -ValPtr IndexExpr::Eval(Frame* f) const - { - auto v1 = op1->Eval(f); - - if ( ! v1 ) - return nullptr; - - auto v2 = op2->Eval(f); - - if ( ! v2 ) - return nullptr; - - Val* indv = v2->AsListVal()->Idx(0).get(); - - if ( is_vector(v1) && is_vector(indv) ) - { - VectorVal* v_v1 = v1->AsVectorVal(); - VectorVal* v_v2 = indv->AsVectorVal(); - auto vt = cast_intrusive(GetType()); - - // Booleans select each element (or not). - if ( IsBool(v_v2->GetType()->Yield()->Tag()) ) - { - if ( v_v1->Size() != v_v2->Size() ) - { - RuntimeError("size mismatch, boolean index and vector"); - return nullptr; - } - - return vector_bool_select(vt, v_v1, v_v2); - } - else - // Elements are indices. - return vector_int_select(vt, v_v1, v_v2); - } - else - return Fold(v1.get(), v2.get()); - } - -ValPtr IndexExpr::Fold(Val* v1, Val* v2) const - { - if ( IsError() ) - return nullptr; - - ValPtr v; - - switch ( v1->GetType()->Tag() ) - { - case TYPE_VECTOR: - { - VectorVal* vect = v1->AsVectorVal(); - const ListVal* lv = v2->AsListVal(); - - if ( lv->Length() == 1 ) - { - auto index = lv->Idx(0)->CoerceToInt(); - if ( index < 0 ) - index = vect->Size() + index; - - v = vect->ValAt(index); - } - else - return index_slice(vect, lv); - } - break; - - case TYPE_TABLE: - v = v1->AsTableVal()->FindOrDefault({NewRef{}, v2}); - break; - - case TYPE_STRING: - return index_string(v1->AsString(), v2->AsListVal()); - - default: - RuntimeError("type cannot be indexed"); - break; - } - - if ( v ) - return v; - - RuntimeError("no such index"); - return nullptr; - } - -StringValPtr index_string(const String* s, const ListVal* lv) - { - int len = s->Len(); - String* substring = nullptr; - - if ( lv->Length() == 1 ) - { - zeek_int_t idx = lv->Idx(0)->AsInt(); - - if ( idx < 0 ) - idx += len; - - // Out-of-range index will return null pointer. - substring = s->GetSubstring(idx, 1); - } - else - { - zeek_int_t first = get_slice_index(lv->Idx(0)->AsInt(), len); - zeek_int_t last = get_slice_index(lv->Idx(1)->AsInt(), len); - zeek_int_t substring_len = last - first; - - if ( substring_len < 0 ) - substring = nullptr; - else - substring = s->GetSubstring(first, substring_len); - } - - return make_intrusive(substring ? substring : new String("")); - } - -VectorValPtr index_slice(VectorVal* vect, const ListVal* lv) - { - auto first = lv->Idx(0)->CoerceToInt(); - auto last = lv->Idx(1)->CoerceToInt(); - return index_slice(vect, first, last); - } - -VectorValPtr index_slice(VectorVal* vect, int _first, int _last) - { - size_t len = vect->Size(); - auto result = make_intrusive(vect->GetType()); - - zeek_int_t first = get_slice_index(_first, len); - zeek_int_t last = get_slice_index(_last, len); - zeek_int_t sub_length = last - first; - - if ( sub_length >= 0 ) - { - result->Resize(sub_length); - - for ( zeek_int_t idx = first; idx < last; idx++ ) - result->Assign(idx - first, vect->ValAt(idx)); - } - - return result; - } - -VectorValPtr vector_bool_select(VectorTypePtr vt, const VectorVal* v1, const VectorVal* v2) - { - auto v_result = make_intrusive(std::move(vt)); - - for ( unsigned int i = 0; i < v2->Size(); ++i ) - if ( v2->BoolAt(i) ) - v_result->Assign(v_result->Size() + 1, v1->ValAt(i)); - - return v_result; - } - -VectorValPtr vector_int_select(VectorTypePtr vt, const VectorVal* v1, const VectorVal* v2) - { - auto v_result = make_intrusive(std::move(vt)); - - // The elements are indices. - // - // ### Should handle negative indices here like S does, i.e., - // by excluding those elements. Probably only do this if *all* - // are negative. - v_result->Resize(v2->Size()); - for ( unsigned int i = 0; i < v2->Size(); ++i ) - v_result->Assign(i, v1->ValAt(v2->ValAt(i)->CoerceToInt())); - - return v_result; - } - -void IndexExpr::Assign(Frame* f, ValPtr v) - { - if ( IsError() ) - return; - - auto v1 = op1->Eval(f); - auto v2 = op2->Eval(f); - - AssignToIndex(v1, v2, v); - } - -void IndexExpr::ExprDescribe(ODesc* d) const - { - op1->Describe(d); - if ( d->IsReadable() ) - d->Add("["); - - op2->Describe(d); - if ( d->IsReadable() ) - d->Add("]"); - } - -static void report_field_deprecation(const RecordType* rt, const Expr* e, int field, - bool has_check = false) - { - reporter->Deprecation(util::fmt("%s (%s)", - rt->GetFieldDeprecationWarning(field, has_check).c_str(), - obj_desc_short(e).c_str()), - e->GetLocationInfo()); - } + if ( auto v = op2->Eval(f) ) + op1->Assign(f, std::move(v)); + + return nullptr; +} + +IndexExpr::IndexExpr(ExprPtr arg_op1, ListExprPtr arg_op2, bool arg_is_slice, bool arg_is_inside_when) + : BinaryExpr(EXPR_INDEX, std::move(arg_op1), std::move(arg_op2)), + is_slice(arg_is_slice), + is_inside_when(arg_is_inside_when) { + if ( IsError() ) + return; + + if ( is_slice ) { + if ( ! IsString(op1->GetType()->Tag()) && ! IsVector(op1->GetType()->Tag()) ) + ExprError("slice notation indexing only supported for strings and vectors currently"); + } + + else if ( IsString(op1->GetType()->Tag()) ) { + if ( op2->AsListExpr()->Exprs().length() != 1 ) + ExprError("invalid string index expression"); + } + + if ( IsError() ) + return; + + int match_type = op1->GetType()->MatchesIndex(op2->AsListExpr()); + + if ( match_type == DOES_NOT_MATCH_INDEX ) { + std::string error_msg = + util::fmt("expression with type '%s' is not a type that can be indexed", type_name(op1->GetType()->Tag())); + SetError(error_msg.data()); + } + + else if ( ! op1->GetType()->Yield() ) { + if ( IsString(op1->GetType()->Tag()) && match_type == MATCHES_INDEX_SCALAR ) + SetType(base_type(TYPE_STRING)); + else + // It's a set - so indexing it yields void. We don't + // directly generate an error message, though, since this + // expression might be part of an add/delete statement, + // rather than yielding a value. + SetType(base_type(TYPE_VOID)); + } + + else if ( match_type == MATCHES_INDEX_SCALAR ) + SetType(op1->GetType()->Yield()); + + else if ( match_type == MATCHES_INDEX_VECTOR ) + SetType(make_intrusive(op1->GetType()->Yield())); + + else + ExprError("Unknown MatchesIndex() return value"); +} + +bool IndexExpr::CanAdd() const { + if ( IsError() ) + return true; // avoid cascading the error report + + // "add" only allowed if our type is "set". + return op1->GetType()->IsSet(); +} + +bool IndexExpr::CanDel() const { + if ( IsError() ) + return true; // avoid cascading the error report + + return op1->GetType()->Tag() == TYPE_TABLE; +} + +void IndexExpr::Add(Frame* f) { + if ( IsError() ) + return; + + auto v1 = op1->Eval(f); + + if ( ! v1 ) + return; + + auto v2 = op2->Eval(f); + + if ( ! v2 ) + return; + + bool iterators_invalidated = false; + v1->AsTableVal()->Assign(std::move(v2), nullptr, true, &iterators_invalidated); + + if ( iterators_invalidated ) + reporter->ExprRuntimeWarning(this, "possible loop/iterator invalidation"); +} + +void IndexExpr::Delete(Frame* f) { + if ( IsError() ) + return; + + auto v1 = op1->Eval(f); + + if ( ! v1 ) + return; + + auto v2 = op2->Eval(f); + + if ( ! v2 ) + return; + + bool iterators_invalidated = false; + v1->AsTableVal()->Remove(*v2, true, &iterators_invalidated); + + if ( iterators_invalidated ) + reporter->ExprRuntimeWarning(this, "possible loop/iterator invalidation"); +} + +ExprPtr IndexExpr::MakeLvalue() { + if ( IsString(op1->GetType()->Tag()) ) + ExprError("cannot assign to string index expression"); + + return make_intrusive(ThisPtr()); +} + +ValPtr IndexExpr::Eval(Frame* f) const { + auto v1 = op1->Eval(f); + + if ( ! v1 ) + return nullptr; + + auto v2 = op2->Eval(f); + + if ( ! v2 ) + return nullptr; + + Val* indv = v2->AsListVal()->Idx(0).get(); + + if ( is_vector(v1) && is_vector(indv) ) { + VectorVal* v_v1 = v1->AsVectorVal(); + VectorVal* v_v2 = indv->AsVectorVal(); + auto vt = cast_intrusive(GetType()); + + // Booleans select each element (or not). + if ( IsBool(v_v2->GetType()->Yield()->Tag()) ) { + if ( v_v1->Size() != v_v2->Size() ) { + RuntimeError("size mismatch, boolean index and vector"); + return nullptr; + } + + return vector_bool_select(vt, v_v1, v_v2); + } + else + // Elements are indices. + return vector_int_select(vt, v_v1, v_v2); + } + else + return Fold(v1.get(), v2.get()); +} + +ValPtr IndexExpr::Fold(Val* v1, Val* v2) const { + if ( IsError() ) + return nullptr; + + ValPtr v; + + switch ( v1->GetType()->Tag() ) { + case TYPE_VECTOR: { + VectorVal* vect = v1->AsVectorVal(); + const ListVal* lv = v2->AsListVal(); + + if ( lv->Length() == 1 ) { + auto index = lv->Idx(0)->CoerceToInt(); + if ( index < 0 ) + index = vect->Size() + index; + + v = vect->ValAt(index); + } + else + return index_slice(vect, lv); + } break; + + case TYPE_TABLE: v = v1->AsTableVal()->FindOrDefault({NewRef{}, v2}); break; + + case TYPE_STRING: return index_string(v1->AsString(), v2->AsListVal()); + + default: RuntimeError("type cannot be indexed"); break; + } + + if ( v ) + return v; + + RuntimeError("no such index"); + return nullptr; +} + +StringValPtr index_string(const String* s, const ListVal* lv) { + int len = s->Len(); + String* substring = nullptr; + + if ( lv->Length() == 1 ) { + zeek_int_t idx = lv->Idx(0)->AsInt(); + + if ( idx < 0 ) + idx += len; + + // Out-of-range index will return null pointer. + substring = s->GetSubstring(idx, 1); + } + else { + zeek_int_t first = get_slice_index(lv->Idx(0)->AsInt(), len); + zeek_int_t last = get_slice_index(lv->Idx(1)->AsInt(), len); + zeek_int_t substring_len = last - first; + + if ( substring_len < 0 ) + substring = nullptr; + else + substring = s->GetSubstring(first, substring_len); + } + + return make_intrusive(substring ? substring : new String("")); +} + +VectorValPtr index_slice(VectorVal* vect, const ListVal* lv) { + auto first = lv->Idx(0)->CoerceToInt(); + auto last = lv->Idx(1)->CoerceToInt(); + return index_slice(vect, first, last); +} + +VectorValPtr index_slice(VectorVal* vect, int _first, int _last) { + size_t len = vect->Size(); + auto result = make_intrusive(vect->GetType()); + + zeek_int_t first = get_slice_index(_first, len); + zeek_int_t last = get_slice_index(_last, len); + zeek_int_t sub_length = last - first; + + if ( sub_length >= 0 ) { + result->Resize(sub_length); + + for ( zeek_int_t idx = first; idx < last; idx++ ) + result->Assign(idx - first, vect->ValAt(idx)); + } + + return result; +} + +VectorValPtr vector_bool_select(VectorTypePtr vt, const VectorVal* v1, const VectorVal* v2) { + auto v_result = make_intrusive(std::move(vt)); + + for ( unsigned int i = 0; i < v2->Size(); ++i ) + if ( v2->BoolAt(i) ) + v_result->Assign(v_result->Size() + 1, v1->ValAt(i)); + + return v_result; +} + +VectorValPtr vector_int_select(VectorTypePtr vt, const VectorVal* v1, const VectorVal* v2) { + auto v_result = make_intrusive(std::move(vt)); + + // The elements are indices. + // + // ### Should handle negative indices here like S does, i.e., + // by excluding those elements. Probably only do this if *all* + // are negative. + v_result->Resize(v2->Size()); + for ( unsigned int i = 0; i < v2->Size(); ++i ) + v_result->Assign(i, v1->ValAt(v2->ValAt(i)->CoerceToInt())); + + return v_result; +} + +void IndexExpr::Assign(Frame* f, ValPtr v) { + if ( IsError() ) + return; + + auto v1 = op1->Eval(f); + auto v2 = op2->Eval(f); + + AssignToIndex(v1, v2, v); +} + +void IndexExpr::ExprDescribe(ODesc* d) const { + op1->Describe(d); + if ( d->IsReadable() ) + d->Add("["); + + op2->Describe(d); + if ( d->IsReadable() ) + d->Add("]"); +} + +static void report_field_deprecation(const RecordType* rt, const Expr* e, int field, bool has_check = false) { + reporter->Deprecation(util::fmt("%s (%s)", rt->GetFieldDeprecationWarning(field, has_check).c_str(), + obj_desc_short(e).c_str()), + e->GetLocationInfo()); +} FieldExpr::FieldExpr(ExprPtr arg_op, const char* arg_field_name) - : UnaryExpr(EXPR_FIELD, std::move(arg_op)), field_name(util::copy_string(arg_field_name)), - td(nullptr), field(0) - { - if ( IsError() ) - return; + : UnaryExpr(EXPR_FIELD, std::move(arg_op)), field_name(util::copy_string(arg_field_name)), td(nullptr), field(0) { + if ( IsError() ) + return; - if ( ! IsRecord(op->GetType()->Tag()) ) - ExprError("not a record"); - else - { - RecordType* rt = op->GetType()->AsRecordType(); - field = rt->FieldOffset(field_name); + if ( ! IsRecord(op->GetType()->Tag()) ) + ExprError("not a record"); + else { + RecordType* rt = op->GetType()->AsRecordType(); + field = rt->FieldOffset(field_name); - if ( field < 0 ) - ExprError("no such field in record"); - else - { - SetType(rt->GetFieldType(field)); - td = rt->FieldDecl(field); + if ( field < 0 ) + ExprError("no such field in record"); + else { + SetType(rt->GetFieldType(field)); + td = rt->FieldDecl(field); - if ( rt->IsFieldDeprecated(field) ) - report_field_deprecation(rt, this, field); - } - } - } + if ( rt->IsFieldDeprecated(field) ) + report_field_deprecation(rt, this, field); + } + } +} -FieldExpr::~FieldExpr() - { - delete[] field_name; - } +FieldExpr::~FieldExpr() { delete[] field_name; } -ExprPtr FieldExpr::MakeLvalue() - { - return make_intrusive(ThisPtr()); - } +ExprPtr FieldExpr::MakeLvalue() { return make_intrusive(ThisPtr()); } -bool FieldExpr::CanDel() const - { - return td->GetAttr(ATTR_DEFAULT) || td->GetAttr(ATTR_OPTIONAL); - } +bool FieldExpr::CanDel() const { return td->GetAttr(ATTR_DEFAULT) || td->GetAttr(ATTR_OPTIONAL); } -void FieldExpr::Assign(Frame* f, ValPtr v) - { - if ( IsError() ) - return; +void FieldExpr::Assign(Frame* f, ValPtr v) { + if ( IsError() ) + return; - if ( auto op_v = op->Eval(f) ) - { - RecordVal* r = op_v->AsRecordVal(); - r->Assign(field, std::move(v)); - } - } + if ( auto op_v = op->Eval(f) ) { + RecordVal* r = op_v->AsRecordVal(); + r->Assign(field, std::move(v)); + } +} -void FieldExpr::Delete(Frame* f) - { - Assign(f, nullptr); - } +void FieldExpr::Delete(Frame* f) { Assign(f, nullptr); } -ValPtr FieldExpr::Fold(Val* v) const - { - if ( const auto& result = v->AsRecordVal()->GetField(field) ) - return result; +ValPtr FieldExpr::Fold(Val* v) const { + if ( const auto& result = v->AsRecordVal()->GetField(field) ) + return result; - // Check for &default. - const Attr* def_attr = td ? td->GetAttr(ATTR_DEFAULT).get() : nullptr; + // Check for &default. + const Attr* def_attr = td ? td->GetAttr(ATTR_DEFAULT).get() : nullptr; - if ( def_attr ) - return def_attr->GetExpr()->Eval(nullptr); - else - { - RuntimeError("field value missing"); - assert(false); - return nullptr; // Will never get here, but compiler can't tell. - } - } + if ( def_attr ) + return def_attr->GetExpr()->Eval(nullptr); + else { + RuntimeError("field value missing"); + assert(false); + return nullptr; // Will never get here, but compiler can't tell. + } +} -void FieldExpr::ExprDescribe(ODesc* d) const - { - op->Describe(d); - if ( d->IsReadable() ) - d->Add("$"); +void FieldExpr::ExprDescribe(ODesc* d) const { + op->Describe(d); + if ( d->IsReadable() ) + d->Add("$"); - if ( IsError() ) - d->Add(""); - else if ( d->IsReadable() ) - d->Add(field_name); - else - d->Add(field); - } + if ( IsError() ) + d->Add(""); + else if ( d->IsReadable() ) + d->Add(field_name); + else + d->Add(field); +} HasFieldExpr::HasFieldExpr(ExprPtr arg_op, const char* arg_field_name) - : UnaryExpr(EXPR_HAS_FIELD, std::move(arg_op)), field_name(arg_field_name), field(0) - { - if ( IsError() ) - return; + : UnaryExpr(EXPR_HAS_FIELD, std::move(arg_op)), field_name(arg_field_name), field(0) { + if ( IsError() ) + return; - if ( ! IsRecord(op->GetType()->Tag()) ) - ExprError("not a record"); - else - { - RecordType* rt = op->GetType()->AsRecordType(); - field = rt->FieldOffset(field_name); + if ( ! IsRecord(op->GetType()->Tag()) ) + ExprError("not a record"); + else { + RecordType* rt = op->GetType()->AsRecordType(); + field = rt->FieldOffset(field_name); - if ( field < 0 ) - ExprError("no such field in record"); - else if ( rt->IsFieldDeprecated(field) ) - report_field_deprecation(rt, this, field, true); + if ( field < 0 ) + ExprError("no such field in record"); + else if ( rt->IsFieldDeprecated(field) ) + report_field_deprecation(rt, this, field, true); - SetType(base_type(TYPE_BOOL)); - } - } + SetType(base_type(TYPE_BOOL)); + } +} -HasFieldExpr::~HasFieldExpr() - { - delete field_name; - } +HasFieldExpr::~HasFieldExpr() { delete field_name; } -ValPtr HasFieldExpr::Fold(Val* v) const - { - auto rv = v->AsRecordVal(); - return val_mgr->Bool(rv->HasField(field)); - } +ValPtr HasFieldExpr::Fold(Val* v) const { + auto rv = v->AsRecordVal(); + return val_mgr->Bool(rv->HasField(field)); +} -void HasFieldExpr::ExprDescribe(ODesc* d) const - { - op->Describe(d); +void HasFieldExpr::ExprDescribe(ODesc* d) const { + op->Describe(d); - if ( d->IsReadable() ) - d->Add("?$"); + if ( d->IsReadable() ) + d->Add("?$"); - if ( IsError() ) - d->Add(""); - else if ( d->IsReadable() ) - d->Add(field_name); - else - d->Add(field); - } + if ( IsError() ) + d->Add(""); + else if ( d->IsReadable() ) + d->Add(field_name); + else + d->Add(field); +} RecordConstructorExpr::RecordConstructorExpr(ListExprPtr constructor_list) - : Expr(EXPR_RECORD_CONSTRUCTOR), op(std::move(constructor_list)), map(std::nullopt) - { - if ( IsError() ) - return; + : Expr(EXPR_RECORD_CONSTRUCTOR), op(std::move(constructor_list)), map(std::nullopt) { + if ( IsError() ) + return; - // Spin through the list, which should be comprised only of - // record-field-assign expressions, and build up a - // record type to associate with this constructor. - const ExprPList& exprs = op->AsListExpr()->Exprs(); - type_decl_list* record_types = new type_decl_list(exprs.length()); + // Spin through the list, which should be comprised only of + // record-field-assign expressions, and build up a + // record type to associate with this constructor. + const ExprPList& exprs = op->AsListExpr()->Exprs(); + type_decl_list* record_types = new type_decl_list(exprs.length()); - const Expr* constructor_error_expr = nullptr; + const Expr* constructor_error_expr = nullptr; - for ( const auto& e : exprs ) - { - if ( e->Tag() != EXPR_FIELD_ASSIGN ) - { - // Don't generate the error yet, as reporting it - // requires that we have a well-formed type. - constructor_error_expr = e; - SetError(); - continue; - } + for ( const auto& e : exprs ) { + if ( e->Tag() != EXPR_FIELD_ASSIGN ) { + // Don't generate the error yet, as reporting it + // requires that we have a well-formed type. + constructor_error_expr = e; + SetError(); + continue; + } - FieldAssignExpr* field = (FieldAssignExpr*)e; - const auto& field_type = field->GetType(); - char* field_name = util::copy_string(field->FieldName()); - record_types->push_back(new TypeDecl(field_name, field_type)); - } + FieldAssignExpr* field = (FieldAssignExpr*)e; + const auto& field_type = field->GetType(); + char* field_name = util::copy_string(field->FieldName()); + record_types->push_back(new TypeDecl(field_name, field_type)); + } - SetType(make_intrusive(record_types)); + SetType(make_intrusive(record_types)); - if ( constructor_error_expr ) - Error("bad type in record constructor", constructor_error_expr); - } + if ( constructor_error_expr ) + Error("bad type in record constructor", constructor_error_expr); +} RecordConstructorExpr::RecordConstructorExpr(RecordTypePtr known_rt, ListExprPtr constructor_list) - : Expr(EXPR_RECORD_CONSTRUCTOR), op(std::move(constructor_list)) - { - if ( IsError() ) - return; + : Expr(EXPR_RECORD_CONSTRUCTOR), op(std::move(constructor_list)) { + if ( IsError() ) + return; - SetType(known_rt); + SetType(known_rt); - const auto& exprs = op->AsListExpr()->Exprs(); - map = std::vector(exprs.length()); + const auto& exprs = op->AsListExpr()->Exprs(); + map = std::vector(exprs.length()); - std::set fields_seen; // used to check for missing fields + std::set fields_seen; // used to check for missing fields - int i = 0; - for ( const auto& e : exprs ) - { - if ( e->Tag() != EXPR_FIELD_ASSIGN ) - { - Error("bad type in record constructor", e); - SetError(); - continue; - } + int i = 0; + for ( const auto& e : exprs ) { + if ( e->Tag() != EXPR_FIELD_ASSIGN ) { + Error("bad type in record constructor", e); + SetError(); + continue; + } - auto field = e->AsFieldAssignExpr(); - int index = known_rt->FieldOffset(field->FieldName()); + auto field = e->AsFieldAssignExpr(); + int index = known_rt->FieldOffset(field->FieldName()); - if ( index < 0 ) - { - Error("no such field in record", e); - SetError(); - continue; - } + if ( index < 0 ) { + Error("no such field in record", e); + SetError(); + continue; + } - auto known_ft = known_rt->GetFieldType(index); + auto known_ft = known_rt->GetFieldType(index); - if ( ! field->PromoteTo(known_ft) ) - SetError(); + if ( ! field->PromoteTo(known_ft) ) + SetError(); - (*map)[i++] = index; - fields_seen.insert(index); - } + (*map)[i++] = index; + fields_seen.insert(index); + } - if ( IsError() ) - return; + if ( IsError() ) + return; - auto n = known_rt->NumFields(); - for ( i = 0; i < n; ++i ) - if ( fields_seen.count(i) == 0 ) - { - const auto td_i = known_rt->FieldDecl(i); - if ( IsAggr(td_i->type) ) - // These are always initialized. - continue; + auto n = known_rt->NumFields(); + for ( i = 0; i < n; ++i ) + if ( fields_seen.count(i) == 0 ) { + const auto td_i = known_rt->FieldDecl(i); + if ( IsAggr(td_i->type) ) + // These are always initialized. + continue; - if ( ! td_i->GetAttr(ATTR_OPTIONAL) ) - { - auto err = std::string("mandatory field \"") + known_rt->FieldName(i) + - "\" missing"; - ExprError(err.c_str()); - } - } - else if ( known_rt->IsFieldDeprecated(i) ) - report_field_deprecation(known_rt.get(), this, i); - } + if ( ! td_i->GetAttr(ATTR_OPTIONAL) ) { + auto err = std::string("mandatory field \"") + known_rt->FieldName(i) + "\" missing"; + ExprError(err.c_str()); + } + } + else if ( known_rt->IsFieldDeprecated(i) ) + report_field_deprecation(known_rt.get(), this, i); +} -ValPtr RecordConstructorExpr::Eval(Frame* f) const - { - if ( IsError() ) - return nullptr; +ValPtr RecordConstructorExpr::Eval(Frame* f) const { + if ( IsError() ) + return nullptr; - const auto& exprs = op->Exprs(); - auto rt = cast_intrusive(type); + const auto& exprs = op->Exprs(); + auto rt = cast_intrusive(type); - if ( ! map && exprs.length() != rt->NumFields() ) - RuntimeErrorWithCallStack("inconsistency evaluating record constructor"); + if ( ! map && exprs.length() != rt->NumFields() ) + RuntimeErrorWithCallStack("inconsistency evaluating record constructor"); - auto rv = make_intrusive(rt); + auto rv = make_intrusive(rt); - for ( int i = 0; i < exprs.length(); ++i ) - { - auto v_i = exprs[i]->Eval(f); - int ind = map ? (*map)[i] : i; + for ( int i = 0; i < exprs.length(); ++i ) { + auto v_i = exprs[i]->Eval(f); + int ind = map ? (*map)[i] : i; - if ( v_i && v_i->GetType()->Tag() == TYPE_VECTOR && - v_i->GetType()->IsUnspecifiedVector() ) - { - const auto& t_ind = rt->GetFieldType(ind); - v_i->AsVectorVal()->Concretize(t_ind->Yield()); - } + if ( v_i && v_i->GetType()->Tag() == TYPE_VECTOR && v_i->GetType()->IsUnspecifiedVector() ) { + const auto& t_ind = rt->GetFieldType(ind); + v_i->AsVectorVal()->Concretize(t_ind->Yield()); + } - rv->Assign(ind, v_i); - } + rv->Assign(ind, v_i); + } - return rv; - } + return rv; +} -bool RecordConstructorExpr::IsPure() const - { - return op->IsPure(); - } +bool RecordConstructorExpr::IsPure() const { return op->IsPure(); } -void RecordConstructorExpr::ExprDescribe(ODesc* d) const - { - auto& tn = type->GetName(); +void RecordConstructorExpr::ExprDescribe(ODesc* d) const { + auto& tn = type->GetName(); - if ( tn.size() > 0 ) - { - d->Add(tn); - d->Add("("); - op->Describe(d); - d->Add(")"); - } - else - { - d->Add("["); - op->Describe(d); - d->Add("]"); - } - } + if ( tn.size() > 0 ) { + d->Add(tn); + d->Add("("); + op->Describe(d); + d->Add(")"); + } + else { + d->Add("["); + op->Describe(d); + d->Add("]"); + } +} -TraversalCode RecordConstructorExpr::Traverse(TraversalCallback* cb) const - { - TraversalCode tc = cb->PreExpr(this); - HANDLE_TC_EXPR_PRE(tc); +TraversalCode RecordConstructorExpr::Traverse(TraversalCallback* cb) const { + TraversalCode tc = cb->PreExpr(this); + HANDLE_TC_EXPR_PRE(tc); - tc = op->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); + tc = op->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); - tc = cb->PostExpr(this); - HANDLE_TC_EXPR_POST(tc); - } + tc = cb->PostExpr(this); + HANDLE_TC_EXPR_POST(tc); +} -static ExprPtr expand_one_elem(const ExprPList& index_exprs, ExprPtr yield, ExprPtr elem, - int elem_offset) - { - auto expanded_elem = make_intrusive(); +static ExprPtr expand_one_elem(const ExprPList& index_exprs, ExprPtr yield, ExprPtr elem, int elem_offset) { + auto expanded_elem = make_intrusive(); - for ( int i = 0; i < index_exprs.length(); ++i ) - if ( i == elem_offset ) - expanded_elem->Append(elem); - else - expanded_elem->Append({NewRef{}, index_exprs[i]}); + for ( int i = 0; i < index_exprs.length(); ++i ) + if ( i == elem_offset ) + expanded_elem->Append(elem); + else + expanded_elem->Append({NewRef{}, index_exprs[i]}); - if ( yield ) - return make_intrusive(expanded_elem, yield, true); - else - return expanded_elem; - } + if ( yield ) + return make_intrusive(expanded_elem, yield, true); + else + return expanded_elem; +} -static bool expand_op_elem(ListExprPtr elems, ExprPtr elem, TypePtr t) - { - ExprPtr index; - ExprPtr yield; +static bool expand_op_elem(ListExprPtr elems, ExprPtr elem, TypePtr t) { + ExprPtr index; + ExprPtr yield; - if ( elem->Tag() == EXPR_ASSIGN ) - { - if ( t ) - { - if ( ! t->IsTable() ) - { - elem->Error("table constructor used in a non-table context"); - return false; - } + if ( elem->Tag() == EXPR_ASSIGN ) { + if ( t ) { + if ( ! t->IsTable() ) { + elem->Error("table constructor used in a non-table context"); + return false; + } - t = t->AsTableType()->GetIndices(); - } + t = t->AsTableType()->GetIndices(); + } - index = elem->GetOp1(); - yield = elem->GetOp2(); - } - else - index = elem; // this is a set - no yield + index = elem->GetOp1(); + yield = elem->GetOp2(); + } + else + index = elem; // this is a set - no yield - // If the index isn't a list, then there's nothing to consider - // expanding. - if ( index->Tag() != EXPR_LIST ) - { - elems->Append(elem); - return false; - } + // If the index isn't a list, then there's nothing to consider + // expanding. + if ( index->Tag() != EXPR_LIST ) { + elems->Append(elem); + return false; + } - // Look inside the index for any sub-lists or sets, and expand those. - // There might be more than one, but we'll pick that up recursively - // later. - auto& index_exprs = index->AsListExpr()->Exprs(); - int index_n = index_exprs.length(); - int list_offset = -1; - int set_offset = -1; - for ( int i = 0; i < index_n; ++i ) - { - auto& ie_i = index_exprs[i]; + // Look inside the index for any sub-lists or sets, and expand those. + // There might be more than one, but we'll pick that up recursively + // later. + auto& index_exprs = index->AsListExpr()->Exprs(); + int index_n = index_exprs.length(); + int list_offset = -1; + int set_offset = -1; + for ( int i = 0; i < index_n; ++i ) { + auto& ie_i = index_exprs[i]; - if ( ie_i->Tag() == EXPR_LIST ) - { - list_offset = i; - break; - } + if ( ie_i->Tag() == EXPR_LIST ) { + list_offset = i; + break; + } - if ( ie_i->GetType()->IsSet() ) - { - // Check for this set corresponding to what's expected - // in this location, in which case it shouldn't be - // expanded. - const TypeList* tl = nullptr; - if ( t && t->Tag() == TYPE_LIST ) - tl = t->AsTypeList(); + if ( ie_i->GetType()->IsSet() ) { + // Check for this set corresponding to what's expected + // in this location, in which case it shouldn't be + // expanded. + const TypeList* tl = nullptr; + if ( t && t->Tag() == TYPE_LIST ) + tl = t->AsTypeList(); - // So we're good-to-go in expanding if either - // (1) we weren't given a type, or it's not a list, - // or (2) it's a list, but doesn't correspond in - // length to the list of expressions, or (3) it does - // but its corresponding element at this position - // doesn't have the same type as this set. - if ( ! tl || static_cast(tl->GetTypes().size()) != index_n || - ! same_type(tl->GetTypes()[i], ie_i->GetType()) ) - { - set_offset = i; - break; - } - } - } + // So we're good-to-go in expanding if either + // (1) we weren't given a type, or it's not a list, + // or (2) it's a list, but doesn't correspond in + // length to the list of expressions, or (3) it does + // but its corresponding element at this position + // doesn't have the same type as this set. + if ( ! tl || static_cast(tl->GetTypes().size()) != index_n || + ! same_type(tl->GetTypes()[i], ie_i->GetType()) ) { + set_offset = i; + break; + } + } + } - if ( set_offset >= 0 ) - { // expand the set - auto s_e = index_exprs[set_offset]; - auto v = s_e->Eval(nullptr); - if ( ! v ) - { - s_e->Error( - "cannot expand constructor elements using a value that depends on local variables"); - elems->SetError(); - return false; - } + if ( set_offset >= 0 ) { // expand the set + auto s_e = index_exprs[set_offset]; + auto v = s_e->Eval(nullptr); + if ( ! v ) { + s_e->Error("cannot expand constructor elements using a value that depends on local variables"); + elems->SetError(); + return false; + } - for ( auto& s_elem : v->AsTableVal()->ToMap() ) - { - auto c_elem = make_intrusive(s_elem.first); - elems->Append(expand_one_elem(index_exprs, yield, c_elem, set_offset)); - } + for ( auto& s_elem : v->AsTableVal()->ToMap() ) { + auto c_elem = make_intrusive(s_elem.first); + elems->Append(expand_one_elem(index_exprs, yield, c_elem, set_offset)); + } - return true; - } + return true; + } - if ( list_offset < 0 ) - { // No embedded lists. - elems->Append(elem); - return false; - } + if ( list_offset < 0 ) { // No embedded lists. + elems->Append(elem); + return false; + } - // Expand the identified list. - auto sub_list = index_exprs[list_offset]->AsListExpr(); - for ( auto& sub_list_i : sub_list->Exprs() ) - { - ExprPtr e = {NewRef{}, sub_list_i}; - elems->Append(expand_one_elem(index_exprs, yield, e, list_offset)); - } + // Expand the identified list. + auto sub_list = index_exprs[list_offset]->AsListExpr(); + for ( auto& sub_list_i : sub_list->Exprs() ) { + ExprPtr e = {NewRef{}, sub_list_i}; + elems->Append(expand_one_elem(index_exprs, yield, e, list_offset)); + } - return true; - } + return true; +} -ListExprPtr expand_op(ListExprPtr op, const TypePtr& t) - { - auto new_list = make_intrusive(); - bool did_expansion = false; +ListExprPtr expand_op(ListExprPtr op, const TypePtr& t) { + auto new_list = make_intrusive(); + bool did_expansion = false; - for ( auto e : op->Exprs() ) - { - if ( expand_op_elem(new_list, {NewRef{}, e}, t) ) - did_expansion = true; + for ( auto e : op->Exprs() ) { + if ( expand_op_elem(new_list, {NewRef{}, e}, t) ) + did_expansion = true; - if ( new_list->IsError() ) - { - op->SetError(); - return op; - } - } + if ( new_list->IsError() ) { + op->SetError(); + return op; + } + } - if ( did_expansion ) - return expand_op(new_list, t); - else - return op; - } + if ( did_expansion ) + return expand_op(new_list, t); + else + return op; +} TableConstructorExpr::TableConstructorExpr(ListExprPtr constructor_list, - std::unique_ptr> arg_attrs, - TypePtr arg_type, AttributesPtr arg_attrs2) - : UnaryExpr(EXPR_TABLE_CONSTRUCTOR, expand_op(constructor_list, arg_type)) - { - if ( IsError() ) - return; + std::unique_ptr> arg_attrs, TypePtr arg_type, + AttributesPtr arg_attrs2) + : UnaryExpr(EXPR_TABLE_CONSTRUCTOR, expand_op(constructor_list, arg_type)) { + if ( IsError() ) + return; - if ( arg_type ) - { - if ( ! arg_type->IsTable() ) - { - Error("bad table constructor type", arg_type.get()); - SetError(); - return; - } + if ( arg_type ) { + if ( ! arg_type->IsTable() ) { + Error("bad table constructor type", arg_type.get()); + SetError(); + return; + } - SetType(std::move(arg_type)); - } - else - { - if ( op->AsListExpr()->Exprs().empty() ) - SetType( - make_intrusive(make_intrusive(base_type(TYPE_ANY)), nullptr)); - else - { - SetType(init_type(op)); + SetType(std::move(arg_type)); + } + else { + if ( op->AsListExpr()->Exprs().empty() ) + SetType(make_intrusive(make_intrusive(base_type(TYPE_ANY)), nullptr)); + else { + SetType(init_type(op)); - if ( ! type ) - { - SetError(); - return; - } + if ( ! type ) { + SetError(); + return; + } - else if ( type->Tag() != TYPE_TABLE || type->AsTableType()->IsSet() ) - { - SetError("values in table(...) constructor do not specify a table"); - return; - } - } - } + else if ( type->Tag() != TYPE_TABLE || type->AsTableType()->IsSet() ) { + SetError("values in table(...) constructor do not specify a table"); + return; + } + } + } - if ( arg_attrs ) - SetAttrs(make_intrusive(std::move(*arg_attrs), type, false, false)); - else - SetAttrs(arg_attrs2); + if ( arg_attrs ) + SetAttrs(make_intrusive(std::move(*arg_attrs), type, false, false)); + else + SetAttrs(arg_attrs2); - const auto& indices = type->AsTableType()->GetIndices()->GetTypes(); - const ExprPList& cle = op->AsListExpr()->Exprs(); + const auto& indices = type->AsTableType()->GetIndices()->GetTypes(); + const ExprPList& cle = op->AsListExpr()->Exprs(); - // check and promote all assign expressions in ctor list - for ( const auto& expr : cle ) - { - if ( expr->Tag() != EXPR_ASSIGN ) - { - expr->Error("illegal table constructor element"); - SetError(); - return; - } + // check and promote all assign expressions in ctor list + for ( const auto& expr : cle ) { + if ( expr->Tag() != EXPR_ASSIGN ) { + expr->Error("illegal table constructor element"); + SetError(); + return; + } - auto idx_expr = expr->AsAssignExpr()->GetOp1(); - auto val_expr = expr->AsAssignExpr()->GetOp2(); - auto yield_type = GetType()->AsTableType()->Yield(); + auto idx_expr = expr->AsAssignExpr()->GetOp1(); + auto val_expr = expr->AsAssignExpr()->GetOp2(); + auto yield_type = GetType()->AsTableType()->Yield(); - if ( idx_expr->Tag() != EXPR_LIST ) - { - expr->Error("table constructor index is not a list"); - SetError(); - return; - } + if ( idx_expr->Tag() != EXPR_LIST ) { + expr->Error("table constructor index is not a list"); + SetError(); + return; + } - // Promote LHS - ExprPList& idx_exprs = idx_expr->AsListExpr()->Exprs(); + // Promote LHS + ExprPList& idx_exprs = idx_expr->AsListExpr()->Exprs(); - if ( idx_exprs.length() != static_cast(indices.size()) ) - continue; + if ( idx_exprs.length() != static_cast(indices.size()) ) + continue; - loop_over_list(idx_exprs, j) - { - ExprPtr idx = {NewRef{}, idx_exprs[j]}; + loop_over_list(idx_exprs, j) { + ExprPtr idx = {NewRef{}, idx_exprs[j]}; - auto promoted_idx = check_and_promote_expr(idx, indices[j]); + auto promoted_idx = check_and_promote_expr(idx, indices[j]); - if ( promoted_idx ) - { - if ( promoted_idx != idx ) - Unref(idx_exprs.replace(j, promoted_idx.release())); + if ( promoted_idx ) { + if ( promoted_idx != idx ) + Unref(idx_exprs.replace(j, promoted_idx.release())); - continue; - } + continue; + } - ExprError("inconsistent types in table constructor"); - return; - } + ExprError("inconsistent types in table constructor"); + return; + } - // Promote RHS - if ( auto promoted_val = check_and_promote_expr(val_expr, yield_type) ) - { - if ( promoted_val != val_expr ) - expr->AsAssignExpr()->SetOp2(promoted_val); - } - else - { - ExprError("inconsistent types in table constructor"); - return; - } - } - } + // Promote RHS + if ( auto promoted_val = check_and_promote_expr(val_expr, yield_type) ) { + if ( promoted_val != val_expr ) + expr->AsAssignExpr()->SetOp2(promoted_val); + } + else { + ExprError("inconsistent types in table constructor"); + return; + } + } +} -TraversalCode TableConstructorExpr::Traverse(TraversalCallback* cb) const - { - TraversalCode tc = cb->PreExpr(this); - HANDLE_TC_EXPR_PRE(tc); +TraversalCode TableConstructorExpr::Traverse(TraversalCallback* cb) const { + TraversalCode tc = cb->PreExpr(this); + HANDLE_TC_EXPR_PRE(tc); - tc = op->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); + tc = op->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); - if ( attrs ) - { - tc = attrs->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); - } + if ( attrs ) { + tc = attrs->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + } - tc = cb->PostExpr(this); - HANDLE_TC_EXPR_POST(tc); - } + tc = cb->PostExpr(this); + HANDLE_TC_EXPR_POST(tc); +} -ValPtr TableConstructorExpr::Eval(Frame* f) const - { - if ( IsError() ) - return nullptr; +ValPtr TableConstructorExpr::Eval(Frame* f) const { + if ( IsError() ) + return nullptr; - auto tv = make_intrusive(GetType(), attrs); - const ExprPList& exprs = op->AsListExpr()->Exprs(); + auto tv = make_intrusive(GetType(), attrs); + const ExprPList& exprs = op->AsListExpr()->Exprs(); - for ( const auto& expr : exprs ) - { - auto op1 = expr->GetOp1(); - auto op2 = expr->GetOp2(); + for ( const auto& expr : exprs ) { + auto op1 = expr->GetOp1(); + auto op2 = expr->GetOp2(); - if ( ! op1 || ! op2 ) - return nullptr; + if ( ! op1 || ! op2 ) + return nullptr; - auto index = op1->Eval(f); - auto v = op2->Eval(f); + auto index = op1->Eval(f); + auto v = op2->Eval(f); - if ( ! index || ! v ) - return nullptr; + if ( ! index || ! v ) + return nullptr; - if ( ! tv->Assign(std::move(index), std::move(v)) ) - RuntimeError("type clash in table assignment"); - } + if ( ! tv->Assign(std::move(index), std::move(v)) ) + RuntimeError("type clash in table assignment"); + } - tv->InitDefaultFunc(f); + tv->InitDefaultFunc(f); - return tv; - } + return tv; +} -void TableConstructorExpr::ExprDescribe(ODesc* d) const - { - d->Add("table("); - op->Describe(d); - d->Add(")"); +void TableConstructorExpr::ExprDescribe(ODesc* d) const { + d->Add("table("); + op->Describe(d); + d->Add(")"); - if ( attrs ) - attrs->Describe(d); - } + if ( attrs ) + attrs->Describe(d); +} -SetConstructorExpr::SetConstructorExpr(ListExprPtr constructor_list, - std::unique_ptr> arg_attrs, +SetConstructorExpr::SetConstructorExpr(ListExprPtr constructor_list, std::unique_ptr> arg_attrs, TypePtr arg_type, AttributesPtr arg_attrs2) - : UnaryExpr(EXPR_SET_CONSTRUCTOR, expand_op(std::move(constructor_list), arg_type)) - { - if ( IsError() ) - return; + : UnaryExpr(EXPR_SET_CONSTRUCTOR, expand_op(std::move(constructor_list), arg_type)) { + if ( IsError() ) + return; - if ( arg_type ) - { - if ( ! arg_type->IsSet() ) - { - Error("bad set constructor type", arg_type.get()); - SetError(); - return; - } + if ( arg_type ) { + if ( ! arg_type->IsSet() ) { + Error("bad set constructor type", arg_type.get()); + SetError(); + return; + } - SetType(std::move(arg_type)); - } - else - { - if ( op->AsListExpr()->Exprs().empty() ) - SetType(make_intrusive(make_intrusive(base_type(TYPE_ANY)), - nullptr)); - else - SetType(init_type(op)); - } + SetType(std::move(arg_type)); + } + else { + if ( op->AsListExpr()->Exprs().empty() ) + SetType(make_intrusive(make_intrusive(base_type(TYPE_ANY)), nullptr)); + else + SetType(init_type(op)); + } - if ( ! type ) - SetError(); + if ( ! type ) + SetError(); - else if ( type->Tag() != TYPE_TABLE || ! type->AsTableType()->IsSet() ) - SetError("values in set(...) constructor do not specify a set"); + else if ( type->Tag() != TYPE_TABLE || ! type->AsTableType()->IsSet() ) + SetError("values in set(...) constructor do not specify a set"); - if ( arg_attrs ) - SetAttrs(make_intrusive(std::move(*arg_attrs), type, false, false)); - else - SetAttrs(std::move(arg_attrs2)); + if ( arg_attrs ) + SetAttrs(make_intrusive(std::move(*arg_attrs), type, false, false)); + else + SetAttrs(std::move(arg_attrs2)); - const auto& indices = type->AsTableType()->GetIndices()->GetTypes(); - ExprPList& cle = op->AsListExpr()->Exprs(); + const auto& indices = type->AsTableType()->GetIndices()->GetTypes(); + ExprPList& cle = op->AsListExpr()->Exprs(); - if ( indices.size() == 1 ) - { - if ( ! check_and_promote_exprs_to_type(op->AsListExpr(), indices[0]) ) - ExprError("inconsistent type in set constructor"); - } + if ( indices.size() == 1 ) { + if ( ! check_and_promote_exprs_to_type(op->AsListExpr(), indices[0]) ) + ExprError("inconsistent type in set constructor"); + } - else if ( indices.size() > 1 ) - { - // Check/promote each expression in composite index. - loop_over_list(cle, i) - { - Expr* ce = cle[i]; + else if ( indices.size() > 1 ) { + // Check/promote each expression in composite index. + loop_over_list(cle, i) { + Expr* ce = cle[i]; - if ( ce->Tag() != EXPR_LIST ) - { - ce->Error("not a list of indices"); - SetError(); - return; - } + if ( ce->Tag() != EXPR_LIST ) { + ce->Error("not a list of indices"); + SetError(); + return; + } - ListExpr* le = ce->AsListExpr(); + ListExpr* le = ce->AsListExpr(); - if ( check_and_promote_exprs(le, type->AsTableType()->GetIndices()) ) - { - if ( le != cle[i] ) - cle.replace(i, le); + if ( check_and_promote_exprs(le, type->AsTableType()->GetIndices()) ) { + if ( le != cle[i] ) + cle.replace(i, le); - continue; - } + continue; + } - ExprError("inconsistent types in set constructor"); - } - } - } + ExprError("inconsistent types in set constructor"); + } + } +} -TraversalCode SetConstructorExpr::Traverse(TraversalCallback* cb) const - { - TraversalCode tc = cb->PreExpr(this); - HANDLE_TC_EXPR_PRE(tc); +TraversalCode SetConstructorExpr::Traverse(TraversalCallback* cb) const { + TraversalCode tc = cb->PreExpr(this); + HANDLE_TC_EXPR_PRE(tc); - tc = op->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); + tc = op->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); - if ( attrs ) - { - tc = attrs->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); - } + if ( attrs ) { + tc = attrs->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + } - tc = cb->PostExpr(this); - HANDLE_TC_EXPR_POST(tc); - } + tc = cb->PostExpr(this); + HANDLE_TC_EXPR_POST(tc); +} -ValPtr SetConstructorExpr::Eval(Frame* f) const - { - if ( IsError() ) - return nullptr; +ValPtr SetConstructorExpr::Eval(Frame* f) const { + if ( IsError() ) + return nullptr; - auto aggr = make_intrusive(IntrusivePtr{NewRef{}, type->AsTableType()}, attrs); - const ExprPList& exprs = op->AsListExpr()->Exprs(); + auto aggr = make_intrusive(IntrusivePtr{NewRef{}, type->AsTableType()}, attrs); + const ExprPList& exprs = op->AsListExpr()->Exprs(); - for ( const auto& expr : exprs ) - { - auto element = expr->Eval(f); - aggr->Assign(std::move(element), nullptr); - } + for ( const auto& expr : exprs ) { + auto element = expr->Eval(f); + aggr->Assign(std::move(element), nullptr); + } - return aggr; - } + return aggr; +} -void SetConstructorExpr::ExprDescribe(ODesc* d) const - { - d->Add("set("); - op->Describe(d); - d->Add(")"); +void SetConstructorExpr::ExprDescribe(ODesc* d) const { + d->Add("set("); + op->Describe(d); + d->Add(")"); - if ( attrs ) - attrs->Describe(d); - } + if ( attrs ) + attrs->Describe(d); +} VectorConstructorExpr::VectorConstructorExpr(ListExprPtr constructor_list, TypePtr arg_type) - : UnaryExpr(EXPR_VECTOR_CONSTRUCTOR, std::move(constructor_list)) - { - if ( IsError() ) - return; + : UnaryExpr(EXPR_VECTOR_CONSTRUCTOR, std::move(constructor_list)) { + if ( IsError() ) + return; - if ( arg_type ) - { - if ( arg_type->Tag() != TYPE_VECTOR ) - { - Error("bad vector constructor type", arg_type.get()); - SetError(); - return; - } + if ( arg_type ) { + if ( arg_type->Tag() != TYPE_VECTOR ) { + Error("bad vector constructor type", arg_type.get()); + SetError(); + return; + } - SetType(std::move(arg_type)); - } - else - { - if ( op->AsListExpr()->Exprs().empty() ) - { - // vector(). - // By default, assign VOID type here. A vector with - // void type set is seen as an unspecified vector. - SetType(make_intrusive(base_type(TYPE_VOID))); - return; - } + SetType(std::move(arg_type)); + } + else { + if ( op->AsListExpr()->Exprs().empty() ) { + // vector(). + // By default, assign VOID type here. A vector with + // void type set is seen as an unspecified vector. + SetType(make_intrusive(base_type(TYPE_VOID))); + return; + } - if ( auto t = maximal_type(op->AsListExpr()) ) - SetType(make_intrusive(std::move(t))); - else - { - SetError(); - return; - } - } + if ( auto t = maximal_type(op->AsListExpr()) ) + SetType(make_intrusive(std::move(t))); + else { + SetError(); + return; + } + } - if ( ! check_and_promote_exprs_to_type(op->AsListExpr(), type->AsVectorType()->Yield()) ) - ExprError("inconsistent types in vector constructor"); - } + if ( ! check_and_promote_exprs_to_type(op->AsListExpr(), type->AsVectorType()->Yield()) ) + ExprError("inconsistent types in vector constructor"); +} -ValPtr VectorConstructorExpr::Eval(Frame* f) const - { - if ( IsError() ) - return nullptr; +ValPtr VectorConstructorExpr::Eval(Frame* f) const { + if ( IsError() ) + return nullptr; - auto vec = make_intrusive(GetType()); - const ExprPList& exprs = op->AsListExpr()->Exprs(); + auto vec = make_intrusive(GetType()); + const ExprPList& exprs = op->AsListExpr()->Exprs(); - loop_over_list(exprs, i) - { - Expr* e = exprs[i]; + loop_over_list(exprs, i) { + Expr* e = exprs[i]; - if ( ! vec->Assign(i, e->Eval(f)) ) - { - RuntimeError(util::fmt("type mismatch at index %d", i)); - return nullptr; - } - } + if ( ! vec->Assign(i, e->Eval(f)) ) { + RuntimeError(util::fmt("type mismatch at index %d", i)); + return nullptr; + } + } - return vec; - } + return vec; +} -void VectorConstructorExpr::ExprDescribe(ODesc* d) const - { - d->Add("vector("); - op->Describe(d); - d->Add(")"); - } +void VectorConstructorExpr::ExprDescribe(ODesc* d) const { + d->Add("vector("); + op->Describe(d); + d->Add(")"); +} FieldAssignExpr::FieldAssignExpr(const char* arg_field_name, ExprPtr value) - : UnaryExpr(EXPR_FIELD_ASSIGN, std::move(value)), field_name(arg_field_name) - { - SetType(op->GetType()); - } + : UnaryExpr(EXPR_FIELD_ASSIGN, std::move(value)), field_name(arg_field_name) { + SetType(op->GetType()); +} -bool FieldAssignExpr::PromoteTo(TypePtr t) - { - op = check_and_promote_expr(op, t); - return op != nullptr; - } +bool FieldAssignExpr::PromoteTo(TypePtr t) { + op = check_and_promote_expr(op, t); + return op != nullptr; +} -bool FieldAssignExpr::IsRecordElement(TypeDecl* td) const - { - if ( td ) - { - td->type = op->GetType(); - td->id = util::copy_string(field_name.c_str()); - } +bool FieldAssignExpr::IsRecordElement(TypeDecl* td) const { + if ( td ) { + td->type = op->GetType(); + td->id = util::copy_string(field_name.c_str()); + } - return true; - } + return true; +} -void FieldAssignExpr::ExprDescribe(ODesc* d) const - { - d->Add("$"); - d->Add(FieldName()); - d->Add("="); +void FieldAssignExpr::ExprDescribe(ODesc* d) const { + d->Add("$"); + d->Add(FieldName()); + d->Add("="); - if ( op ) - op->Describe(d); - else - d->Add(""); - } + if ( op ) + op->Describe(d); + else + d->Add(""); +} -ArithCoerceExpr::ArithCoerceExpr(ExprPtr arg_op, TypeTag t) - : UnaryExpr(EXPR_ARITH_COERCE, std::move(arg_op)) - { - if ( IsError() ) - return; +ArithCoerceExpr::ArithCoerceExpr(ExprPtr arg_op, TypeTag t) : UnaryExpr(EXPR_ARITH_COERCE, std::move(arg_op)) { + if ( IsError() ) + return; - TypeTag bt = op->GetType()->Tag(); - TypeTag vbt = bt; + TypeTag bt = op->GetType()->Tag(); + TypeTag vbt = bt; - if ( IsVector(bt) ) - { - SetType(make_intrusive(base_type(t))); - vbt = op->GetType()->AsVectorType()->Yield()->Tag(); - } - else - SetType(base_type(t)); + if ( IsVector(bt) ) { + SetType(make_intrusive(base_type(t))); + vbt = op->GetType()->AsVectorType()->Yield()->Tag(); + } + else + SetType(base_type(t)); - if ( (bt == TYPE_ENUM) != (t == TYPE_ENUM) ) - ExprError("can't convert to/from enumerated type"); + if ( (bt == TYPE_ENUM) != (t == TYPE_ENUM) ) + ExprError("can't convert to/from enumerated type"); - else if ( ! IsArithmetic(t) && ! IsBool(t) && t != TYPE_TIME && t != TYPE_INTERVAL ) - ExprError("bad coercion"); + else if ( ! IsArithmetic(t) && ! IsBool(t) && t != TYPE_TIME && t != TYPE_INTERVAL ) + ExprError("bad coercion"); - else if ( ! IsArithmetic(bt) && ! IsBool(bt) && ! IsArithmetic(vbt) && ! IsBool(vbt) ) - ExprError("bad coercion value"); - } + else if ( ! IsArithmetic(bt) && ! IsBool(bt) && ! IsArithmetic(vbt) && ! IsBool(vbt) ) + ExprError("bad coercion value"); +} -ValPtr ArithCoerceExpr::FoldSingleVal(ValPtr v, const TypePtr& t) const - { - return check_and_promote(v, t, false, location); - } +ValPtr ArithCoerceExpr::FoldSingleVal(ValPtr v, const TypePtr& t) const { + return check_and_promote(v, t, false, location); +} -ValPtr ArithCoerceExpr::Fold(Val* v) const - { - auto t = GetType(); +ValPtr ArithCoerceExpr::Fold(Val* v) const { + auto t = GetType(); - if ( ! is_vector(v) ) - { - // Our result type might be vector, in which case this - // invocation is being done per-element rather than on - // the whole vector. Correct the type if so. - if ( type->Tag() == TYPE_VECTOR ) - t = t->AsVectorType()->Yield(); + if ( ! is_vector(v) ) { + // Our result type might be vector, in which case this + // invocation is being done per-element rather than on + // the whole vector. Correct the type if so. + if ( type->Tag() == TYPE_VECTOR ) + t = t->AsVectorType()->Yield(); - return FoldSingleVal({NewRef{}, v}, t); - } + return FoldSingleVal({NewRef{}, v}, t); + } - VectorVal* vv = v->AsVectorVal(); - auto result = make_intrusive(cast_intrusive(t)); + VectorVal* vv = v->AsVectorVal(); + auto result = make_intrusive(cast_intrusive(t)); - auto yt = t->AsVectorType()->Yield(); + auto yt = t->AsVectorType()->Yield(); - for ( unsigned int i = 0; i < vv->Size(); ++i ) - { - auto elt = vv->ValAt(i); - if ( elt ) - result->Assign(i, FoldSingleVal(elt, yt)); - else - result->Assign(i, nullptr); - } + for ( unsigned int i = 0; i < vv->Size(); ++i ) { + auto elt = vv->ValAt(i); + if ( elt ) + result->Assign(i, FoldSingleVal(elt, yt)); + else + result->Assign(i, nullptr); + } - return result; - } + return result; +} // Returns true if the record type or any of its fields have an error. -static bool record_type_has_errors(const RecordType* rt) - { - if ( IsErrorType(rt->Tag()) ) - return true; +static bool record_type_has_errors(const RecordType* rt) { + if ( IsErrorType(rt->Tag()) ) + return true; - if ( rt->NumFields() > 0 ) - for ( const auto* td : *rt->Types() ) - if ( IsErrorType(td->type->Tag()) ) - return true; + if ( rt->NumFields() > 0 ) + for ( const auto* td : *rt->Types() ) + if ( IsErrorType(td->type->Tag()) ) + return true; - return false; - } + return false; +} -RecordCoerceExpr::RecordCoerceExpr(ExprPtr arg_op, RecordTypePtr r) - : UnaryExpr(EXPR_RECORD_COERCE, std::move(arg_op)) - { - if ( IsError() ) - return; +RecordCoerceExpr::RecordCoerceExpr(ExprPtr arg_op, RecordTypePtr r) : UnaryExpr(EXPR_RECORD_COERCE, std::move(arg_op)) { + if ( IsError() ) + return; - SetType(std::move(r)); + SetType(std::move(r)); - if ( GetType()->Tag() != TYPE_RECORD ) - ExprError("coercion to non-record"); + if ( GetType()->Tag() != TYPE_RECORD ) + ExprError("coercion to non-record"); - else if ( op->GetType()->Tag() != TYPE_RECORD ) - ExprError("coercion of non-record to record"); + else if ( op->GetType()->Tag() != TYPE_RECORD ) + ExprError("coercion of non-record to record"); - else - { - RecordType* t_r = type->AsRecordType(); - RecordType* sub_r = op->GetType()->AsRecordType(); + else { + RecordType* t_r = type->AsRecordType(); + RecordType* sub_r = op->GetType()->AsRecordType(); - if ( record_type_has_errors(t_r) || record_type_has_errors(sub_r) ) - { - SetError(); - return; - } + if ( record_type_has_errors(t_r) || record_type_has_errors(sub_r) ) { + SetError(); + return; + } - int map_size = t_r->NumFields(); - map.resize(map_size, -1); // -1 = field is not mapped + int map_size = t_r->NumFields(); + map.resize(map_size, -1); // -1 = field is not mapped - int i; - for ( i = 0; i < sub_r->NumFields(); ++i ) - { - int t_i = t_r->FieldOffset(sub_r->FieldName(i)); - if ( t_i < 0 ) - { - ExprError( - util::fmt("orphaned field \"%s\" in record coercion", sub_r->FieldName(i))); - break; - } + int i; + for ( i = 0; i < sub_r->NumFields(); ++i ) { + int t_i = t_r->FieldOffset(sub_r->FieldName(i)); + if ( t_i < 0 ) { + ExprError(util::fmt("orphaned field \"%s\" in record coercion", sub_r->FieldName(i))); + break; + } - const auto& sub_t_i = sub_r->GetFieldType(i); - const auto& sup_t_i = t_r->GetFieldType(t_i); + const auto& sub_t_i = sub_r->GetFieldType(i); + const auto& sup_t_i = t_r->GetFieldType(t_i); - if ( ! same_type(sup_t_i, sub_t_i) ) - { - auto is_arithmetic_promotable = [](zeek::Type* sup, zeek::Type* sub) -> bool - { - auto sup_tag = sup->Tag(); - auto sub_tag = sub->Tag(); + if ( ! same_type(sup_t_i, sub_t_i) ) { + auto is_arithmetic_promotable = [](zeek::Type* sup, zeek::Type* sub) -> bool { + auto sup_tag = sup->Tag(); + auto sub_tag = sub->Tag(); - if ( ! BothArithmetic(sup_tag, sub_tag) ) - return false; + if ( ! BothArithmetic(sup_tag, sub_tag) ) + return false; - if ( sub_tag == TYPE_DOUBLE && IsIntegral(sup_tag) ) - return false; + if ( sub_tag == TYPE_DOUBLE && IsIntegral(sup_tag) ) + return false; - if ( sub_tag == TYPE_INT && sup_tag == TYPE_COUNT ) - return false; + if ( sub_tag == TYPE_INT && sup_tag == TYPE_COUNT ) + return false; - return true; - }; + return true; + }; - auto is_record_promotable = [](zeek::Type* sup, zeek::Type* sub) -> bool - { - if ( sup->Tag() != TYPE_RECORD ) - return false; + auto is_record_promotable = [](zeek::Type* sup, zeek::Type* sub) -> bool { + if ( sup->Tag() != TYPE_RECORD ) + return false; - if ( sub->Tag() != TYPE_RECORD ) - return false; + if ( sub->Tag() != TYPE_RECORD ) + return false; - return record_promotion_compatible(sup->AsRecordType(), sub->AsRecordType()); - }; + return record_promotion_compatible(sup->AsRecordType(), sub->AsRecordType()); + }; - if ( ! is_arithmetic_promotable(sup_t_i.get(), sub_t_i.get()) && - ! is_record_promotable(sup_t_i.get(), sub_t_i.get()) ) - { - std::string error_msg = util::fmt("type clash for field \"%s\"", - sub_r->FieldName(i)); - Error(error_msg.c_str(), sub_t_i.get()); - SetError(); - break; - } - } + if ( ! is_arithmetic_promotable(sup_t_i.get(), sub_t_i.get()) && + ! is_record_promotable(sup_t_i.get(), sub_t_i.get()) ) { + std::string error_msg = util::fmt("type clash for field \"%s\"", sub_r->FieldName(i)); + Error(error_msg.c_str(), sub_t_i.get()); + SetError(); + break; + } + } - map[t_i] = i; - } + map[t_i] = i; + } - if ( IsError() ) - return; + if ( IsError() ) + return; - for ( i = 0; i < map_size; ++i ) - { - if ( map[i] == -1 ) - { - if ( ! t_r->FieldDecl(i)->GetAttr(ATTR_OPTIONAL) ) - { - std::string error_msg = util::fmt("non-optional field \"%s\" missing", - t_r->FieldName(i)); - Error(error_msg.c_str()); - SetError(); - break; - } - } - else if ( t_r->IsFieldDeprecated(i) ) - report_field_deprecation(t_r, this, i); - } - } - } + for ( i = 0; i < map_size; ++i ) { + if ( map[i] == -1 ) { + if ( ! t_r->FieldDecl(i)->GetAttr(ATTR_OPTIONAL) ) { + std::string error_msg = util::fmt("non-optional field \"%s\" missing", t_r->FieldName(i)); + Error(error_msg.c_str()); + SetError(); + break; + } + } + else if ( t_r->IsFieldDeprecated(i) ) + report_field_deprecation(t_r, this, i); + } + } +} -ValPtr RecordCoerceExpr::Fold(Val* v) const - { - if ( same_type(GetType(), Op()->GetType()) ) - return IntrusivePtr{NewRef{}, v}; +ValPtr RecordCoerceExpr::Fold(Val* v) const { + if ( same_type(GetType(), Op()->GetType()) ) + return IntrusivePtr{NewRef{}, v}; - auto rt = cast_intrusive(GetType()); - return coerce_to_record(rt, v, map); - } + auto rt = cast_intrusive(GetType()); + return coerce_to_record(rt, v, map); +} -RecordValPtr coerce_to_record(RecordTypePtr rt, Val* v, const std::vector& map) - { - int map_size = map.size(); - auto val = make_intrusive(rt); - RecordType* val_type = val->GetType()->AsRecordType(); +RecordValPtr coerce_to_record(RecordTypePtr rt, Val* v, const std::vector& map) { + int map_size = map.size(); + auto val = make_intrusive(rt); + RecordType* val_type = val->GetType()->AsRecordType(); - RecordVal* rv = v->AsRecordVal(); + RecordVal* rv = v->AsRecordVal(); - for ( int i = 0; i < map_size; ++i ) - { - if ( map[i] >= 0 ) - { - auto rhs = rv->GetField(map[i]); + for ( int i = 0; i < map_size; ++i ) { + if ( map[i] >= 0 ) { + auto rhs = rv->GetField(map[i]); - if ( ! rhs ) - { - auto rv_rt = rv->GetType()->AsRecordType(); - const auto& def = rv_rt->FieldDecl(map[i])->GetAttr(ATTR_DEFAULT); + if ( ! rhs ) { + auto rv_rt = rv->GetType()->AsRecordType(); + const auto& def = rv_rt->FieldDecl(map[i])->GetAttr(ATTR_DEFAULT); - if ( def ) - rhs = def->GetExpr()->Eval(nullptr); - } + if ( def ) + rhs = def->GetExpr()->Eval(nullptr); + } - assert(rhs || rt->FieldDecl(i)->GetAttr(ATTR_OPTIONAL)); + assert(rhs || rt->FieldDecl(i)->GetAttr(ATTR_OPTIONAL)); - if ( ! rhs ) - { - // Optional field is missing. - val->Remove(i); - continue; - } + if ( ! rhs ) { + // Optional field is missing. + val->Remove(i); + continue; + } - const auto& rhs_type = rhs->GetType(); - const auto& field_type = val_type->GetFieldType(i); + const auto& rhs_type = rhs->GetType(); + const auto& field_type = val_type->GetFieldType(i); - if ( rhs_type->Tag() == TYPE_RECORD && field_type->Tag() == TYPE_RECORD && - ! same_type(rhs_type, field_type) ) - { - if ( auto new_val = rhs->AsRecordVal()->CoerceTo( - cast_intrusive(field_type)) ) - rhs = std::move(new_val); - } - else if ( rhs_type->Tag() == TYPE_VECTOR && field_type->Tag() == TYPE_VECTOR && - rhs_type->AsVectorType()->IsUnspecifiedVector() ) - { - auto rhs_v = rhs->AsVectorVal(); - if ( ! rhs_v->Concretize(field_type->Yield()) ) - reporter->InternalError("could not concretize empty vector"); - } - else if ( BothArithmetic(rhs_type->Tag(), field_type->Tag()) && - ! same_type(rhs_type, field_type) ) - { - auto new_val = check_and_promote(rhs, field_type, false); - rhs = std::move(new_val); - } + if ( rhs_type->Tag() == TYPE_RECORD && field_type->Tag() == TYPE_RECORD && + ! same_type(rhs_type, field_type) ) { + if ( auto new_val = rhs->AsRecordVal()->CoerceTo(cast_intrusive(field_type)) ) + rhs = std::move(new_val); + } + else if ( rhs_type->Tag() == TYPE_VECTOR && field_type->Tag() == TYPE_VECTOR && + rhs_type->AsVectorType()->IsUnspecifiedVector() ) { + auto rhs_v = rhs->AsVectorVal(); + if ( ! rhs_v->Concretize(field_type->Yield()) ) + reporter->InternalError("could not concretize empty vector"); + } + else if ( BothArithmetic(rhs_type->Tag(), field_type->Tag()) && ! same_type(rhs_type, field_type) ) { + auto new_val = check_and_promote(rhs, field_type, false); + rhs = std::move(new_val); + } - val->Assign(i, std::move(rhs)); - } - else - { - if ( const auto& def = rt->FieldDecl(i)->GetAttr(ATTR_DEFAULT) ) - { - auto def_val = def->GetExpr()->Eval(nullptr); - const auto& def_type = def_val->GetType(); - const auto& field_type = rt->GetFieldType(i); + val->Assign(i, std::move(rhs)); + } + else { + if ( const auto& def = rt->FieldDecl(i)->GetAttr(ATTR_DEFAULT) ) { + auto def_val = def->GetExpr()->Eval(nullptr); + const auto& def_type = def_val->GetType(); + const auto& field_type = rt->GetFieldType(i); - if ( def_type->Tag() == TYPE_RECORD && field_type->Tag() == TYPE_RECORD && - ! same_type(def_type, field_type) ) - { - auto tmp = def_val->AsRecordVal()->CoerceTo( - cast_intrusive(field_type)); + if ( def_type->Tag() == TYPE_RECORD && field_type->Tag() == TYPE_RECORD && + ! same_type(def_type, field_type) ) { + auto tmp = def_val->AsRecordVal()->CoerceTo(cast_intrusive(field_type)); - if ( tmp ) - def_val = std::move(tmp); - } + if ( tmp ) + def_val = std::move(tmp); + } - val->Assign(i, std::move(def_val)); - } - else - val->Remove(i); - } - } + val->Assign(i, std::move(def_val)); + } + else + val->Remove(i); + } + } - return val; - } + return val; +} TableCoerceExpr::TableCoerceExpr(ExprPtr arg_op, TableTypePtr tt, bool type_check) - : UnaryExpr(EXPR_TABLE_COERCE, std::move(arg_op)) - { - if ( IsError() ) - return; + : UnaryExpr(EXPR_TABLE_COERCE, std::move(arg_op)) { + if ( IsError() ) + return; - if ( type_check ) - { - op = check_and_promote_expr(op, tt); - if ( ! op ) - { - SetError(); - return; - } + if ( type_check ) { + op = check_and_promote_expr(op, tt); + if ( ! op ) { + SetError(); + return; + } - if ( op->Tag() == EXPR_TABLE_COERCE && op->GetType() == tt ) - // Avoid double-coercion. - op = op->GetOp1(); - } + if ( op->Tag() == EXPR_TABLE_COERCE && op->GetType() == tt ) + // Avoid double-coercion. + op = op->GetOp1(); + } - SetType(std::move(tt)); + SetType(std::move(tt)); - if ( GetType()->Tag() != TYPE_TABLE ) - ExprError("coercion to non-table"); + if ( GetType()->Tag() != TYPE_TABLE ) + ExprError("coercion to non-table"); - else if ( op->GetType()->Tag() != TYPE_TABLE ) - ExprError("coercion of non-table/set to table/set"); - } + else if ( op->GetType()->Tag() != TYPE_TABLE ) + ExprError("coercion of non-table/set to table/set"); +} -ValPtr TableCoerceExpr::Fold(Val* v) const - { - TableVal* tv = v->AsTableVal(); +ValPtr TableCoerceExpr::Fold(Val* v) const { + TableVal* tv = v->AsTableVal(); - if ( tv->Size() > 0 ) - RuntimeErrorWithCallStack("coercion of non-empty table/set"); + if ( tv->Size() > 0 ) + RuntimeErrorWithCallStack("coercion of non-empty table/set"); - return make_intrusive(GetType(), tv->GetAttrs()); - } + return make_intrusive(GetType(), tv->GetAttrs()); +} -VectorCoerceExpr::VectorCoerceExpr(ExprPtr arg_op, VectorTypePtr v) - : UnaryExpr(EXPR_VECTOR_COERCE, std::move(arg_op)) - { - if ( IsError() ) - return; +VectorCoerceExpr::VectorCoerceExpr(ExprPtr arg_op, VectorTypePtr v) : UnaryExpr(EXPR_VECTOR_COERCE, std::move(arg_op)) { + if ( IsError() ) + return; - SetType(std::move(v)); + SetType(std::move(v)); - if ( GetType()->Tag() != TYPE_VECTOR ) - ExprError("coercion to non-vector"); + if ( GetType()->Tag() != TYPE_VECTOR ) + ExprError("coercion to non-vector"); - else if ( op->GetType()->Tag() != TYPE_VECTOR ) - ExprError("coercion of non-vector to vector"); - } + else if ( op->GetType()->Tag() != TYPE_VECTOR ) + ExprError("coercion of non-vector to vector"); +} -ValPtr VectorCoerceExpr::Fold(Val* v) const - { - VectorVal* vv = v->AsVectorVal(); +ValPtr VectorCoerceExpr::Fold(Val* v) const { + VectorVal* vv = v->AsVectorVal(); - if ( vv->Size() > 0 ) - RuntimeErrorWithCallStack("coercion of non-empty vector"); + if ( vv->Size() > 0 ) + RuntimeErrorWithCallStack("coercion of non-empty vector"); - return make_intrusive(GetType()); - } + return make_intrusive(GetType()); +} ScheduleTimer::ScheduleTimer(const EventHandlerPtr& arg_event, Args arg_args, double t) - : Timer(t, TIMER_SCHEDULE), event(arg_event), args(std::move(arg_args)) - { - } + : Timer(t, TIMER_SCHEDULE), event(arg_event), args(std::move(arg_args)) {} -void ScheduleTimer::Dispatch(double /* t */, bool /* is_expire */) - { - if ( event ) - event_mgr.Enqueue(event, std::move(args), util::detail::SOURCE_LOCAL, 0, nullptr, - this->Time()); - } +void ScheduleTimer::Dispatch(double /* t */, bool /* is_expire */) { + if ( event ) + event_mgr.Enqueue(event, std::move(args), util::detail::SOURCE_LOCAL, 0, nullptr, this->Time()); +} ScheduleExpr::ScheduleExpr(ExprPtr arg_when, EventExprPtr arg_event) - : Expr(EXPR_SCHEDULE), when(std::move(arg_when)), event(std::move(arg_event)) - { - if ( IsError() || when->IsError() || event->IsError() ) - return; + : Expr(EXPR_SCHEDULE), when(std::move(arg_when)), event(std::move(arg_event)) { + if ( IsError() || when->IsError() || event->IsError() ) + return; - TypeTag bt = when->GetType()->Tag(); + TypeTag bt = when->GetType()->Tag(); - if ( bt != TYPE_TIME && bt != TYPE_INTERVAL ) - ExprError("schedule expression requires a time or time interval"); - } + if ( bt != TYPE_TIME && bt != TYPE_INTERVAL ) + ExprError("schedule expression requires a time or time interval"); +} -ValPtr ScheduleExpr::Eval(Frame* f) const - { - if ( run_state::terminating ) - return nullptr; +ValPtr ScheduleExpr::Eval(Frame* f) const { + if ( run_state::terminating ) + return nullptr; - auto when_val = when->Eval(f); + auto when_val = when->Eval(f); - if ( ! when_val ) - return nullptr; + if ( ! when_val ) + return nullptr; - double dt = when_val->InternalDouble(); + double dt = when_val->InternalDouble(); - if ( when->GetType()->Tag() == TYPE_INTERVAL ) - dt += run_state::network_time; + if ( when->GetType()->Tag() == TYPE_INTERVAL ) + dt += run_state::network_time; - auto args = eval_list(f, event->Args()); + auto args = eval_list(f, event->Args()); - if ( args ) - { - auto handler = event->Handler(); + if ( args ) { + auto handler = event->Handler(); - if ( etm ) - etm->ScriptEventQueued(handler); + if ( etm ) + etm->ScriptEventQueued(handler); - timer_mgr->Add(new ScheduleTimer(handler, std::move(*args), dt)); - } + timer_mgr->Add(new ScheduleTimer(handler, std::move(*args), dt)); + } - return nullptr; - } + return nullptr; +} -TraversalCode ScheduleExpr::Traverse(TraversalCallback* cb) const - { - TraversalCode tc = cb->PreExpr(this); - HANDLE_TC_EXPR_PRE(tc); +TraversalCode ScheduleExpr::Traverse(TraversalCallback* cb) const { + TraversalCode tc = cb->PreExpr(this); + HANDLE_TC_EXPR_PRE(tc); - tc = when->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); + tc = when->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); - tc = event->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); + tc = event->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); - tc = cb->PostExpr(this); - HANDLE_TC_EXPR_POST(tc); - } + tc = cb->PostExpr(this); + HANDLE_TC_EXPR_POST(tc); +} -void ScheduleExpr::ExprDescribe(ODesc* d) const - { - if ( d->IsReadable() ) - d->AddSP("schedule"); +void ScheduleExpr::ExprDescribe(ODesc* d) const { + if ( d->IsReadable() ) + d->AddSP("schedule"); - when->Describe(d); - d->SP(); + when->Describe(d); + d->SP(); - if ( d->IsReadable() ) - { - d->Add("{"); - d->PushIndent(); - event->Describe(d); - d->PopIndent(); - d->Add("}"); - } - else - event->Describe(d); - } + if ( d->IsReadable() ) { + d->Add("{"); + d->PushIndent(); + event->Describe(d); + d->PopIndent(); + d->Add("}"); + } + else + event->Describe(d); +} -InExpr::InExpr(ExprPtr arg_op1, ExprPtr arg_op2) - : BinaryExpr(EXPR_IN, std::move(arg_op1), std::move(arg_op2)) - { - if ( IsError() ) - return; +InExpr::InExpr(ExprPtr arg_op1, ExprPtr arg_op2) : BinaryExpr(EXPR_IN, std::move(arg_op1), std::move(arg_op2)) { + if ( IsError() ) + return; - if ( op1->GetType()->Tag() == TYPE_PATTERN ) - { - if ( op2->GetType()->Tag() == TYPE_STRING ) - { - SetType(base_type(TYPE_BOOL)); - return; - } - else if ( op2->GetType()->Tag() == TYPE_TABLE ) - { - // fall through to type-checking at end of function - } - else - { - op2->GetType()->Error("pattern requires string or set/table index", op1.get()); - SetError(); - return; - } - } + if ( op1->GetType()->Tag() == TYPE_PATTERN ) { + if ( op2->GetType()->Tag() == TYPE_STRING ) { + SetType(base_type(TYPE_BOOL)); + return; + } + else if ( op2->GetType()->Tag() == TYPE_TABLE ) { + // fall through to type-checking at end of function + } + else { + op2->GetType()->Error("pattern requires string or set/table index", op1.get()); + SetError(); + return; + } + } - if ( op1->GetType()->Tag() == TYPE_STRING && op2->GetType()->Tag() == TYPE_STRING ) - { - SetType(base_type(TYPE_BOOL)); - return; - } + if ( op1->GetType()->Tag() == TYPE_STRING && op2->GetType()->Tag() == TYPE_STRING ) { + SetType(base_type(TYPE_BOOL)); + return; + } - // Check for: in - // in set[subnet] - // in table[subnet] of ... - if ( op1->GetType()->Tag() == TYPE_ADDR ) - { - if ( op2->GetType()->Tag() == TYPE_SUBNET ) - { - SetType(base_type(TYPE_BOOL)); - return; - } + // Check for: in + // in set[subnet] + // in table[subnet] of ... + if ( op1->GetType()->Tag() == TYPE_ADDR ) { + if ( op2->GetType()->Tag() == TYPE_SUBNET ) { + SetType(base_type(TYPE_BOOL)); + return; + } - if ( op2->GetType()->Tag() == TYPE_TABLE && op2->GetType()->AsTableType()->IsSubNetIndex() ) - { - SetType(base_type(TYPE_BOOL)); - return; - } - } + if ( op2->GetType()->Tag() == TYPE_TABLE && op2->GetType()->AsTableType()->IsSubNetIndex() ) { + SetType(base_type(TYPE_BOOL)); + return; + } + } - if ( op1->Tag() != EXPR_LIST ) - op1 = make_intrusive(std::move(op1)); + if ( op1->Tag() != EXPR_LIST ) + op1 = make_intrusive(std::move(op1)); - ListExpr* lop1 = op1->AsListExpr(); + ListExpr* lop1 = op1->AsListExpr(); - if ( ! op2->GetType()->MatchesIndex(lop1) ) - SetError("not an index type"); - else - SetType(base_type(TYPE_BOOL)); - } + if ( ! op2->GetType()->MatchesIndex(lop1) ) + SetError("not an index type"); + else + SetType(base_type(TYPE_BOOL)); +} -ValPtr InExpr::Fold(Val* v1, Val* v2) const - { - if ( v2->GetType()->Tag() == TYPE_STRING ) - { - const String* s2 = v2->AsString(); +ValPtr InExpr::Fold(Val* v1, Val* v2) const { + if ( v2->GetType()->Tag() == TYPE_STRING ) { + const String* s2 = v2->AsString(); - if ( v1->GetType()->Tag() == TYPE_PATTERN ) - { - auto re = v1->As(); - return val_mgr->Bool(re->MatchAnywhere(s2) != 0); - } + if ( v1->GetType()->Tag() == TYPE_PATTERN ) { + auto re = v1->As(); + return val_mgr->Bool(re->MatchAnywhere(s2) != 0); + } - const String* s1 = v1->AsString(); + const String* s1 = v1->AsString(); - // Could do better here e.g. Boyer-Moore if done repeatedly. - auto s = reinterpret_cast(s1->CheckString()); - auto res = util::strstr_n(s2->Len(), s2->Bytes(), s1->Len(), s) != -1; - return val_mgr->Bool(res); - } + // Could do better here e.g. Boyer-Moore if done repeatedly. + auto s = reinterpret_cast(s1->CheckString()); + auto res = util::strstr_n(s2->Len(), s2->Bytes(), s1->Len(), s) != -1; + return val_mgr->Bool(res); + } - if ( v1->GetType()->Tag() == TYPE_ADDR && v2->GetType()->Tag() == TYPE_SUBNET ) - return val_mgr->Bool(v2->AsSubNetVal()->Contains(v1->AsAddr())); + if ( v1->GetType()->Tag() == TYPE_ADDR && v2->GetType()->Tag() == TYPE_SUBNET ) + return val_mgr->Bool(v2->AsSubNetVal()->Contains(v1->AsAddr())); - bool res; + bool res; - if ( is_vector(v2) ) - { - auto vv2 = v2->AsVectorVal(); - auto ind = v1->AsListVal()->Idx(0)->CoerceToUnsigned(); - res = ind < vv2->Size() && vv2->ValAt(ind); - } - else - res = (bool)v2->AsTableVal()->Find({NewRef{}, v1}); + if ( is_vector(v2) ) { + auto vv2 = v2->AsVectorVal(); + auto ind = v1->AsListVal()->Idx(0)->CoerceToUnsigned(); + res = ind < vv2->Size() && vv2->ValAt(ind); + } + else + res = (bool)v2->AsTableVal()->Find({NewRef{}, v1}); - return val_mgr->Bool(res); - } + return val_mgr->Bool(res); +} CallExpr::CallExpr(ExprPtr arg_func, ListExprPtr arg_args, bool in_hook, bool _in_when) - : Expr(EXPR_CALL), func(std::move(arg_func)), args(std::move(arg_args)), in_when(_in_when) - { - if ( func->IsError() || args->IsError() ) - { - SetError(); - return; - } - - const auto& func_type = func->GetType(); - - if ( ! IsFunc(func_type->Tag()) ) - { - func->Error("not a function"); - SetError(); - return; - } - - if ( func_type->AsFuncType()->Flavor() == FUNC_FLAVOR_HOOK && ! in_hook ) - { - func->Error("hook cannot be called directly, use hook operator"); - SetError(); - return; - } - - if ( record_type_has_errors(func_type->AsFuncType()->Params()->AsRecordType()) ) - SetError(); - else if ( ! func_type->MatchesIndex(args.get()) ) - SetError("argument type mismatch in function call"); - else - { - const auto& yield = func_type->Yield(); - - if ( ! yield ) - { - switch ( func_type->AsFuncType()->Flavor() ) - { - - case FUNC_FLAVOR_FUNCTION: - Error("function has no yield type"); - SetError(); - break; - - case FUNC_FLAVOR_EVENT: - Error("event called in expression, use event statement instead"); - SetError(); - break; - - case FUNC_FLAVOR_HOOK: - Error("hook has no yield type"); - SetError(); - break; - - default: - Error("invalid function flavor"); - SetError(); - break; - } - } - else - SetType(yield); - - // Check for call to built-ins that can be statically analyzed. - ValPtr func_val; - - if ( func->Tag() == EXPR_NAME && - // This is cheating, but without it processing gets - // quite confused regarding "value used but not set" - // run-time errors when we apply this analysis during - // parsing. Really we should instead do it after we've - // parsed the entire set of scripts. - util::streq(((NameExpr*)func.get())->Id()->Name(), "fmt") && - // The following is needed because fmt might not yet - // be bound as a name. - did_builtin_init && (func_val = func->Eval(nullptr)) ) - { - zeek::Func* f = func_val->AsFunc(); - if ( f->GetKind() == Func::BUILTIN_FUNC && - ! check_built_in_call((BuiltinFunc*)f, this) ) - SetError(); - } - } - } - -bool CallExpr::IsPure() const - { - if ( IsError() ) - return true; - - if ( ! func->IsPure() ) - return false; - - auto func_val = func->Eval(nullptr); - - if ( ! func_val ) - return false; - - zeek::Func* f = func_val->AsFunc(); - - // Only recurse for built-in functions, as recursing on script - // functions can lead to infinite recursion if the function being - // called here happens to be recursive (either directly - // or indirectly). - bool pure = false; - - if ( f->GetKind() == Func::BUILTIN_FUNC ) - pure = f->IsPure() && args->IsPure(); - - return pure; - } - -ValPtr CallExpr::Eval(Frame* f) const - { - if ( IsError() ) - return nullptr; - - // If we are inside a trigger condition, we may have already been - // called, delayed, and then produced a result which is now cached. - // Check for that. - if ( f ) - { - if ( trigger::Trigger* trigger = f->GetTrigger() ) - { - if ( Val* v = trigger->Lookup((void*)this) ) - { - DBG_LOG(DBG_NOTIFIERS, "%s: provides cached function result", trigger->Name()); - return {NewRef{}, v}; - } - } - } - - ValPtr ret; - auto func_val = func->Eval(f); - auto v = eval_list(f, args.get()); - - if ( func_val && v ) - { - const zeek::Func* funcv = func_val->AsFunc(); - auto current_assoc = f ? f->GetTriggerAssoc() : nullptr; - - if ( f ) - f->SetCall(this); - - auto& args = *v; - ret = funcv->Invoke(&args, f); - - if ( f ) - f->SetTriggerAssoc(current_assoc); - } - - return ret; - } - -TraversalCode CallExpr::Traverse(TraversalCallback* cb) const - { - TraversalCode tc = cb->PreExpr(this); - HANDLE_TC_EXPR_PRE(tc); - - tc = func->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); - - tc = args->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); - - tc = cb->PostExpr(this); - HANDLE_TC_EXPR_POST(tc); - } - -void CallExpr::ExprDescribe(ODesc* d) const - { - func->Describe(d); - if ( d->IsReadable() ) - { - d->Add("("); - args->Describe(d); - d->Add(")"); - } - else - args->Describe(d); - } - -LambdaExpr::LambdaExpr(FunctionIngredientsPtr arg_ing, IDPList arg_outer_ids, std::string name, - StmtPtr when_parent) - : Expr(EXPR_LAMBDA) - { - ingredients = std::move(arg_ing); - outer_ids = std::move(arg_outer_ids); - - auto ingr_t = ingredients->GetID()->GetType(); - SetType(ingr_t); - captures = ingr_t->GetCaptures(); - - if ( ! CheckCaptures(std::move(when_parent)) ) - { - SetError(); - return; - } - - // Now that we've validated that the captures match the outer_ids, - // we regenerate the latter to come in the same order as the captures. - // This avoids potentially subtle bugs when doing script optimization - // where one context uses the outer_ids and another uses the captures. - if ( captures ) - { - outer_ids.clear(); - for ( auto& c : *captures ) - outer_ids.append(c.Id().get()); - } - - // Install a primary version of the function globally. This is used - // by both broker (for transmitting closures) and script optimization - // (replacing its AST body with a compiled one). - primary_func = make_intrusive(ingredients->GetID()); - primary_func->SetOuterIDs(outer_ids); - - // When we build the body, it will get updated with initialization - // statements. Update the ingredients to reflect the new body, - // and no more need for initializers. - primary_func->AddBody(*ingredients); - primary_func->SetScope(ingredients->Scope()); - ingredients->ClearInits(); - - if ( name.empty() ) - BuildName(); - else - my_name = name; - - // Install that in the current scope. - lambda_id = install_ID(my_name.c_str(), current_module.c_str(), true, false); - - // Update lamb's name - primary_func->SetName(lambda_id->Name()); - - auto v = make_intrusive(primary_func); - lambda_id->SetVal(std::move(v)); - lambda_id->SetType(ingr_t); - lambda_id->SetConst(); - - analyze_lambda(this); - } - -LambdaExpr::LambdaExpr(LambdaExpr* orig) : Expr(EXPR_LAMBDA) - { - primary_func = orig->primary_func; - ingredients = orig->ingredients; - lambda_id = orig->lambda_id; - my_name = orig->my_name; - private_captures = orig->private_captures; - - // We need to have our own copies of the outer IDs and captures so - // we can rename them when inlined. - for ( auto i : orig->outer_ids ) - outer_ids.append(i); - - if ( orig->captures ) - { - captures = std::vector{}; - for ( auto& c : *orig->captures ) - captures->push_back(c); - } - - SetType(orig->GetType()); - } - -bool LambdaExpr::CheckCaptures(StmtPtr when_parent) - { - auto desc = when_parent ? "\"when\" statement" : "lambda"; - - if ( ! captures ) - { - if ( outer_ids.size() > 0 ) - { - reporter->Error("%s uses outer identifiers without [] captures: %s%s", desc, - outer_ids.size() > 1 ? "e.g., " : "", outer_ids[0]->Name()); - return false; - } - - return true; - } - - std::set outer_is_matched; - std::set capture_is_matched; - - for ( const auto& c : *captures ) - { - auto cid = c.Id().get(); - - if ( ! cid ) - // This happens for undefined/inappropriate - // identifiers listed in captures. There's - // already been an error message. - continue; - - if ( capture_is_matched.count(cid) > 0 ) - { - auto msg = util::fmt("%s listed multiple times in capture", cid->Name()); - if ( when_parent ) - when_parent->Error(msg); - else - ExprError(msg); - - return false; - } - - for ( auto id : outer_ids ) - if ( cid == id ) - { - outer_is_matched.insert(id); - capture_is_matched.insert(cid); - break; - } - } - - for ( auto id : outer_ids ) - if ( outer_is_matched.count(id) == 0 ) - { - auto msg = util::fmt("%s is used inside %s but not captured", id->Name(), desc); - if ( when_parent ) - when_parent->Error(msg); - else - ExprError(msg); - - return false; - } - - for ( const auto& c : *captures ) - { - auto cid = c.Id().get(); - if ( cid && capture_is_matched.count(cid) == 0 ) - { - auto msg = util::fmt("%s is captured but not used inside %s", cid->Name(), desc); - if ( when_parent ) - when_parent->Error(msg); - else - ExprError(msg); - - return false; - } - } - - return true; - } - -void LambdaExpr::BuildName() - { - // Get the body's "string" representation. - ODesc d; - primary_func->Describe(&d); - - if ( captures ) - for ( auto& c : *captures ) - { - if ( c.IsDeepCopy() ) - d.AddSP("copy"); - - if ( c.Id() ) - // c.Id() will be nil for some errors - c.Id()->Describe(&d); - } - - for ( ;; ) - { - hash128_t h; - KeyedHash::Hash128(d.Bytes(), d.Len(), &h); - - my_name = "lambda_<" + std::to_string(h[0]) + ">"; - auto fullname = make_full_var_name(current_module.data(), my_name.data()); - const auto& id = current_scope()->Find(fullname); - - if ( id ) - // Just try again to make a unique lambda name. - // If two peer processes need to agree on the same - // lambda name, this assumes they're loading the same - // scripts and thus have the same hash collisions. - d.Add(" "); - else - break; - } - } - -ScopePtr LambdaExpr::GetScope() const - { - return ingredients->Scope(); - } - -void LambdaExpr::ReplaceBody(StmtPtr new_body) - { - ingredients->ReplaceBody(std::move(new_body)); - } - -ValPtr LambdaExpr::Eval(Frame* f) const - { - auto lamb = make_intrusive(ingredients->GetID()); - - // Use the primary function as the source of the frame size - // and function body, rather than the ingredients, since script - // optimization might have changed the former but not the latter. - lamb->SetFrameSize(primary_func->FrameSize()); - StmtPtr body = primary_func->GetBodies()[0].stmts; - - if ( run_state::is_parsing ) - // We're evaluating this lambda at parse time, which happens - // for initializations. If we're doing script optimization - // then the current version of the body might be left in an - // inconsistent state (e.g., if it's replaced with ZAM code) - // causing problems if we execute this lambda subsequently. - // To avoid that problem, we duplicate the AST so it's - // distinct. - body = body->Duplicate(); - - lamb->AddBody(*ingredients, body); - lamb->CreateCaptures(f); - - // Set name to corresponding master func. - // Allows for lookups by the receiver. - lamb->SetName(my_name.c_str()); - - return make_intrusive(std::move(lamb)); - } - -void LambdaExpr::ExprDescribe(ODesc* d) const - { - d->Add(expr_name(Tag())); - - if ( captures && d->IsReadable() ) - { - d->Add("["); - - for ( auto& c : *captures ) - { - if ( &c != &(*captures)[0] ) - d->AddSP(", "); - - if ( c.IsDeepCopy() ) - d->AddSP("copy"); - - d->Add(c.Id()->Name()); - } - - d->Add("]"); - } - - ingredients->Body()->Describe(d); - } - -TraversalCode LambdaExpr::Traverse(TraversalCallback* cb) const - { - if ( IsError() ) - // Not well-formed. - return TC_CONTINUE; - - TraversalCode tc = cb->PreExpr(this); - HANDLE_TC_EXPR_PRE(tc); - - tc = lambda_id->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); - - tc = ingredients->Body()->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); - - tc = cb->PostExpr(this); - HANDLE_TC_EXPR_POST(tc); - } + : Expr(EXPR_CALL), func(std::move(arg_func)), args(std::move(arg_args)), in_when(_in_when) { + if ( func->IsError() || args->IsError() ) { + SetError(); + return; + } + + const auto& func_type = func->GetType(); + + if ( ! IsFunc(func_type->Tag()) ) { + func->Error("not a function"); + SetError(); + return; + } + + if ( func_type->AsFuncType()->Flavor() == FUNC_FLAVOR_HOOK && ! in_hook ) { + func->Error("hook cannot be called directly, use hook operator"); + SetError(); + return; + } + + if ( record_type_has_errors(func_type->AsFuncType()->Params()->AsRecordType()) ) + SetError(); + else if ( ! func_type->MatchesIndex(args.get()) ) + SetError("argument type mismatch in function call"); + else { + const auto& yield = func_type->Yield(); + + if ( ! yield ) { + switch ( func_type->AsFuncType()->Flavor() ) { + case FUNC_FLAVOR_FUNCTION: + Error("function has no yield type"); + SetError(); + break; + + case FUNC_FLAVOR_EVENT: + Error("event called in expression, use event statement instead"); + SetError(); + break; + + case FUNC_FLAVOR_HOOK: + Error("hook has no yield type"); + SetError(); + break; + + default: + Error("invalid function flavor"); + SetError(); + break; + } + } + else + SetType(yield); + + // Check for call to built-ins that can be statically analyzed. + ValPtr func_val; + + if ( func->Tag() == EXPR_NAME && + // This is cheating, but without it processing gets + // quite confused regarding "value used but not set" + // run-time errors when we apply this analysis during + // parsing. Really we should instead do it after we've + // parsed the entire set of scripts. + util::streq(((NameExpr*)func.get())->Id()->Name(), "fmt") && + // The following is needed because fmt might not yet + // be bound as a name. + did_builtin_init && (func_val = func->Eval(nullptr)) ) { + zeek::Func* f = func_val->AsFunc(); + if ( f->GetKind() == Func::BUILTIN_FUNC && ! check_built_in_call((BuiltinFunc*)f, this) ) + SetError(); + } + } +} + +bool CallExpr::IsPure() const { + if ( IsError() ) + return true; + + if ( ! func->IsPure() ) + return false; + + auto func_val = func->Eval(nullptr); + + if ( ! func_val ) + return false; + + zeek::Func* f = func_val->AsFunc(); + + // Only recurse for built-in functions, as recursing on script + // functions can lead to infinite recursion if the function being + // called here happens to be recursive (either directly + // or indirectly). + bool pure = false; + + if ( f->GetKind() == Func::BUILTIN_FUNC ) + pure = f->IsPure() && args->IsPure(); + + return pure; +} + +ValPtr CallExpr::Eval(Frame* f) const { + if ( IsError() ) + return nullptr; + + // If we are inside a trigger condition, we may have already been + // called, delayed, and then produced a result which is now cached. + // Check for that. + if ( f ) { + if ( trigger::Trigger* trigger = f->GetTrigger() ) { + if ( Val* v = trigger->Lookup((void*)this) ) { + DBG_LOG(DBG_NOTIFIERS, "%s: provides cached function result", trigger->Name()); + return {NewRef{}, v}; + } + } + } + + ValPtr ret; + auto func_val = func->Eval(f); + auto v = eval_list(f, args.get()); + + if ( func_val && v ) { + const zeek::Func* funcv = func_val->AsFunc(); + auto current_assoc = f ? f->GetTriggerAssoc() : nullptr; + + if ( f ) + f->SetCall(this); + + auto& args = *v; + ret = funcv->Invoke(&args, f); + + if ( f ) + f->SetTriggerAssoc(current_assoc); + } + + return ret; +} + +TraversalCode CallExpr::Traverse(TraversalCallback* cb) const { + TraversalCode tc = cb->PreExpr(this); + HANDLE_TC_EXPR_PRE(tc); + + tc = func->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + + tc = args->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + + tc = cb->PostExpr(this); + HANDLE_TC_EXPR_POST(tc); +} + +void CallExpr::ExprDescribe(ODesc* d) const { + func->Describe(d); + if ( d->IsReadable() ) { + d->Add("("); + args->Describe(d); + d->Add(")"); + } + else + args->Describe(d); +} + +LambdaExpr::LambdaExpr(FunctionIngredientsPtr arg_ing, IDPList arg_outer_ids, std::string name, StmtPtr when_parent) + : Expr(EXPR_LAMBDA) { + ingredients = std::move(arg_ing); + outer_ids = std::move(arg_outer_ids); + + auto ingr_t = ingredients->GetID()->GetType(); + SetType(ingr_t); + captures = ingr_t->GetCaptures(); + + if ( ! CheckCaptures(std::move(when_parent)) ) { + SetError(); + return; + } + + // Now that we've validated that the captures match the outer_ids, + // we regenerate the latter to come in the same order as the captures. + // This avoids potentially subtle bugs when doing script optimization + // where one context uses the outer_ids and another uses the captures. + if ( captures ) { + outer_ids.clear(); + for ( auto& c : *captures ) + outer_ids.append(c.Id().get()); + } + + // Install a primary version of the function globally. This is used + // by both broker (for transmitting closures) and script optimization + // (replacing its AST body with a compiled one). + primary_func = make_intrusive(ingredients->GetID()); + primary_func->SetOuterIDs(outer_ids); + + // When we build the body, it will get updated with initialization + // statements. Update the ingredients to reflect the new body, + // and no more need for initializers. + primary_func->AddBody(*ingredients); + primary_func->SetScope(ingredients->Scope()); + ingredients->ClearInits(); + + if ( name.empty() ) + BuildName(); + else + my_name = name; + + // Install that in the current scope. + lambda_id = install_ID(my_name.c_str(), current_module.c_str(), true, false); + + // Update lamb's name + primary_func->SetName(lambda_id->Name()); + + auto v = make_intrusive(primary_func); + lambda_id->SetVal(std::move(v)); + lambda_id->SetType(ingr_t); + lambda_id->SetConst(); + + analyze_lambda(this); +} + +LambdaExpr::LambdaExpr(LambdaExpr* orig) : Expr(EXPR_LAMBDA) { + primary_func = orig->primary_func; + ingredients = orig->ingredients; + lambda_id = orig->lambda_id; + my_name = orig->my_name; + private_captures = orig->private_captures; + + // We need to have our own copies of the outer IDs and captures so + // we can rename them when inlined. + for ( auto i : orig->outer_ids ) + outer_ids.append(i); + + if ( orig->captures ) { + captures = std::vector{}; + for ( auto& c : *orig->captures ) + captures->push_back(c); + } + + SetType(orig->GetType()); +} + +bool LambdaExpr::CheckCaptures(StmtPtr when_parent) { + auto desc = when_parent ? "\"when\" statement" : "lambda"; + + if ( ! captures ) { + if ( outer_ids.size() > 0 ) { + reporter->Error("%s uses outer identifiers without [] captures: %s%s", desc, + outer_ids.size() > 1 ? "e.g., " : "", outer_ids[0]->Name()); + return false; + } + + return true; + } + + std::set outer_is_matched; + std::set capture_is_matched; + + for ( const auto& c : *captures ) { + auto cid = c.Id().get(); + + if ( ! cid ) + // This happens for undefined/inappropriate + // identifiers listed in captures. There's + // already been an error message. + continue; + + if ( capture_is_matched.count(cid) > 0 ) { + auto msg = util::fmt("%s listed multiple times in capture", cid->Name()); + if ( when_parent ) + when_parent->Error(msg); + else + ExprError(msg); + + return false; + } + + for ( auto id : outer_ids ) + if ( cid == id ) { + outer_is_matched.insert(id); + capture_is_matched.insert(cid); + break; + } + } + + for ( auto id : outer_ids ) + if ( outer_is_matched.count(id) == 0 ) { + auto msg = util::fmt("%s is used inside %s but not captured", id->Name(), desc); + if ( when_parent ) + when_parent->Error(msg); + else + ExprError(msg); + + return false; + } + + for ( const auto& c : *captures ) { + auto cid = c.Id().get(); + if ( cid && capture_is_matched.count(cid) == 0 ) { + auto msg = util::fmt("%s is captured but not used inside %s", cid->Name(), desc); + if ( when_parent ) + when_parent->Error(msg); + else + ExprError(msg); + + return false; + } + } + + return true; +} + +void LambdaExpr::BuildName() { + // Get the body's "string" representation. + ODesc d; + primary_func->Describe(&d); + + if ( captures ) + for ( auto& c : *captures ) { + if ( c.IsDeepCopy() ) + d.AddSP("copy"); + + if ( c.Id() ) + // c.Id() will be nil for some errors + c.Id()->Describe(&d); + } + + for ( ;; ) { + hash128_t h; + KeyedHash::Hash128(d.Bytes(), d.Len(), &h); + + my_name = "lambda_<" + std::to_string(h[0]) + ">"; + auto fullname = make_full_var_name(current_module.data(), my_name.data()); + const auto& id = current_scope()->Find(fullname); + + if ( id ) + // Just try again to make a unique lambda name. + // If two peer processes need to agree on the same + // lambda name, this assumes they're loading the same + // scripts and thus have the same hash collisions. + d.Add(" "); + else + break; + } +} + +ScopePtr LambdaExpr::GetScope() const { return ingredients->Scope(); } + +void LambdaExpr::ReplaceBody(StmtPtr new_body) { ingredients->ReplaceBody(std::move(new_body)); } + +ValPtr LambdaExpr::Eval(Frame* f) const { + auto lamb = make_intrusive(ingredients->GetID()); + + // Use the primary function as the source of the frame size + // and function body, rather than the ingredients, since script + // optimization might have changed the former but not the latter. + lamb->SetFrameSize(primary_func->FrameSize()); + StmtPtr body = primary_func->GetBodies()[0].stmts; + + if ( run_state::is_parsing ) + // We're evaluating this lambda at parse time, which happens + // for initializations. If we're doing script optimization + // then the current version of the body might be left in an + // inconsistent state (e.g., if it's replaced with ZAM code) + // causing problems if we execute this lambda subsequently. + // To avoid that problem, we duplicate the AST so it's + // distinct. + body = body->Duplicate(); + + lamb->AddBody(*ingredients, body); + lamb->CreateCaptures(f); + + // Set name to corresponding master func. + // Allows for lookups by the receiver. + lamb->SetName(my_name.c_str()); + + return make_intrusive(std::move(lamb)); +} + +void LambdaExpr::ExprDescribe(ODesc* d) const { + d->Add(expr_name(Tag())); + + if ( captures && d->IsReadable() ) { + d->Add("["); + + for ( auto& c : *captures ) { + if ( &c != &(*captures)[0] ) + d->AddSP(", "); + + if ( c.IsDeepCopy() ) + d->AddSP("copy"); + + d->Add(c.Id()->Name()); + } + + d->Add("]"); + } + + ingredients->Body()->Describe(d); +} + +TraversalCode LambdaExpr::Traverse(TraversalCallback* cb) const { + if ( IsError() ) + // Not well-formed. + return TC_CONTINUE; + + TraversalCode tc = cb->PreExpr(this); + HANDLE_TC_EXPR_PRE(tc); + + tc = lambda_id->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + + tc = ingredients->Body()->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + + tc = cb->PostExpr(this); + HANDLE_TC_EXPR_POST(tc); +} EventExpr::EventExpr(const char* arg_name, ListExprPtr arg_args) - : Expr(EXPR_EVENT), name(arg_name), args(std::move(arg_args)) - { - EventHandler* h = event_registry->Lookup(name); - - if ( ! h ) - { - h = new EventHandler(name.c_str()); - event_registry->Register(h, true); - } - - h->SetUsed(); - - handler = h; - - if ( args->IsError() ) - { - SetError(); - return; - } - - const auto& func_type = h->GetType(); - - if ( ! func_type ) - { - Error("not an event"); - SetError(); - return; - } - - if ( record_type_has_errors(func_type->AsFuncType()->Params()->AsRecordType()) ) - SetError(); - else if ( ! func_type->MatchesIndex(args.get()) ) - SetError("argument type mismatch in event invocation"); - else - { - if ( func_type->Yield() ) - { - Error("function invoked as an event"); - SetError(); - } - } - } - -ValPtr EventExpr::Eval(Frame* f) const - { - if ( IsError() ) - return nullptr; - - auto v = eval_list(f, args.get()); - - if ( handler ) - { - if ( etm ) - etm->ScriptEventQueued(handler); - - event_mgr.Enqueue(handler, std::move(*v)); - } - - return nullptr; - } - -TraversalCode EventExpr::Traverse(TraversalCallback* cb) const - { - TraversalCode tc = cb->PreExpr(this); - HANDLE_TC_EXPR_PRE(tc); - - auto& f = handler->GetFunc(); - if ( f ) - { - // We don't traverse the function, because that can lead - // to infinite traversals. We do, however, see if we can - // locate the corresponding identifier, and traverse that. - - auto& id = lookup_ID(f->Name(), GLOBAL_MODULE_NAME, false, false, false); - - if ( id ) - { - tc = id->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); - } - } - - tc = args->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); - - tc = cb->PostExpr(this); - HANDLE_TC_EXPR_POST(tc); - } - -void EventExpr::ExprDescribe(ODesc* d) const - { - d->Add(name.c_str()); - if ( d->IsReadable() ) - { - d->Add("("); - args->Describe(d); - d->Add(")"); - } - else - args->Describe(d); - } - -ListExpr::ListExpr() : Expr(EXPR_LIST) - { - SetType(make_intrusive()); - } - -ListExpr::ListExpr(ExprPtr e) : Expr(EXPR_LIST) - { - SetType(make_intrusive()); - Append(std::move(e)); - } - -ListExpr::~ListExpr() - { - for ( const auto& expr : exprs ) - Unref(expr); - } - -void ListExpr::Append(ExprPtr e) - { - exprs.push_back(e.release()); - ((TypeList*)type.get())->Append(exprs.back()->GetType()); - } - -bool ListExpr::IsPure() const - { - for ( const auto& expr : exprs ) - if ( ! expr->IsPure() ) - return false; - - return true; - } - -ValPtr ListExpr::Eval(Frame* f) const - { - std::vector evs; - - for ( const auto& expr : exprs ) - { - auto ev = expr->Eval(f); - - if ( ! ev ) - { - RuntimeError("uninitialized list value"); - return nullptr; - } - - evs.push_back(std::move(ev)); - } - - return make_intrusive(cast_intrusive(type), std::move(evs)); - } - -TypePtr ListExpr::InitType() const - { - if ( exprs.empty() ) - { - Error("empty list in untyped initialization"); - return nullptr; - } - - if ( exprs[0]->IsRecordElement(nullptr) ) - { - type_decl_list* types = new type_decl_list(exprs.length()); - for ( const auto& expr : exprs ) - { - TypeDecl* td = new TypeDecl(nullptr, nullptr); - if ( ! expr->IsRecordElement(td) ) - { - expr->Error("record element expected"); - delete td; - delete types; - return nullptr; - } - - types->push_back(td); - } - - return make_intrusive(types); - } - - else - { - auto tl = make_intrusive(); - - for ( const auto& e : exprs ) - { - const auto& ti = e->GetType(); - - // Collapse any embedded sets or lists. - if ( ti->IsSet() || ti->Tag() == TYPE_LIST ) - { - TypeList* til = ti->IsSet() ? ti->AsSetType()->GetIndices().get() - : ti->AsTypeList(); - - if ( ! til->IsPure() || ! til->AllMatch(til->GetPureType(), true) ) - tl->Append({NewRef{}, til}); - else - tl->Append(til->GetPureType()); - } - else - tl->Append(ti); - } - - return tl; - } - } - -void ListExpr::ExprDescribe(ODesc* d) const - { - d->AddCount(exprs.length()); - - loop_over_list(exprs, i) - { - if ( d->IsReadable() && i > 0 ) - d->Add(", "); - - exprs[i]->Describe(d); - } - } - -ExprPtr ListExpr::MakeLvalue() - { - for ( const auto& expr : exprs ) - if ( expr->Tag() != EXPR_NAME ) - ExprError("can only assign to list of identifiers"); - - return make_intrusive(ThisPtr()); - } - -void ListExpr::Assign(Frame* f, ValPtr v) - { - ListVal* lv = v->AsListVal(); - - if ( exprs.length() != lv->Length() ) - RuntimeError("mismatch in list lengths"); - - loop_over_list(exprs, i) exprs[i]->Assign(f, lv->Idx(i)); - } - -TraversalCode ListExpr::Traverse(TraversalCallback* cb) const - { - TraversalCode tc = cb->PreExpr(this); - HANDLE_TC_EXPR_PRE(tc); - - for ( const auto& expr : exprs ) - { - tc = expr->Traverse(cb); - HANDLE_TC_EXPR_PRE(tc); - } - - tc = cb->PostExpr(this); - HANDLE_TC_EXPR_POST(tc); - } - -RecordAssignExpr::RecordAssignExpr(const ExprPtr& record, const ExprPtr& init_list, bool is_init) - { - const ExprPList& inits = init_list->AsListExpr()->Exprs(); - - RecordType* lhs = record->GetType()->AsRecordType(); - - // The inits have two forms: - // 1) other records -- use all matching field names+types - // 2) a string indicating the field name, then (as the next element) - // the value to use for that field. - - for ( const auto& init : inits ) - { - if ( init->GetType()->Tag() == TYPE_RECORD ) - { - RecordType* t = init->GetType()->AsRecordType(); - - for ( int j = 0; j < t->NumFields(); ++j ) - { - const char* field_name = t->FieldName(j); - int field = lhs->FieldOffset(field_name); - - if ( field >= 0 && same_type(lhs->GetFieldType(field), t->GetFieldType(j)) ) - { - auto fe_lhs = make_intrusive(record, field_name); - auto fe_rhs = make_intrusive(IntrusivePtr{NewRef{}, init}, - field_name); - Append(get_assign_expr(std::move(fe_lhs), std::move(fe_rhs), is_init)); - } - } - } - - else if ( init->Tag() == EXPR_FIELD_ASSIGN ) - { - FieldAssignExpr* rf = (FieldAssignExpr*)init; - rf->Ref(); - - const char* field_name = ""; // rf->FieldName(); - if ( lhs->HasField(field_name) ) - { - auto fe_lhs = make_intrusive(record, field_name); - ExprPtr fe_rhs = {NewRef{}, rf->Op()}; - Append(get_assign_expr(std::move(fe_lhs), std::move(fe_rhs), is_init)); - } - else - { - std::string s = "No such field '"; - s += field_name; - s += "'"; - init_list->SetError(s.c_str()); - } - } - - else - { - init_list->SetError("bad record initializer"); - return; - } - } - } - -CastExpr::CastExpr(ExprPtr arg_op, TypePtr t) : UnaryExpr(EXPR_CAST, std::move(arg_op)) - { - auto stype = Op()->GetType(); - - SetType(std::move(t)); - - if ( ! can_cast_value_to_type(stype.get(), GetType().get()) ) - ExprError("cast not supported"); - } - -ValPtr CastExpr::Fold(Val* v) const - { - std::string error; - auto res = cast_value({NewRef{}, v}, GetType(), error); - - if ( ! res ) - RuntimeError(error.c_str()); - - return res; - } - -ValPtr cast_value(ValPtr v, const TypePtr& t, std::string& error) - { - auto nv = cast_value_to_type(v.get(), t.get()); - - if ( nv ) - return nv; - - ODesc d; - - d.Add("invalid cast of value with type '"); - v->GetType()->Describe(&d); - d.Add("' to type '"); - t->Describe(&d); - d.Add("'"); - - if ( same_type(v->GetType(), Broker::detail::DataVal::ScriptDataType()) && - ! v->AsRecordVal()->HasField(0) ) - d.Add(" (nil $data field)"); - - error = d.Description(); - return nullptr; - } - -void CastExpr::ExprDescribe(ODesc* d) const - { - Op()->Describe(d); - d->Add(" as "); - GetType()->Describe(d); - } - -IsExpr::IsExpr(ExprPtr arg_op, TypePtr arg_t) - : UnaryExpr(EXPR_IS, std::move(arg_op)), t(std::move(arg_t)) - { - SetType(base_type(TYPE_BOOL)); - } - -ValPtr IsExpr::Fold(Val* v) const - { - if ( IsError() ) - return nullptr; - - return val_mgr->Bool(can_cast_value_to_type(v, t.get())); - } - -void IsExpr::ExprDescribe(ODesc* d) const - { - Op()->Describe(d); - d->Add(" is "); - t->Describe(d); - } - -ExprPtr get_assign_expr(ExprPtr op1, ExprPtr op2, bool is_init) - { - if ( op1->GetType()->Tag() == TYPE_RECORD && op2->GetType()->Tag() == TYPE_LIST ) - return make_intrusive(std::move(op1), std::move(op2), is_init); - - else if ( op1->Tag() == EXPR_INDEX && op1->AsIndexExpr()->IsSlice() ) - return make_intrusive(std::move(op1), std::move(op2), is_init); - - else - return make_intrusive(std::move(op1), std::move(op2), is_init); - } - -ExprPtr check_and_promote_expr(ExprPtr e, TypePtr t) - { - const auto& et = e->GetType(); - TypeTag e_tag = et->Tag(); - TypeTag t_tag = t->Tag(); - - if ( t_tag == TYPE_ANY ) - { - if ( e_tag != TYPE_ANY ) - return make_intrusive(e); - - return e; - } - - if ( e_tag == TYPE_ANY ) - return make_intrusive(e, t); - - if ( EitherArithmetic(t_tag, e_tag) ) - { - if ( e_tag == t_tag ) - return e; - - if ( ! BothArithmetic(t_tag, e_tag) ) - { - t->Error("arithmetic mixed with non-arithmetic", e.get()); - return nullptr; - } - - TypeTag mt = max_type(t_tag, e_tag); - if ( mt != t_tag ) - { - t->Error("over-promotion of arithmetic value", e.get()); - return nullptr; - } - - return make_intrusive(e, t_tag); - } - - if ( t->Tag() == TYPE_RECORD && et->Tag() == TYPE_RECORD ) - { - RecordType* t_r = t->AsRecordType(); - RecordType* et_r = et->AsRecordType(); - - if ( same_type(t, et) ) - return e; - - if ( record_promotion_compatible(t_r, et_r) ) - return make_intrusive(e, IntrusivePtr{NewRef{}, t_r}); - - t->Error("incompatible record types", e.get()); - return nullptr; - } - - if ( ! same_type(t, et) ) - { - if ( t->Tag() == TYPE_TABLE && et->Tag() == TYPE_TABLE && - et->AsTableType()->IsUnspecifiedTable() ) - { - if ( e->Tag() == EXPR_TABLE_CONSTRUCTOR ) - { - auto& attrs = cast_intrusive(e)->GetAttrs(); - zeek::detail::AttrPtr def = Attr::nil; - - // Check for &default or &default_insert expressions - // and use it for type checking against t. - if ( attrs ) - { - def = attrs->Find(ATTR_DEFAULT); - if ( ! def ) - def = attrs->Find(ATTR_DEFAULT_INSERT); - } - - if ( def ) - { - std::string err_msg; - if ( ! check_default_attr(def.get(), t, false, false, err_msg) ) - { - if ( ! err_msg.empty() ) - t->Error(err_msg.c_str(), e.get()); - return nullptr; - } - } - } - - return make_intrusive(e, IntrusivePtr{NewRef{}, t->AsTableType()}, - false); - } - - if ( t->Tag() == TYPE_VECTOR && et->Tag() == TYPE_VECTOR && - et->AsVectorType()->IsUnspecifiedVector() ) - return make_intrusive(e, IntrusivePtr{NewRef{}, t->AsVectorType()}); - - if ( t->Tag() != TYPE_ERROR && et->Tag() != TYPE_ERROR ) - t->Error("type clash", e.get()); - - return nullptr; - } - - return e; - } - -bool check_and_promote_exprs(ListExpr* const elements, const TypeListPtr& types) - { - ExprPList& el = elements->Exprs(); - const auto& tl = types->GetTypes(); - - if ( tl.size() == 1 && tl[0]->Tag() == TYPE_ANY ) - return true; - - if ( el.length() != static_cast(tl.size()) ) - { - types->Error("indexing mismatch", elements); - return false; - } - - loop_over_list(el, i) - { - ExprPtr e = {NewRef{}, el[i]}; - auto promoted_e = check_and_promote_expr(e, tl[i]); - - if ( ! promoted_e ) - { - e->Error("type mismatch", tl[i].get()); - return false; - } - - if ( promoted_e != e ) - Unref(el.replace(i, promoted_e.release())); - } - - return true; - } - -bool check_and_promote_args(ListExpr* const args, const RecordType* types) - { - ExprPList& el = args->Exprs(); - int ntypes = types->NumFields(); - - // give variadic BIFs automatic pass - if ( ntypes == 1 && types->FieldDecl(0)->type->Tag() == TYPE_ANY ) - return true; - - if ( el.length() < ntypes ) - { - std::vector def_elements; - - // Start from rightmost parameter, work backward to fill in missing - // arguments using &default expressions. - for ( int i = ntypes - 1; i >= el.length(); --i ) - { - auto td = types->FieldDecl(i); - const auto& def_attr = td->attrs ? td->attrs->Find(ATTR_DEFAULT).get() : nullptr; - - if ( ! def_attr ) - { - types->Error("parameter mismatch", args); - return false; - } - - // Don't use the default expression directly, as - // doing so will wind up sharing its code across - // different invocations that use the default - // argument. That works okay for the interpreter, - // but if we transform the code we want that done - // separately for each instance, rather than - // one instance inheriting the transformed version - // from another. - const auto& e = def_attr->GetExpr(); - def_elements.emplace_back(e->Duplicate()); - } - - auto ne = def_elements.size(); - while ( ne ) - el.push_back(def_elements[--ne].release()); - } - - auto tl = make_intrusive(); - - for ( int i = 0; i < types->NumFields(); ++i ) - tl->Append(types->GetFieldType(i)); - - int rval = check_and_promote_exprs(args, tl); - - return rval; - } - -bool check_and_promote_exprs_to_type(ListExpr* const elements, TypePtr t) - { - ExprPList& el = elements->Exprs(); - - if ( t->Tag() == TYPE_ANY ) - return true; - - loop_over_list(el, i) - { - ExprPtr e = {NewRef{}, el[i]}; - auto promoted_e = check_and_promote_expr(e, t); - - if ( ! promoted_e ) - { - e->Error("type mismatch", t.get()); - return false; - } - - if ( promoted_e != e ) - Unref(el.replace(i, promoted_e.release())); - } - - return true; - } - -std::optional> eval_list(Frame* f, const ListExpr* l) - { - const ExprPList& e = l->Exprs(); - auto rval = std::make_optional>(); - rval->reserve(e.length()); - - for ( const auto& expr : e ) - { - auto ev = expr->Eval(f); - - if ( ! ev ) - return {}; - - rval->emplace_back(std::move(ev)); - } - - return rval; - } - -bool expr_greater(const Expr* e1, const Expr* e2) - { - return e1->Tag() > e2->Tag(); - } - - } // namespace zeek::detail + : Expr(EXPR_EVENT), name(arg_name), args(std::move(arg_args)) { + EventHandler* h = event_registry->Lookup(name); + + if ( ! h ) { + h = new EventHandler(name.c_str()); + event_registry->Register(h, true); + } + + h->SetUsed(); + + handler = h; + + if ( args->IsError() ) { + SetError(); + return; + } + + const auto& func_type = h->GetType(); + + if ( ! func_type ) { + Error("not an event"); + SetError(); + return; + } + + if ( record_type_has_errors(func_type->AsFuncType()->Params()->AsRecordType()) ) + SetError(); + else if ( ! func_type->MatchesIndex(args.get()) ) + SetError("argument type mismatch in event invocation"); + else { + if ( func_type->Yield() ) { + Error("function invoked as an event"); + SetError(); + } + } +} + +ValPtr EventExpr::Eval(Frame* f) const { + if ( IsError() ) + return nullptr; + + auto v = eval_list(f, args.get()); + + if ( handler ) { + if ( etm ) + etm->ScriptEventQueued(handler); + + event_mgr.Enqueue(handler, std::move(*v)); + } + + return nullptr; +} + +TraversalCode EventExpr::Traverse(TraversalCallback* cb) const { + TraversalCode tc = cb->PreExpr(this); + HANDLE_TC_EXPR_PRE(tc); + + auto& f = handler->GetFunc(); + if ( f ) { + // We don't traverse the function, because that can lead + // to infinite traversals. We do, however, see if we can + // locate the corresponding identifier, and traverse that. + + auto& id = lookup_ID(f->Name(), GLOBAL_MODULE_NAME, false, false, false); + + if ( id ) { + tc = id->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + } + } + + tc = args->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + + tc = cb->PostExpr(this); + HANDLE_TC_EXPR_POST(tc); +} + +void EventExpr::ExprDescribe(ODesc* d) const { + d->Add(name.c_str()); + if ( d->IsReadable() ) { + d->Add("("); + args->Describe(d); + d->Add(")"); + } + else + args->Describe(d); +} + +ListExpr::ListExpr() : Expr(EXPR_LIST) { SetType(make_intrusive()); } + +ListExpr::ListExpr(ExprPtr e) : Expr(EXPR_LIST) { + SetType(make_intrusive()); + Append(std::move(e)); +} + +ListExpr::~ListExpr() { + for ( const auto& expr : exprs ) + Unref(expr); +} + +void ListExpr::Append(ExprPtr e) { + exprs.push_back(e.release()); + ((TypeList*)type.get())->Append(exprs.back()->GetType()); +} + +bool ListExpr::IsPure() const { + for ( const auto& expr : exprs ) + if ( ! expr->IsPure() ) + return false; + + return true; +} + +ValPtr ListExpr::Eval(Frame* f) const { + std::vector evs; + + for ( const auto& expr : exprs ) { + auto ev = expr->Eval(f); + + if ( ! ev ) { + RuntimeError("uninitialized list value"); + return nullptr; + } + + evs.push_back(std::move(ev)); + } + + return make_intrusive(cast_intrusive(type), std::move(evs)); +} + +TypePtr ListExpr::InitType() const { + if ( exprs.empty() ) { + Error("empty list in untyped initialization"); + return nullptr; + } + + if ( exprs[0]->IsRecordElement(nullptr) ) { + type_decl_list* types = new type_decl_list(exprs.length()); + for ( const auto& expr : exprs ) { + TypeDecl* td = new TypeDecl(nullptr, nullptr); + if ( ! expr->IsRecordElement(td) ) { + expr->Error("record element expected"); + delete td; + delete types; + return nullptr; + } + + types->push_back(td); + } + + return make_intrusive(types); + } + + else { + auto tl = make_intrusive(); + + for ( const auto& e : exprs ) { + const auto& ti = e->GetType(); + + // Collapse any embedded sets or lists. + if ( ti->IsSet() || ti->Tag() == TYPE_LIST ) { + TypeList* til = ti->IsSet() ? ti->AsSetType()->GetIndices().get() : ti->AsTypeList(); + + if ( ! til->IsPure() || ! til->AllMatch(til->GetPureType(), true) ) + tl->Append({NewRef{}, til}); + else + tl->Append(til->GetPureType()); + } + else + tl->Append(ti); + } + + return tl; + } +} + +void ListExpr::ExprDescribe(ODesc* d) const { + d->AddCount(exprs.length()); + + loop_over_list(exprs, i) { + if ( d->IsReadable() && i > 0 ) + d->Add(", "); + + exprs[i]->Describe(d); + } +} + +ExprPtr ListExpr::MakeLvalue() { + for ( const auto& expr : exprs ) + if ( expr->Tag() != EXPR_NAME ) + ExprError("can only assign to list of identifiers"); + + return make_intrusive(ThisPtr()); +} + +void ListExpr::Assign(Frame* f, ValPtr v) { + ListVal* lv = v->AsListVal(); + + if ( exprs.length() != lv->Length() ) + RuntimeError("mismatch in list lengths"); + + loop_over_list(exprs, i) exprs[i]->Assign(f, lv->Idx(i)); +} + +TraversalCode ListExpr::Traverse(TraversalCallback* cb) const { + TraversalCode tc = cb->PreExpr(this); + HANDLE_TC_EXPR_PRE(tc); + + for ( const auto& expr : exprs ) { + tc = expr->Traverse(cb); + HANDLE_TC_EXPR_PRE(tc); + } + + tc = cb->PostExpr(this); + HANDLE_TC_EXPR_POST(tc); +} + +RecordAssignExpr::RecordAssignExpr(const ExprPtr& record, const ExprPtr& init_list, bool is_init) { + const ExprPList& inits = init_list->AsListExpr()->Exprs(); + + RecordType* lhs = record->GetType()->AsRecordType(); + + // The inits have two forms: + // 1) other records -- use all matching field names+types + // 2) a string indicating the field name, then (as the next element) + // the value to use for that field. + + for ( const auto& init : inits ) { + if ( init->GetType()->Tag() == TYPE_RECORD ) { + RecordType* t = init->GetType()->AsRecordType(); + + for ( int j = 0; j < t->NumFields(); ++j ) { + const char* field_name = t->FieldName(j); + int field = lhs->FieldOffset(field_name); + + if ( field >= 0 && same_type(lhs->GetFieldType(field), t->GetFieldType(j)) ) { + auto fe_lhs = make_intrusive(record, field_name); + auto fe_rhs = make_intrusive(IntrusivePtr{NewRef{}, init}, field_name); + Append(get_assign_expr(std::move(fe_lhs), std::move(fe_rhs), is_init)); + } + } + } + + else if ( init->Tag() == EXPR_FIELD_ASSIGN ) { + FieldAssignExpr* rf = (FieldAssignExpr*)init; + rf->Ref(); + + const char* field_name = ""; // rf->FieldName(); + if ( lhs->HasField(field_name) ) { + auto fe_lhs = make_intrusive(record, field_name); + ExprPtr fe_rhs = {NewRef{}, rf->Op()}; + Append(get_assign_expr(std::move(fe_lhs), std::move(fe_rhs), is_init)); + } + else { + std::string s = "No such field '"; + s += field_name; + s += "'"; + init_list->SetError(s.c_str()); + } + } + + else { + init_list->SetError("bad record initializer"); + return; + } + } +} + +CastExpr::CastExpr(ExprPtr arg_op, TypePtr t) : UnaryExpr(EXPR_CAST, std::move(arg_op)) { + auto stype = Op()->GetType(); + + SetType(std::move(t)); + + if ( ! can_cast_value_to_type(stype.get(), GetType().get()) ) + ExprError("cast not supported"); +} + +ValPtr CastExpr::Fold(Val* v) const { + std::string error; + auto res = cast_value({NewRef{}, v}, GetType(), error); + + if ( ! res ) + RuntimeError(error.c_str()); + + return res; +} + +ValPtr cast_value(ValPtr v, const TypePtr& t, std::string& error) { + auto nv = cast_value_to_type(v.get(), t.get()); + + if ( nv ) + return nv; + + ODesc d; + + d.Add("invalid cast of value with type '"); + v->GetType()->Describe(&d); + d.Add("' to type '"); + t->Describe(&d); + d.Add("'"); + + if ( same_type(v->GetType(), Broker::detail::DataVal::ScriptDataType()) && ! v->AsRecordVal()->HasField(0) ) + d.Add(" (nil $data field)"); + + error = d.Description(); + return nullptr; +} + +void CastExpr::ExprDescribe(ODesc* d) const { + Op()->Describe(d); + d->Add(" as "); + GetType()->Describe(d); +} + +IsExpr::IsExpr(ExprPtr arg_op, TypePtr arg_t) : UnaryExpr(EXPR_IS, std::move(arg_op)), t(std::move(arg_t)) { + SetType(base_type(TYPE_BOOL)); +} + +ValPtr IsExpr::Fold(Val* v) const { + if ( IsError() ) + return nullptr; + + return val_mgr->Bool(can_cast_value_to_type(v, t.get())); +} + +void IsExpr::ExprDescribe(ODesc* d) const { + Op()->Describe(d); + d->Add(" is "); + t->Describe(d); +} + +ExprPtr get_assign_expr(ExprPtr op1, ExprPtr op2, bool is_init) { + if ( op1->GetType()->Tag() == TYPE_RECORD && op2->GetType()->Tag() == TYPE_LIST ) + return make_intrusive(std::move(op1), std::move(op2), is_init); + + else if ( op1->Tag() == EXPR_INDEX && op1->AsIndexExpr()->IsSlice() ) + return make_intrusive(std::move(op1), std::move(op2), is_init); + + else + return make_intrusive(std::move(op1), std::move(op2), is_init); +} + +ExprPtr check_and_promote_expr(ExprPtr e, TypePtr t) { + const auto& et = e->GetType(); + TypeTag e_tag = et->Tag(); + TypeTag t_tag = t->Tag(); + + if ( t_tag == TYPE_ANY ) { + if ( e_tag != TYPE_ANY ) + return make_intrusive(e); + + return e; + } + + if ( e_tag == TYPE_ANY ) + return make_intrusive(e, t); + + if ( EitherArithmetic(t_tag, e_tag) ) { + if ( e_tag == t_tag ) + return e; + + if ( ! BothArithmetic(t_tag, e_tag) ) { + t->Error("arithmetic mixed with non-arithmetic", e.get()); + return nullptr; + } + + TypeTag mt = max_type(t_tag, e_tag); + if ( mt != t_tag ) { + t->Error("over-promotion of arithmetic value", e.get()); + return nullptr; + } + + return make_intrusive(e, t_tag); + } + + if ( t->Tag() == TYPE_RECORD && et->Tag() == TYPE_RECORD ) { + RecordType* t_r = t->AsRecordType(); + RecordType* et_r = et->AsRecordType(); + + if ( same_type(t, et) ) + return e; + + if ( record_promotion_compatible(t_r, et_r) ) + return make_intrusive(e, IntrusivePtr{NewRef{}, t_r}); + + t->Error("incompatible record types", e.get()); + return nullptr; + } + + if ( ! same_type(t, et) ) { + if ( t->Tag() == TYPE_TABLE && et->Tag() == TYPE_TABLE && et->AsTableType()->IsUnspecifiedTable() ) { + if ( e->Tag() == EXPR_TABLE_CONSTRUCTOR ) { + auto& attrs = cast_intrusive(e)->GetAttrs(); + zeek::detail::AttrPtr def = Attr::nil; + + // Check for &default or &default_insert expressions + // and use it for type checking against t. + if ( attrs ) { + def = attrs->Find(ATTR_DEFAULT); + if ( ! def ) + def = attrs->Find(ATTR_DEFAULT_INSERT); + } + + if ( def ) { + std::string err_msg; + if ( ! check_default_attr(def.get(), t, false, false, err_msg) ) { + if ( ! err_msg.empty() ) + t->Error(err_msg.c_str(), e.get()); + return nullptr; + } + } + } + + return make_intrusive(e, IntrusivePtr{NewRef{}, t->AsTableType()}, false); + } + + if ( t->Tag() == TYPE_VECTOR && et->Tag() == TYPE_VECTOR && et->AsVectorType()->IsUnspecifiedVector() ) + return make_intrusive(e, IntrusivePtr{NewRef{}, t->AsVectorType()}); + + if ( t->Tag() != TYPE_ERROR && et->Tag() != TYPE_ERROR ) + t->Error("type clash", e.get()); + + return nullptr; + } + + return e; +} + +bool check_and_promote_exprs(ListExpr* const elements, const TypeListPtr& types) { + ExprPList& el = elements->Exprs(); + const auto& tl = types->GetTypes(); + + if ( tl.size() == 1 && tl[0]->Tag() == TYPE_ANY ) + return true; + + if ( el.length() != static_cast(tl.size()) ) { + types->Error("indexing mismatch", elements); + return false; + } + + loop_over_list(el, i) { + ExprPtr e = {NewRef{}, el[i]}; + auto promoted_e = check_and_promote_expr(e, tl[i]); + + if ( ! promoted_e ) { + e->Error("type mismatch", tl[i].get()); + return false; + } + + if ( promoted_e != e ) + Unref(el.replace(i, promoted_e.release())); + } + + return true; +} + +bool check_and_promote_args(ListExpr* const args, const RecordType* types) { + ExprPList& el = args->Exprs(); + int ntypes = types->NumFields(); + + // give variadic BIFs automatic pass + if ( ntypes == 1 && types->FieldDecl(0)->type->Tag() == TYPE_ANY ) + return true; + + if ( el.length() < ntypes ) { + std::vector def_elements; + + // Start from rightmost parameter, work backward to fill in missing + // arguments using &default expressions. + for ( int i = ntypes - 1; i >= el.length(); --i ) { + auto td = types->FieldDecl(i); + const auto& def_attr = td->attrs ? td->attrs->Find(ATTR_DEFAULT).get() : nullptr; + + if ( ! def_attr ) { + types->Error("parameter mismatch", args); + return false; + } + + // Don't use the default expression directly, as + // doing so will wind up sharing its code across + // different invocations that use the default + // argument. That works okay for the interpreter, + // but if we transform the code we want that done + // separately for each instance, rather than + // one instance inheriting the transformed version + // from another. + const auto& e = def_attr->GetExpr(); + def_elements.emplace_back(e->Duplicate()); + } + + auto ne = def_elements.size(); + while ( ne ) + el.push_back(def_elements[--ne].release()); + } + + auto tl = make_intrusive(); + + for ( int i = 0; i < types->NumFields(); ++i ) + tl->Append(types->GetFieldType(i)); + + int rval = check_and_promote_exprs(args, tl); + + return rval; +} + +bool check_and_promote_exprs_to_type(ListExpr* const elements, TypePtr t) { + ExprPList& el = elements->Exprs(); + + if ( t->Tag() == TYPE_ANY ) + return true; + + loop_over_list(el, i) { + ExprPtr e = {NewRef{}, el[i]}; + auto promoted_e = check_and_promote_expr(e, t); + + if ( ! promoted_e ) { + e->Error("type mismatch", t.get()); + return false; + } + + if ( promoted_e != e ) + Unref(el.replace(i, promoted_e.release())); + } + + return true; +} + +std::optional> eval_list(Frame* f, const ListExpr* l) { + const ExprPList& e = l->Exprs(); + auto rval = std::make_optional>(); + rval->reserve(e.length()); + + for ( const auto& expr : e ) { + auto ev = expr->Eval(f); + + if ( ! ev ) + return {}; + + rval->emplace_back(std::move(ev)); + } + + return rval; +} + +bool expr_greater(const Expr* e1, const Expr* e2) { return e1->Tag() > e2->Tag(); } + +} // namespace zeek::detail diff --git a/src/Expr.h b/src/Expr.h index 2dde29e52c..c83fbcd843 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -18,12 +18,11 @@ #include "zeek/ZeekArgs.h" #include "zeek/ZeekList.h" -namespace zeek - { -template class IntrusivePtr; +namespace zeek { +template +class IntrusivePtr; -namespace detail - { +namespace detail { class Frame; class Scope; @@ -34,81 +33,80 @@ using ScopePtr = IntrusivePtr; using ScriptFuncPtr = IntrusivePtr; using FunctionIngredientsPtr = std::shared_ptr; -enum ExprTag : int - { - EXPR_ANY = -1, - EXPR_NAME, - EXPR_CONST, - EXPR_CLONE, - EXPR_INCR, - EXPR_DECR, - EXPR_NOT, - EXPR_COMPLEMENT, - EXPR_POSITIVE, - EXPR_NEGATE, - EXPR_ADD, - EXPR_SUB, - EXPR_ADD_TO, - EXPR_REMOVE_FROM, - EXPR_TIMES, - EXPR_DIVIDE, - EXPR_MASK, - EXPR_MOD, - EXPR_AND, - EXPR_OR, - EXPR_XOR, - EXPR_LSHIFT, - EXPR_RSHIFT, - EXPR_AND_AND, - EXPR_OR_OR, - EXPR_LT, - EXPR_LE, - EXPR_EQ, - EXPR_NE, - EXPR_GE, - EXPR_GT, - EXPR_COND, - EXPR_REF, - EXPR_ASSIGN, - EXPR_INDEX, - EXPR_FIELD, - EXPR_HAS_FIELD, - EXPR_RECORD_CONSTRUCTOR, - EXPR_TABLE_CONSTRUCTOR, - EXPR_SET_CONSTRUCTOR, - EXPR_VECTOR_CONSTRUCTOR, - EXPR_FIELD_ASSIGN, - EXPR_IN, - EXPR_LIST, - EXPR_CALL, - EXPR_LAMBDA, - EXPR_EVENT, - EXPR_SCHEDULE, - EXPR_ARITH_COERCE, - EXPR_RECORD_COERCE, - EXPR_TABLE_COERCE, - EXPR_VECTOR_COERCE, - EXPR_SIZE, - EXPR_CAST, - EXPR_IS, - EXPR_INDEX_SLICE_ASSIGN, - EXPR_INLINE, +enum ExprTag : int { + EXPR_ANY = -1, + EXPR_NAME, + EXPR_CONST, + EXPR_CLONE, + EXPR_INCR, + EXPR_DECR, + EXPR_NOT, + EXPR_COMPLEMENT, + EXPR_POSITIVE, + EXPR_NEGATE, + EXPR_ADD, + EXPR_SUB, + EXPR_ADD_TO, + EXPR_REMOVE_FROM, + EXPR_TIMES, + EXPR_DIVIDE, + EXPR_MASK, + EXPR_MOD, + EXPR_AND, + EXPR_OR, + EXPR_XOR, + EXPR_LSHIFT, + EXPR_RSHIFT, + EXPR_AND_AND, + EXPR_OR_OR, + EXPR_LT, + EXPR_LE, + EXPR_EQ, + EXPR_NE, + EXPR_GE, + EXPR_GT, + EXPR_COND, + EXPR_REF, + EXPR_ASSIGN, + EXPR_INDEX, + EXPR_FIELD, + EXPR_HAS_FIELD, + EXPR_RECORD_CONSTRUCTOR, + EXPR_TABLE_CONSTRUCTOR, + EXPR_SET_CONSTRUCTOR, + EXPR_VECTOR_CONSTRUCTOR, + EXPR_FIELD_ASSIGN, + EXPR_IN, + EXPR_LIST, + EXPR_CALL, + EXPR_LAMBDA, + EXPR_EVENT, + EXPR_SCHEDULE, + EXPR_ARITH_COERCE, + EXPR_RECORD_COERCE, + EXPR_TABLE_COERCE, + EXPR_VECTOR_COERCE, + EXPR_SIZE, + EXPR_CAST, + EXPR_IS, + EXPR_INDEX_SLICE_ASSIGN, + EXPR_INLINE, - // The following types of expressions are only created for - // ASTs transformed to reduced form; they aren't germane for - // ASTs produced by parsing .zeek script files. - EXPR_INDEX_ASSIGN, - EXPR_FIELD_LHS_ASSIGN, - EXPR_APPEND_TO, - EXPR_TO_ANY_COERCE, - EXPR_FROM_ANY_COERCE, - EXPR_FROM_ANY_VEC_COERCE, - EXPR_ANY_INDEX, + // The following types of expressions are only created for + // ASTs transformed to reduced form; they aren't germane for + // ASTs produced by parsing .zeek script files. + EXPR_INDEX_ASSIGN, + EXPR_FIELD_LHS_ASSIGN, + EXPR_APPEND_TO, + EXPR_TO_ANY_COERCE, + EXPR_FROM_ANY_COERCE, + EXPR_FROM_ANY_VEC_COERCE, + EXPR_ANY_INDEX, - EXPR_NOP, + EXPR_NOP, #define NUM_EXPRS (int(EXPR_NOP) + 1) - }; +}; extern const char* expr_name(ExprTag t); @@ -150,909 +148,871 @@ using StmtPtr = IntrusivePtr; class ExprOptInfo; -class Expr : public Obj - { +class Expr : public Obj { public: - const TypePtr& GetType() const { return type; } + const TypePtr& GetType() const { return type; } - template IntrusivePtr GetType() const { return cast_intrusive(type); } + template + IntrusivePtr GetType() const { + return cast_intrusive(type); + } - ExprTag Tag() const { return tag; } + ExprTag Tag() const { return tag; } - Expr* Ref() - { - zeek::Ref(this); - return this; - } - ExprPtr ThisPtr() { return {NewRef{}, this}; } + Expr* Ref() { + zeek::Ref(this); + return this; + } + ExprPtr ThisPtr() { return {NewRef{}, this}; } - // Evaluates the expression and returns a corresponding Val*, - // or nil if the expression's value isn't fixed. - virtual ValPtr Eval(Frame* f) const = 0; + // Evaluates the expression and returns a corresponding Val*, + // or nil if the expression's value isn't fixed. + virtual ValPtr Eval(Frame* f) const = 0; - // Assign to the given value, if appropriate. - virtual void Assign(Frame* f, ValPtr v); + // Assign to the given value, if appropriate. + virtual void Assign(Frame* f, ValPtr v); - // Returns the type corresponding to this expression interpreted - // as an initialization. Returns nil if the initialization is illegal. - virtual TypePtr InitType() const; + // Returns the type corresponding to this expression interpreted + // as an initialization. Returns nil if the initialization is illegal. + virtual TypePtr InitType() const; - // Returns true if this expression, interpreted as an initialization, - // constitutes a record element, false otherwise. If the TypeDecl* - // is non-nil and the expression is a record element, fills in the - // TypeDecl with a description of the element. - virtual bool IsRecordElement(TypeDecl* td) const; + // Returns true if this expression, interpreted as an initialization, + // constitutes a record element, false otherwise. If the TypeDecl* + // is non-nil and the expression is a record element, fills in the + // TypeDecl with a description of the element. + virtual bool IsRecordElement(TypeDecl* td) const; - // True if the expression has no side effects, false otherwise. - virtual bool IsPure() const { return true; } + // True if the expression has no side effects, false otherwise. + virtual bool IsPure() const { return true; } - // True if the expression is a constant, false otherwise. - bool IsConst() const { return tag == EXPR_CONST; } + // True if the expression is a constant, false otherwise. + bool IsConst() const { return tag == EXPR_CONST; } - // True if the expression is in error (to alleviate error propagation). - bool IsError() const; + // True if the expression is in error (to alleviate error propagation). + bool IsError() const; - // Mark expression as in error. - void SetError(); - void SetError(const char* msg); + // Mark expression as in error. + void SetError(); + void SetError(const char* msg); - // Returns the expression's constant value, or complains - // if it's not a constant. - inline Val* ExprVal() const; + // Returns the expression's constant value, or complains + // if it's not a constant. + inline Val* ExprVal() const; - // True if the expression is a constant zero, false otherwise. - bool IsZero() const; + // True if the expression is a constant zero, false otherwise. + bool IsZero() const; - // True if the expression is a constant one, false otherwise. - bool IsOne() const; + // True if the expression is a constant one, false otherwise. + bool IsOne() const; - // True if the expression supports the "add" or "delete" operations, - // false otherwise. - virtual bool CanAdd() const; - virtual bool CanDel() const; + // True if the expression supports the "add" or "delete" operations, + // false otherwise. + virtual bool CanAdd() const; + virtual bool CanDel() const; - virtual void Add(Frame* f); // perform add operation - virtual void Delete(Frame* f); // perform delete operation + virtual void Add(Frame* f); // perform add operation + virtual void Delete(Frame* f); // perform delete operation - // Return the expression converted to L-value form. If expr - // cannot be used as an L-value, reports an error and returns - // the current value of expr (this is the default method). - virtual ExprPtr MakeLvalue(); + // Return the expression converted to L-value form. If expr + // cannot be used as an L-value, reports an error and returns + // the current value of expr (this is the default method). + virtual ExprPtr MakeLvalue(); - // Invert the sense of the operation. Returns true if the expression - // was invertible (currently only true for relational/equality - // expressions), false otherwise. - virtual bool InvertSense(); + // Invert the sense of the operation. Returns true if the expression + // was invertible (currently only true for relational/equality + // expressions), false otherwise. + virtual bool InvertSense(); - // Marks the expression as one requiring (or at least appearing - // with) parentheses. Used for pretty-printing. - void MarkParen() { paren = true; } - bool IsParen() const { return paren; } + // Marks the expression as one requiring (or at least appearing + // with) parentheses. Used for pretty-printing. + void MarkParen() { paren = true; } + bool IsParen() const { return paren; } -#define ZEEK_EXPR_ACCESSOR_DECLS(ctype) \ - const ctype* As##ctype() const; \ - ctype* As##ctype(); \ - IntrusivePtr As##ctype##Ptr(); +#define ZEEK_EXPR_ACCESSOR_DECLS(ctype) \ + const ctype* As##ctype() const; \ + ctype* As##ctype(); \ + IntrusivePtr As##ctype##Ptr(); - ZEEK_EXPR_ACCESSOR_DECLS(AddToExpr) - ZEEK_EXPR_ACCESSOR_DECLS(AnyIndexExpr) - ZEEK_EXPR_ACCESSOR_DECLS(AssignExpr) - ZEEK_EXPR_ACCESSOR_DECLS(CallExpr) - ZEEK_EXPR_ACCESSOR_DECLS(ConstExpr) - ZEEK_EXPR_ACCESSOR_DECLS(EventExpr) - ZEEK_EXPR_ACCESSOR_DECLS(FieldAssignExpr) - ZEEK_EXPR_ACCESSOR_DECLS(FieldExpr) - ZEEK_EXPR_ACCESSOR_DECLS(FieldLHSAssignExpr) - ZEEK_EXPR_ACCESSOR_DECLS(ForExpr) - ZEEK_EXPR_ACCESSOR_DECLS(HasFieldExpr) - ZEEK_EXPR_ACCESSOR_DECLS(IndexAssignExpr) - ZEEK_EXPR_ACCESSOR_DECLS(IndexExpr) - ZEEK_EXPR_ACCESSOR_DECLS(InlineExpr) - ZEEK_EXPR_ACCESSOR_DECLS(IsExpr) - ZEEK_EXPR_ACCESSOR_DECLS(LambdaExpr) - ZEEK_EXPR_ACCESSOR_DECLS(ListExpr) - ZEEK_EXPR_ACCESSOR_DECLS(NameExpr) - ZEEK_EXPR_ACCESSOR_DECLS(RecordCoerceExpr) - ZEEK_EXPR_ACCESSOR_DECLS(RecordConstructorExpr) - ZEEK_EXPR_ACCESSOR_DECLS(RefExpr) - ZEEK_EXPR_ACCESSOR_DECLS(SetConstructorExpr) - ZEEK_EXPR_ACCESSOR_DECLS(TableConstructorExpr) + ZEEK_EXPR_ACCESSOR_DECLS(AddToExpr) + ZEEK_EXPR_ACCESSOR_DECLS(AnyIndexExpr) + ZEEK_EXPR_ACCESSOR_DECLS(AssignExpr) + ZEEK_EXPR_ACCESSOR_DECLS(CallExpr) + ZEEK_EXPR_ACCESSOR_DECLS(ConstExpr) + ZEEK_EXPR_ACCESSOR_DECLS(EventExpr) + ZEEK_EXPR_ACCESSOR_DECLS(FieldAssignExpr) + ZEEK_EXPR_ACCESSOR_DECLS(FieldExpr) + ZEEK_EXPR_ACCESSOR_DECLS(FieldLHSAssignExpr) + ZEEK_EXPR_ACCESSOR_DECLS(ForExpr) + ZEEK_EXPR_ACCESSOR_DECLS(HasFieldExpr) + ZEEK_EXPR_ACCESSOR_DECLS(IndexAssignExpr) + ZEEK_EXPR_ACCESSOR_DECLS(IndexExpr) + ZEEK_EXPR_ACCESSOR_DECLS(InlineExpr) + ZEEK_EXPR_ACCESSOR_DECLS(IsExpr) + ZEEK_EXPR_ACCESSOR_DECLS(LambdaExpr) + ZEEK_EXPR_ACCESSOR_DECLS(ListExpr) + ZEEK_EXPR_ACCESSOR_DECLS(NameExpr) + ZEEK_EXPR_ACCESSOR_DECLS(RecordCoerceExpr) + ZEEK_EXPR_ACCESSOR_DECLS(RecordConstructorExpr) + ZEEK_EXPR_ACCESSOR_DECLS(RefExpr) + ZEEK_EXPR_ACCESSOR_DECLS(SetConstructorExpr) + ZEEK_EXPR_ACCESSOR_DECLS(TableConstructorExpr) - void Describe(ODesc* d) const override final; + void Describe(ODesc* d) const override final; - virtual TraversalCode Traverse(TraversalCallback* cb) const = 0; + virtual TraversalCode Traverse(TraversalCallback* cb) const = 0; - // Returns a duplicate of the expression. - virtual ExprPtr Duplicate() = 0; + // Returns a duplicate of the expression. + virtual ExprPtr Duplicate() = 0; - // Recursively traverses the AST to inline eligible function calls. - virtual ExprPtr Inline(Inliner* inl) { return ThisPtr(); } + // Recursively traverses the AST to inline eligible function calls. + virtual ExprPtr Inline(Inliner* inl) { return ThisPtr(); } - // True if the expression can serve as an operand to a reduced - // expression. - bool IsSingleton(Reducer* r) const - { - return (tag == EXPR_NAME && IsReduced(r)) || tag == EXPR_CONST; - } + // True if the expression can serve as an operand to a reduced + // expression. + bool IsSingleton(Reducer* r) const { return (tag == EXPR_NAME && IsReduced(r)) || tag == EXPR_CONST; } - // True if the expression has no side effects, false otherwise. - virtual bool HasNoSideEffects() const { return IsPure(); } + // True if the expression has no side effects, false otherwise. + virtual bool HasNoSideEffects() const { return IsPure(); } - // True if the expression is in fully reduced form: a singleton - // or an assignment to an operator with singleton operands. - virtual bool IsReduced(Reducer* c) const; + // True if the expression is in fully reduced form: a singleton + // or an assignment to an operator with singleton operands. + virtual bool IsReduced(Reducer* c) const; - // True if the expression's operands are singletons. - virtual bool HasReducedOps(Reducer* c) const; + // True if the expression's operands are singletons. + virtual bool HasReducedOps(Reducer* c) const; - // True if (a) the expression has at least one operand, and (b) all - // of its operands are constant. - bool HasConstantOps() const - { - return GetOp1() && GetOp1()->IsConst() && - (! GetOp2() || (GetOp2()->IsConst() && (! GetOp3() || GetOp3()->IsConst()))); - } + // True if (a) the expression has at least one operand, and (b) all + // of its operands are constant. + bool HasConstantOps() const { + return GetOp1() && GetOp1()->IsConst() && + (! GetOp2() || (GetOp2()->IsConst() && (! GetOp3() || GetOp3()->IsConst()))); + } - // True if the expression is reduced to a form that can be - // used in a conditional. - bool IsReducedConditional(Reducer* c) const; + // True if the expression is reduced to a form that can be + // used in a conditional. + bool IsReducedConditional(Reducer* c) const; - // True if the expression is reduced to a form that can be - // used in a field assignment. - bool IsReducedFieldAssignment(Reducer* c) const; + // True if the expression is reduced to a form that can be + // used in a field assignment. + bool IsReducedFieldAssignment(Reducer* c) const; - // True if this expression can be the RHS for a field assignment. - bool IsFieldAssignable(const Expr* e) const; + // True if this expression can be the RHS for a field assignment. + bool IsFieldAssignable(const Expr* e) const; - // True if the expression will transform to one of another AST node - // (perhaps of the same type) upon reduction, for non-constant - // operands. "Transform" means something beyond assignment to a - // temporary. Necessary so that we know to fully reduce such - // expressions if they're the RHS of an assignment. - virtual bool WillTransform(Reducer* c) const { return false; } + // True if the expression will transform to one of another AST node + // (perhaps of the same type) upon reduction, for non-constant + // operands. "Transform" means something beyond assignment to a + // temporary. Necessary so that we know to fully reduce such + // expressions if they're the RHS of an assignment. + virtual bool WillTransform(Reducer* c) const { return false; } - // The same, but for the expression when used in a conditional context. - virtual bool WillTransformInConditional(Reducer* c) const { return false; } + // The same, but for the expression when used in a conditional context. + virtual bool WillTransformInConditional(Reducer* c) const { return false; } - // Returns the current expression transformed into "new_me". - ExprPtr TransformMe(ExprPtr new_me, Reducer* c, StmtPtr& red_stmt); + // Returns the current expression transformed into "new_me". + ExprPtr TransformMe(ExprPtr new_me, Reducer* c, StmtPtr& red_stmt); - // Returns a set of predecessor statements in red_stmt (which might - // be nil if no reduction necessary), and the reduced version of - // the expression, suitable for replacing previous uses. The - // second version always yields a singleton suitable for use - // as an operand. The first version does this too except - // for assignment statements; thus, its form is not guarantee - // suitable for use as an operand. - virtual ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt); - virtual ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) { return Reduce(c, red_stmt); } + // Returns a set of predecessor statements in red_stmt (which might + // be nil if no reduction necessary), and the reduced version of + // the expression, suitable for replacing previous uses. The + // second version always yields a singleton suitable for use + // as an operand. The first version does this too except + // for assignment statements; thus, its form is not guarantee + // suitable for use as an operand. + virtual ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt); + virtual ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) { return Reduce(c, red_stmt); } - // Reduces the expression to one whose operands are singletons. - // Returns a predecessor statement (which might be a StmtList), if any. - virtual StmtPtr ReduceToSingletons(Reducer* c); + // Reduces the expression to one whose operands are singletons. + // Returns a predecessor statement (which might be a StmtList), if any. + virtual StmtPtr ReduceToSingletons(Reducer* c); - // Reduces the expression to one that can appear as a conditional. - ExprPtr ReduceToConditional(Reducer* c, StmtPtr& red_stmt); + // Reduces the expression to one that can appear as a conditional. + ExprPtr ReduceToConditional(Reducer* c, StmtPtr& red_stmt); - // Reduces the expression to one that can appear as a field - // assignment. - ExprPtr ReduceToFieldAssignment(Reducer* c, StmtPtr& red_stmt); + // Reduces the expression to one that can appear as a field + // assignment. + ExprPtr ReduceToFieldAssignment(Reducer* c, StmtPtr& red_stmt); - // Helper function for factoring out complexities related to - // index-based assignment. - void AssignToIndex(ValPtr v1, ValPtr v2, ValPtr v3) const; + // Helper function for factoring out complexities related to + // index-based assignment. + void AssignToIndex(ValPtr v1, ValPtr v2, ValPtr v3) const; - // Returns a new expression corresponding to a temporary - // that's been assigned to the given expression via red_stmt. - ExprPtr AssignToTemporary(ExprPtr e, Reducer* c, StmtPtr& red_stmt); - // Same but for this expression. - ExprPtr AssignToTemporary(Reducer* c, StmtPtr& red_stmt) - { - return AssignToTemporary(ThisPtr(), c, red_stmt); - } + // Returns a new expression corresponding to a temporary + // that's been assigned to the given expression via red_stmt. + ExprPtr AssignToTemporary(ExprPtr e, Reducer* c, StmtPtr& red_stmt); + // Same but for this expression. + ExprPtr AssignToTemporary(Reducer* c, StmtPtr& red_stmt) { return AssignToTemporary(ThisPtr(), c, red_stmt); } - // If the expression always evaluates to the same value, returns - // that value. Otherwise, returns nullptr. - virtual ValPtr FoldVal() const { return nullptr; } + // If the expression always evaluates to the same value, returns + // that value. Otherwise, returns nullptr. + virtual ValPtr FoldVal() const { return nullptr; } - // Returns a Val or a constant Expr corresponding to zero. - ValPtr MakeZero(TypeTag t) const; - ConstExprPtr MakeZeroExpr(TypeTag t) const; + // Returns a Val or a constant Expr corresponding to zero. + ValPtr MakeZero(TypeTag t) const; + ConstExprPtr MakeZeroExpr(TypeTag t) const; - // Returns the expression's operands, or nil if it doesn't - // have the given operand. - virtual ExprPtr GetOp1() const; - virtual ExprPtr GetOp2() const; - virtual ExprPtr GetOp3() const; + // Returns the expression's operands, or nil if it doesn't + // have the given operand. + virtual ExprPtr GetOp1() const; + virtual ExprPtr GetOp2() const; + virtual ExprPtr GetOp3() const; - // Sets the operands to new values. - virtual void SetOp1(ExprPtr new_op); - virtual void SetOp2(ExprPtr new_op); - virtual void SetOp3(ExprPtr new_op); + // Sets the operands to new values. + virtual void SetOp1(ExprPtr new_op); + virtual void SetOp2(ExprPtr new_op); + virtual void SetOp3(ExprPtr new_op); - // Helper function to reduce boring code runs. - StmtPtr MergeStmts(StmtPtr s1, StmtPtr s2, StmtPtr s3 = nullptr) const; + // Helper function to reduce boring code runs. + StmtPtr MergeStmts(StmtPtr s1, StmtPtr s2, StmtPtr s3 = nullptr) const; - // Access to the original expression from which this one is derived, - // or this one if we don't have an original. Returns a bare pointer - // rather than an ExprPtr to emphasize that the access is read-only. - const Expr* Original() const { return original ? original->Original() : this; } + // Access to the original expression from which this one is derived, + // or this one if we don't have an original. Returns a bare pointer + // rather than an ExprPtr to emphasize that the access is read-only. + const Expr* Original() const { return original ? original->Original() : this; } - // Designate the given Expr node as the original for this one. - void SetOriginal(ExprPtr _orig) - { - if ( ! original ) - original = std::move(_orig); - } + // Designate the given Expr node as the original for this one. + void SetOriginal(ExprPtr _orig) { + if ( ! original ) + original = std::move(_orig); + } - // A convenience function for taking a newly-created Expr, - // making it point to us as the successor, and returning it. - // - // Takes an Expr* rather than a ExprPtr to de-clutter the calling - // code, which is always passing in "new XyzExpr(...)". This - // call, as a convenient side effect, transforms that bare pointer - // into an ExprPtr. - virtual ExprPtr SetSucc(Expr* succ) - { - succ->SetOriginal(ThisPtr()); - if ( IsParen() ) - succ->MarkParen(); - return {AdoptRef{}, succ}; - } + // A convenience function for taking a newly-created Expr, + // making it point to us as the successor, and returning it. + // + // Takes an Expr* rather than a ExprPtr to de-clutter the calling + // code, which is always passing in "new XyzExpr(...)". This + // call, as a convenient side effect, transforms that bare pointer + // into an ExprPtr. + virtual ExprPtr SetSucc(Expr* succ) { + succ->SetOriginal(ThisPtr()); + if ( IsParen() ) + succ->MarkParen(); + return {AdoptRef{}, succ}; + } - const detail::Location* GetLocationInfo() const override - { - if ( original ) - return original->GetLocationInfo(); - else - return Obj::GetLocationInfo(); - } + const detail::Location* GetLocationInfo() const override { + if ( original ) + return original->GetLocationInfo(); + else + return Obj::GetLocationInfo(); + } - // Access script optimization information associated with - // this statement. - ExprOptInfo* GetOptInfo() const { return opt_info; } + // Access script optimization information associated with + // this statement. + ExprOptInfo* GetOptInfo() const { return opt_info; } - // Returns the number of expressions created since the last reset. - static int GetNumExprs() { return num_exprs; } + // Returns the number of expressions created since the last reset. + static int GetNumExprs() { return num_exprs; } - // Clears the number of expressions created. - static void ResetNumExprs() { num_exprs = 0; } + // Clears the number of expressions created. + static void ResetNumExprs() { num_exprs = 0; } - ~Expr() override; + ~Expr() override; protected: - Expr() = default; - explicit Expr(ExprTag arg_tag); + Expr() = default; + explicit Expr(ExprTag arg_tag); - virtual void ExprDescribe(ODesc* d) const = 0; - void AddTag(ODesc* d) const; + virtual void ExprDescribe(ODesc* d) const = 0; + void AddTag(ODesc* d) const; - // Puts the expression in canonical form. - virtual void Canonicalize(); + // Puts the expression in canonical form. + virtual void Canonicalize(); - void SetType(TypePtr t); + void SetType(TypePtr t); - // Reports the given error and sets the expression's type to - // TYPE_ERROR. - void ExprError(const char msg[]); + // Reports the given error and sets the expression's type to + // TYPE_ERROR. + void ExprError(const char msg[]); - // These two functions both call Reporter::RuntimeError or Reporter::ExprRuntimeError, - // both of which are marked as [[noreturn]]. - [[noreturn]] void RuntimeError(const std::string& msg) const; - [[noreturn]] void RuntimeErrorWithCallStack(const std::string& msg) const; + // These two functions both call Reporter::RuntimeError or Reporter::ExprRuntimeError, + // both of which are marked as [[noreturn]]. + [[noreturn]] void RuntimeError(const std::string& msg) const; + [[noreturn]] void RuntimeErrorWithCallStack(const std::string& msg) const; - ExprTag tag; - bool paren; - TypePtr type; + ExprTag tag; + bool paren; + TypePtr type; - // The original expression from which this statement was - // derived, if any. Used as an aid for generating meaningful - // and correctly-localized error messages. - ExprPtr original = nullptr; + // The original expression from which this statement was + // derived, if any. Used as an aid for generating meaningful + // and correctly-localized error messages. + ExprPtr original = nullptr; - // Information associated with the Expr for purposes of - // script optimization. - ExprOptInfo* opt_info; + // Information associated with the Expr for purposes of + // script optimization. + ExprOptInfo* opt_info; - // Number of expressions created thus far. - static int num_exprs; - }; + // Number of expressions created thus far. + static int num_exprs; +}; -class NameExpr final : public Expr - { +class NameExpr final : public Expr { public: - explicit NameExpr(IDPtr id, bool const_init = false); + explicit NameExpr(IDPtr id, bool const_init = false); - ID* Id() const { return id.get(); } - const IDPtr& IdPtr() const; + ID* Id() const { return id.get(); } + const IDPtr& IdPtr() const; - ValPtr Eval(Frame* f) const override; - void Assign(Frame* f, ValPtr v) override; - ExprPtr MakeLvalue() override; + ValPtr Eval(Frame* f) const override; + void Assign(Frame* f, ValPtr v) override; + ExprPtr MakeLvalue() override; - TraversalCode Traverse(TraversalCallback* cb) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; - // Optimization-related: - ExprPtr Duplicate() override; - bool HasNoSideEffects() const override { return true; } - bool IsReduced(Reducer* c) const override; - bool HasReducedOps(Reducer* c) const override { return IsReduced(c); } - bool WillTransform(Reducer* c) const override { return ! IsReduced(c); } - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - ValPtr FoldVal() const override; + // Optimization-related: + ExprPtr Duplicate() override; + bool HasNoSideEffects() const override { return true; } + bool IsReduced(Reducer* c) const override; + bool HasReducedOps(Reducer* c) const override { return IsReduced(c); } + bool WillTransform(Reducer* c) const override { return ! IsReduced(c); } + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + ValPtr FoldVal() const override; protected: - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - // Returns true if our identifier is a global with a constant value - // that can be propagated; used for optimization. - bool FoldableGlobal() const; + // Returns true if our identifier is a global with a constant value + // that can be propagated; used for optimization. + bool FoldableGlobal() const; - IDPtr id; - bool in_const_init; - }; + IDPtr id; + bool in_const_init; +}; -class ConstExpr final : public Expr - { +class ConstExpr final : public Expr { public: - explicit ConstExpr(ValPtr val); + explicit ConstExpr(ValPtr val); - Val* Value() const { return val.get(); } - ValPtr ValuePtr() const { return val; } + Val* Value() const { return val.get(); } + ValPtr ValuePtr() const { return val; } - ValPtr Eval(Frame* f) const override; + ValPtr Eval(Frame* f) const override; - TraversalCode Traverse(TraversalCallback* cb) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; - // Optimization-related: - ExprPtr Duplicate() override; - ValPtr FoldVal() const override { return val; } + // Optimization-related: + ExprPtr Duplicate() override; + ValPtr FoldVal() const override { return val; } protected: - void ExprDescribe(ODesc* d) const override; - ValPtr val; - }; + void ExprDescribe(ODesc* d) const override; + ValPtr val; +}; -class UnaryExpr : public Expr - { +class UnaryExpr : public Expr { public: - Expr* Op() const { return op.get(); } + Expr* Op() const { return op.get(); } - // UnaryExpr::Eval correctly handles vector types. Any child - // class that overrides Eval() should be modified to handle - // vectors correctly as necessary. - ValPtr Eval(Frame* f) const override; + // UnaryExpr::Eval correctly handles vector types. Any child + // class that overrides Eval() should be modified to handle + // vectors correctly as necessary. + ValPtr Eval(Frame* f) const override; - bool IsPure() const override; + bool IsPure() const override; - TraversalCode Traverse(TraversalCallback* cb) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; - // Optimization-related: - ExprPtr Inline(Inliner* inl) override; + // Optimization-related: + ExprPtr Inline(Inliner* inl) override; - bool HasNoSideEffects() const override; - bool IsReduced(Reducer* c) const override; - bool HasReducedOps(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + bool HasNoSideEffects() const override; + bool IsReduced(Reducer* c) const override; + bool HasReducedOps(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - ExprPtr GetOp1() const override final { return op; } - void SetOp1(ExprPtr _op) override final { op = std::move(_op); } + ExprPtr GetOp1() const override final { return op; } + void SetOp1(ExprPtr _op) override final { op = std::move(_op); } protected: - UnaryExpr(ExprTag arg_tag, ExprPtr arg_op); + UnaryExpr(ExprTag arg_tag, ExprPtr arg_op); - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - // Returns the expression folded using the given constant. - virtual ValPtr Fold(Val* v) const; + // Returns the expression folded using the given constant. + virtual ValPtr Fold(Val* v) const; - ExprPtr op; - }; + ExprPtr op; +}; -class BinaryExpr : public Expr - { +class BinaryExpr : public Expr { public: - Expr* Op1() const { return op1.get(); } - Expr* Op2() const { return op2.get(); } + Expr* Op1() const { return op1.get(); } + Expr* Op2() const { return op2.get(); } - bool IsPure() const override; + bool IsPure() const override; - // BinaryExpr::Eval correctly handles vector types. Any child - // class that overrides Eval() should be modified to handle - // vectors correctly as necessary. - ValPtr Eval(Frame* f) const override; + // BinaryExpr::Eval correctly handles vector types. Any child + // class that overrides Eval() should be modified to handle + // vectors correctly as necessary. + ValPtr Eval(Frame* f) const override; - TraversalCode Traverse(TraversalCallback* cb) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; - // Optimization-related: - ExprPtr Inline(Inliner* inl) override; + // Optimization-related: + ExprPtr Inline(Inliner* inl) override; - bool HasNoSideEffects() const override; - bool IsReduced(Reducer* c) const override; - bool HasReducedOps(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + bool HasNoSideEffects() const override; + bool IsReduced(Reducer* c) const override; + bool HasReducedOps(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - ExprPtr GetOp1() const override final { return op1; } - ExprPtr GetOp2() const override final { return op2; } + ExprPtr GetOp1() const override final { return op1; } + ExprPtr GetOp2() const override final { return op2; } - void SetOp1(ExprPtr _op) override final { op1 = std::move(_op); } - void SetOp2(ExprPtr _op) override final { op2 = std::move(_op); } + void SetOp1(ExprPtr _op) override final { op1 = std::move(_op); } + void SetOp2(ExprPtr _op) override final { op2 = std::move(_op); } protected: - BinaryExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2) - : Expr(arg_tag), op1(std::move(arg_op1)), op2(std::move(arg_op2)) - { - if ( ! (op1 && op2) ) - return; - if ( op1->IsError() || op2->IsError() ) - SetError(); - } + BinaryExpr(ExprTag arg_tag, ExprPtr arg_op1, ExprPtr arg_op2) + : Expr(arg_tag), op1(std::move(arg_op1)), op2(std::move(arg_op2)) { + if ( ! (op1 && op2) ) + return; + if ( op1->IsError() || op2->IsError() ) + SetError(); + } - // Returns the expression folded using the given constants. - virtual ValPtr Fold(Val* v1, Val* v2) const; + // Returns the expression folded using the given constants. + virtual ValPtr Fold(Val* v1, Val* v2) const; - // Same for when the constants are strings. - virtual ValPtr StringFold(Val* v1, Val* v2) const; + // Same for when the constants are strings. + virtual ValPtr StringFold(Val* v1, Val* v2) const; - // Same for when the constants are patterns. - virtual ValPtr PatternFold(Val* v1, Val* v2) const; + // Same for when the constants are patterns. + virtual ValPtr PatternFold(Val* v1, Val* v2) const; - // Same for when the constants are sets. - virtual ValPtr SetFold(Val* v1, Val* v2) const; + // Same for when the constants are sets. + virtual ValPtr SetFold(Val* v1, Val* v2) const; - // Same for when the constants are tables. - virtual ValPtr TableFold(Val* v1, Val* v2) const; + // Same for when the constants are tables. + virtual ValPtr TableFold(Val* v1, Val* v2) const; - // Same for when the constants are addresses or subnets. - virtual ValPtr AddrFold(Val* v1, Val* v2) const; - virtual ValPtr SubNetFold(Val* v1, Val* v2) const; + // Same for when the constants are addresses or subnets. + virtual ValPtr AddrFold(Val* v1, Val* v2) const; + virtual ValPtr SubNetFold(Val* v1, Val* v2) const; - bool BothConst() const { return op1->IsConst() && op2->IsConst(); } + bool BothConst() const { return op1->IsConst() && op2->IsConst(); } - // Exchange op1 and op2. - void SwapOps(); + // Exchange op1 and op2. + void SwapOps(); - // Promote the operands to the given type tag, if necessary. - void PromoteOps(TypeTag t); + // Promote the operands to the given type tag, if necessary. + void PromoteOps(TypeTag t); - // Promote the expression to the given type tag (i.e., promote - // operands and also set expression's type). - void PromoteType(TypeTag t, bool is_vector); + // Promote the expression to the given type tag (i.e., promote + // operands and also set expression's type). + void PromoteType(TypeTag t, bool is_vector); - // Promote one of the operands to be "double" (if not already), - // to make it suitable for combining with the other "interval" - // operand, yielding an "interval" type. - void PromoteForInterval(ExprPtr& op); + // Promote one of the operands to be "double" (if not already), + // to make it suitable for combining with the other "interval" + // operand, yielding an "interval" type. + void PromoteForInterval(ExprPtr& op); - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - // For assignment operations (=, +=, -=) checks for a valid - // expression-list on the RHS (op2), potentially transforming - // op2 in the process. Returns true if the list is present - // and type-checks correctly, false otherwise. - bool CheckForRHSList(); + // For assignment operations (=, +=, -=) checks for a valid + // expression-list on the RHS (op2), potentially transforming + // op2 in the process. Returns true if the list is present + // and type-checks correctly, false otherwise. + bool CheckForRHSList(); - ExprPtr op1; - ExprPtr op2; - }; + ExprPtr op1; + ExprPtr op2; +}; -class CloneExpr final : public UnaryExpr - { +class CloneExpr final : public UnaryExpr { public: - explicit CloneExpr(ExprPtr op); - ValPtr Eval(Frame* f) const override; + explicit CloneExpr(ExprPtr op); + ValPtr Eval(Frame* f) const override; - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; protected: - ValPtr Fold(Val* v) const override; - }; + ValPtr Fold(Val* v) const override; +}; -class IncrExpr final : public UnaryExpr - { +class IncrExpr final : public UnaryExpr { public: - IncrExpr(ExprTag tag, ExprPtr op); + IncrExpr(ExprTag tag, ExprPtr op); - ValPtr Eval(Frame* f) const override; - ValPtr DoSingleEval(Frame* f, Val* v) const; - bool IsPure() const override { return false; } + ValPtr Eval(Frame* f) const override; + ValPtr DoSingleEval(Frame* f, Val* v) const; + bool IsPure() const override { return false; } - // Optimization-related: - ExprPtr Duplicate() override; - bool HasNoSideEffects() const override; - bool WillTransform(Reducer* c) const override { return true; } - bool IsReduced(Reducer* c) const override; - bool HasReducedOps(Reducer* c) const override { return false; } - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; - }; + // Optimization-related: + ExprPtr Duplicate() override; + bool HasNoSideEffects() const override; + bool WillTransform(Reducer* c) const override { return true; } + bool IsReduced(Reducer* c) const override; + bool HasReducedOps(Reducer* c) const override { return false; } + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; +}; -class ComplementExpr final : public UnaryExpr - { +class ComplementExpr final : public UnaryExpr { public: - explicit ComplementExpr(ExprPtr op); + explicit ComplementExpr(ExprPtr op); - // Optimization-related: - ExprPtr Duplicate() override; - bool WillTransform(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + // Optimization-related: + ExprPtr Duplicate() override; + bool WillTransform(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; protected: - ValPtr Fold(Val* v) const override; - }; + ValPtr Fold(Val* v) const override; +}; -class NotExpr final : public UnaryExpr - { +class NotExpr final : public UnaryExpr { public: - explicit NotExpr(ExprPtr op); + explicit NotExpr(ExprPtr op); - // Optimization-related: - ExprPtr Duplicate() override; - bool WillTransform(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + // Optimization-related: + ExprPtr Duplicate() override; + bool WillTransform(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; protected: - ValPtr Fold(Val* v) const override; - }; + ValPtr Fold(Val* v) const override; +}; -class PosExpr final : public UnaryExpr - { +class PosExpr final : public UnaryExpr { public: - explicit PosExpr(ExprPtr op); + explicit PosExpr(ExprPtr op); - // Optimization-related: - ExprPtr Duplicate() override; - bool WillTransform(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + // Optimization-related: + ExprPtr Duplicate() override; + bool WillTransform(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; protected: - ValPtr Fold(Val* v) const override; - }; + ValPtr Fold(Val* v) const override; +}; -class NegExpr final : public UnaryExpr - { +class NegExpr final : public UnaryExpr { public: - explicit NegExpr(ExprPtr op); + explicit NegExpr(ExprPtr op); - // Optimization-related: - ExprPtr Duplicate() override; - bool WillTransform(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + // Optimization-related: + ExprPtr Duplicate() override; + bool WillTransform(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; protected: - ValPtr Fold(Val* v) const override; - }; + ValPtr Fold(Val* v) const override; +}; -class SizeExpr final : public UnaryExpr - { +class SizeExpr final : public UnaryExpr { public: - explicit SizeExpr(ExprPtr op); - ValPtr Eval(Frame* f) const override; + explicit SizeExpr(ExprPtr op); + ValPtr Eval(Frame* f) const override; - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; protected: - ValPtr Fold(Val* v) const override; - }; + ValPtr Fold(Val* v) const override; +}; -class AddExpr final : public BinaryExpr - { +class AddExpr final : public BinaryExpr { public: - AddExpr(ExprPtr op1, ExprPtr op2); - void Canonicalize() override; + AddExpr(ExprPtr op1, ExprPtr op2); + void Canonicalize() override; - // Optimization-related: - ExprPtr Duplicate() override; - bool WillTransform(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + // Optimization-related: + ExprPtr Duplicate() override; + bool WillTransform(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; protected: - ExprPtr BuildSub(const ExprPtr& op1, const ExprPtr& op2); - }; + ExprPtr BuildSub(const ExprPtr& op1, const ExprPtr& op2); +}; -class AddToExpr final : public BinaryExpr - { +class AddToExpr final : public BinaryExpr { public: - AddToExpr(ExprPtr op1, ExprPtr op2); - ValPtr Eval(Frame* f) const override; + AddToExpr(ExprPtr op1, ExprPtr op2); + ValPtr Eval(Frame* f) const override; - // Optimization-related: - bool IsPure() const override { return false; } - ExprPtr Duplicate() override; - bool HasReducedOps(Reducer* c) const override { return false; } - bool WillTransform(Reducer* c) const override { return true; } - bool IsReduced(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; + // Optimization-related: + bool IsPure() const override { return false; } + ExprPtr Duplicate() override; + bool HasReducedOps(Reducer* c) const override { return false; } + bool WillTransform(Reducer* c) const override { return true; } + bool IsReduced(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; private: - // Whether this operation is appending a single element to a vector. - bool is_vector_elem_append = false; - }; + // Whether this operation is appending a single element to a vector. + bool is_vector_elem_append = false; +}; -class RemoveFromExpr final : public BinaryExpr - { +class RemoveFromExpr final : public BinaryExpr { public: - bool IsPure() const override { return false; } - RemoveFromExpr(ExprPtr op1, ExprPtr op2); - ValPtr Eval(Frame* f) const override; + bool IsPure() const override { return false; } + RemoveFromExpr(ExprPtr op1, ExprPtr op2); + ValPtr Eval(Frame* f) const override; - // Optimization-related: - ExprPtr Duplicate() override; - bool HasReducedOps(Reducer* c) const override { return false; } - bool WillTransform(Reducer* c) const override { return true; } - bool IsReduced(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; - }; + // Optimization-related: + ExprPtr Duplicate() override; + bool HasReducedOps(Reducer* c) const override { return false; } + bool WillTransform(Reducer* c) const override { return true; } + bool IsReduced(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; +}; -class SubExpr final : public BinaryExpr - { +class SubExpr final : public BinaryExpr { public: - SubExpr(ExprPtr op1, ExprPtr op2); + SubExpr(ExprPtr op1, ExprPtr op2); - // Optimization-related: - ExprPtr Duplicate() override; - bool WillTransform(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - }; + // Optimization-related: + ExprPtr Duplicate() override; + bool WillTransform(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; +}; -class TimesExpr final : public BinaryExpr - { +class TimesExpr final : public BinaryExpr { public: - TimesExpr(ExprPtr op1, ExprPtr op2); - void Canonicalize() override; + TimesExpr(ExprPtr op1, ExprPtr op2); + void Canonicalize() override; - // Optimization-related: - ExprPtr Duplicate() override; - bool WillTransform(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - }; + // Optimization-related: + ExprPtr Duplicate() override; + bool WillTransform(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; +}; -class DivideExpr final : public BinaryExpr - { +class DivideExpr final : public BinaryExpr { public: - DivideExpr(ExprPtr op1, ExprPtr op2); + DivideExpr(ExprPtr op1, ExprPtr op2); - // Optimization-related: - ExprPtr Duplicate() override; - bool WillTransform(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - }; + // Optimization-related: + ExprPtr Duplicate() override; + bool WillTransform(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; +}; -class MaskExpr final : public BinaryExpr - { +class MaskExpr final : public BinaryExpr { public: - MaskExpr(ExprPtr op1, ExprPtr op2); + MaskExpr(ExprPtr op1, ExprPtr op2); - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; protected: - ValPtr AddrFold(Val* v1, Val* v2) const override; - }; + ValPtr AddrFold(Val* v1, Val* v2) const override; +}; -class ModExpr final : public BinaryExpr - { +class ModExpr final : public BinaryExpr { public: - ModExpr(ExprPtr op1, ExprPtr op2); + ModExpr(ExprPtr op1, ExprPtr op2); - // Optimization-related: - ExprPtr Duplicate() override; - }; + // Optimization-related: + ExprPtr Duplicate() override; +}; -class BoolExpr final : public BinaryExpr - { +class BoolExpr final : public BinaryExpr { public: - BoolExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); + BoolExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); - ValPtr Eval(Frame* f) const override; - ValPtr DoSingleEval(Frame* f, ValPtr v1, Expr* op2) const; + ValPtr Eval(Frame* f) const override; + ValPtr DoSingleEval(Frame* f, ValPtr v1, Expr* op2) const; - // Optimization-related: - ExprPtr Duplicate() override; - bool WillTransform(Reducer* c) const override; - bool WillTransformInConditional(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + // Optimization-related: + ExprPtr Duplicate() override; + bool WillTransform(Reducer* c) const override; + bool WillTransformInConditional(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; protected: - bool IsTrue(const ExprPtr& e) const; - bool IsFalse(const ExprPtr& e) const; - }; + bool IsTrue(const ExprPtr& e) const; + bool IsFalse(const ExprPtr& e) const; +}; -class BitExpr final : public BinaryExpr - { +class BitExpr final : public BinaryExpr { public: - BitExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); + BitExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); - // Optimization-related: - ExprPtr Duplicate() override; - bool WillTransform(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - }; + // Optimization-related: + ExprPtr Duplicate() override; + bool WillTransform(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; +}; -class EqExpr final : public BinaryExpr - { +class EqExpr final : public BinaryExpr { public: - EqExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); - void Canonicalize() override; + EqExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); + void Canonicalize() override; - // Optimization-related: - ExprPtr Duplicate() override; - bool WillTransform(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - bool InvertSense() override; + // Optimization-related: + ExprPtr Duplicate() override; + bool WillTransform(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + bool InvertSense() override; protected: - ValPtr Fold(Val* v1, Val* v2) const override; - }; + ValPtr Fold(Val* v1, Val* v2) const override; +}; -class RelExpr final : public BinaryExpr - { +class RelExpr final : public BinaryExpr { public: - RelExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); - void Canonicalize() override; + RelExpr(ExprTag tag, ExprPtr op1, ExprPtr op2); + void Canonicalize() override; - // Optimization-related: - ExprPtr Duplicate() override; - bool WillTransform(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - bool InvertSense() override; - }; + // Optimization-related: + ExprPtr Duplicate() override; + bool WillTransform(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + bool InvertSense() override; +}; -class CondExpr final : public Expr - { +class CondExpr final : public Expr { public: - CondExpr(ExprPtr op1, ExprPtr op2, ExprPtr op3); + CondExpr(ExprPtr op1, ExprPtr op2, ExprPtr op3); - const Expr* Op1() const { return op1.get(); } - const Expr* Op2() const { return op2.get(); } - const Expr* Op3() const { return op3.get(); } + const Expr* Op1() const { return op1.get(); } + const Expr* Op2() const { return op2.get(); } + const Expr* Op3() const { return op3.get(); } - ValPtr Eval(Frame* f) const override; - bool IsPure() const override; + ValPtr Eval(Frame* f) const override; + bool IsPure() const override; - TraversalCode Traverse(TraversalCallback* cb) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; - // Optimization-related: - ExprPtr Duplicate() override; - ExprPtr Inline(Inliner* inl) override; + // Optimization-related: + ExprPtr Duplicate() override; + ExprPtr Inline(Inliner* inl) override; - bool WillTransform(Reducer* c) const override; - bool IsReduced(Reducer* c) const override; - bool HasReducedOps(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - StmtPtr ReduceToSingletons(Reducer* c) override; + bool WillTransform(Reducer* c) const override; + bool IsReduced(Reducer* c) const override; + bool HasReducedOps(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + StmtPtr ReduceToSingletons(Reducer* c) override; - ExprPtr GetOp1() const override final { return op1; } - ExprPtr GetOp2() const override final { return op2; } - ExprPtr GetOp3() const override final { return op3; } + ExprPtr GetOp1() const override final { return op1; } + ExprPtr GetOp2() const override final { return op2; } + ExprPtr GetOp3() const override final { return op3; } - void SetOp1(ExprPtr _op) override final { op1 = std::move(_op); } - void SetOp2(ExprPtr _op) override final { op2 = std::move(_op); } - void SetOp3(ExprPtr _op) override final { op3 = std::move(_op); } + void SetOp1(ExprPtr _op) override final { op1 = std::move(_op); } + void SetOp2(ExprPtr _op) override final { op2 = std::move(_op); } + void SetOp3(ExprPtr _op) override final { op3 = std::move(_op); } protected: - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - ExprPtr op1; - ExprPtr op2; - ExprPtr op3; - }; + ExprPtr op1; + ExprPtr op2; + ExprPtr op3; +}; -class RefExpr final : public UnaryExpr - { +class RefExpr final : public UnaryExpr { public: - explicit RefExpr(ExprPtr op); + explicit RefExpr(ExprPtr op); - void Assign(Frame* f, ValPtr v) override; - ExprPtr MakeLvalue() override; + void Assign(Frame* f, ValPtr v) override; + ExprPtr MakeLvalue() override; - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; - bool WillTransform(Reducer* c) const override; - bool IsReduced(Reducer* c) const override; - bool HasReducedOps(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + bool WillTransform(Reducer* c) const override; + bool IsReduced(Reducer* c) const override; + bool HasReducedOps(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - // Reduce to simplified LHS form, i.e., a reference to only a name. - StmtPtr ReduceToLHS(Reducer* c); - }; + // Reduce to simplified LHS form, i.e., a reference to only a name. + StmtPtr ReduceToLHS(Reducer* c); +}; -class AssignExpr : public BinaryExpr - { +class AssignExpr : public BinaryExpr { public: - // If val is given, evaluating this expression will always yield the val - // yet still perform the assignment. Used for triggers. - AssignExpr(ExprPtr op1, ExprPtr op2, bool is_init, ValPtr val = nullptr, - const AttributesPtr& attrs = nullptr, bool type_check = true); + // If val is given, evaluating this expression will always yield the val + // yet still perform the assignment. Used for triggers. + AssignExpr(ExprPtr op1, ExprPtr op2, bool is_init, ValPtr val = nullptr, const AttributesPtr& attrs = nullptr, + bool type_check = true); - ValPtr Eval(Frame* f) const override; - TypePtr InitType() const override; - bool IsRecordElement(TypeDecl* td) const override; - bool IsPure() const override { return false; } + ValPtr Eval(Frame* f) const override; + TypePtr InitType() const override; + bool IsRecordElement(TypeDecl* td) const override; + bool IsPure() const override { return false; } - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; - bool HasNoSideEffects() const override; - bool WillTransform(Reducer* c) const override { return true; } - bool IsReduced(Reducer* c) const override; - bool HasReducedOps(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; + bool HasNoSideEffects() const override; + bool WillTransform(Reducer* c) const override { return true; } + bool IsReduced(Reducer* c) const override; + bool HasReducedOps(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; - // Whether this is an assignment to a temporary. - bool IsTemp() const { return is_temp; } - void SetIsTemp() { is_temp = true; } + // Whether this is an assignment to a temporary. + bool IsTemp() const { return is_temp; } + void SetIsTemp() { is_temp = true; } - // The following is a hack that's used in "when" expressions to support - // assignments to new locals, like "when ( (local l = foo()) && ...". - // These methods return the value to use when evaluating such - // assignments. That would normally be the RHS of the assignment, - // but to get when's to work in a convenient fashion, for them it's - // instead boolean T. - ValPtr AssignVal() { return val; } - const ValPtr& AssignVal() const { return val; } + // The following is a hack that's used in "when" expressions to support + // assignments to new locals, like "when ( (local l = foo()) && ...". + // These methods return the value to use when evaluating such + // assignments. That would normally be the RHS of the assignment, + // but to get when's to work in a convenient fashion, for them it's + // instead boolean T. + ValPtr AssignVal() { return val; } + const ValPtr& AssignVal() const { return val; } protected: - bool TypeCheck(const AttributesPtr& attrs = nullptr); - bool TypeCheckArithmetics(TypeTag bt1, TypeTag bt2); + bool TypeCheck(const AttributesPtr& attrs = nullptr); + bool TypeCheckArithmetics(TypeTag bt1, TypeTag bt2); - bool is_init; - bool is_temp = false; // Optimization related + bool is_init; + bool is_temp = false; // Optimization related - ValPtr val; // optional - }; + ValPtr val; // optional +}; -class IndexSliceAssignExpr final : public AssignExpr - { +class IndexSliceAssignExpr final : public AssignExpr { public: - IndexSliceAssignExpr(ExprPtr op1, ExprPtr op2, bool is_init); - ValPtr Eval(Frame* f) const override; + IndexSliceAssignExpr(ExprPtr op1, ExprPtr op2, bool is_init); + ValPtr Eval(Frame* f) const override; - // Optimization-related: - ExprPtr Duplicate() override; - }; + // Optimization-related: + ExprPtr Duplicate() override; +}; -class IndexExpr : public BinaryExpr - { +class IndexExpr : public BinaryExpr { public: - IndexExpr(ExprPtr op1, ListExprPtr op2, bool is_slice = false, bool is_inside_when = false); + IndexExpr(ExprPtr op1, ListExprPtr op2, bool is_slice = false, bool is_inside_when = false); - bool CanAdd() const override; - bool CanDel() const override; + bool CanAdd() const override; + bool CanDel() const override; - void Add(Frame* f) override; - void Delete(Frame* f) override; + void Add(Frame* f) override; + void Delete(Frame* f) override; - void Assign(Frame* f, ValPtr v) override; - ExprPtr MakeLvalue() override; + void Assign(Frame* f, ValPtr v) override; + ExprPtr MakeLvalue() override; - // Need to override Eval since it can take a vector arg but does - // not necessarily return a vector. - ValPtr Eval(Frame* f) const override; + // Need to override Eval since it can take a vector arg but does + // not necessarily return a vector. + ValPtr Eval(Frame* f) const override; - bool IsSlice() const { return is_slice; } - bool IsInsideWhen() const { return is_inside_when; } + bool IsSlice() const { return is_slice; } + bool IsInsideWhen() const { return is_inside_when; } - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; - bool HasReducedOps(Reducer* c) const override; - StmtPtr ReduceToSingletons(Reducer* c) override; + bool HasReducedOps(Reducer* c) const override; + StmtPtr ReduceToSingletons(Reducer* c) override; protected: - ValPtr Fold(Val* v1, Val* v2) const override; + ValPtr Fold(Val* v1, Val* v2) const override; - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - bool is_slice; - bool is_inside_when; - }; + bool is_slice; + bool is_inside_when; +}; // The following execute the heart of IndexExpr functionality for // vector slices and strings. @@ -1084,753 +1044,716 @@ extern VectorValPtr vector_int_select(VectorTypePtr vt, const VectorVal* v1, con // // TODO: One Fine Day we should do the equivalent for accessing fields // in records, too. -class IndexExprWhen final : public IndexExpr - { +class IndexExprWhen final : public IndexExpr { public: - static inline std::vector results = {}; - static inline int evaluating = 0; + static inline std::vector results = {}; + static inline int evaluating = 0; - static void StartEval() { ++evaluating; } + static void StartEval() { ++evaluating; } - static void EndEval() { --evaluating; } + static void EndEval() { --evaluating; } - static std::vector TakeAllResults() - { - auto rval = std::move(results); - results = {}; - return rval; - } + static std::vector TakeAllResults() { + auto rval = std::move(results); + results = {}; + return rval; + } - IndexExprWhen(ExprPtr op1, ListExprPtr op2, bool is_slice = false) - : IndexExpr(std::move(op1), std::move(op2), is_slice, true) - { - } + IndexExprWhen(ExprPtr op1, ListExprPtr op2, bool is_slice = false) + : IndexExpr(std::move(op1), std::move(op2), is_slice, true) {} - ValPtr Eval(Frame* f) const override - { - auto v = IndexExpr::Eval(f); + ValPtr Eval(Frame* f) const override { + auto v = IndexExpr::Eval(f); - if ( v && evaluating > 0 ) - results.emplace_back(v); + if ( v && evaluating > 0 ) + results.emplace_back(v); - return v; - } + return v; + } - // Optimization-related: - ExprPtr Duplicate() override; - }; + // Optimization-related: + ExprPtr Duplicate() override; +}; -class FieldExpr final : public UnaryExpr - { +class FieldExpr final : public UnaryExpr { public: - FieldExpr(ExprPtr op, const char* field_name); - ~FieldExpr() override; + FieldExpr(ExprPtr op, const char* field_name); + ~FieldExpr() override; - int Field() const { return field; } - const char* FieldName() const { return field_name; } + int Field() const { return field; } + const char* FieldName() const { return field_name; } - bool CanDel() const override; + bool CanDel() const override; - void Assign(Frame* f, ValPtr v) override; - void Delete(Frame* f) override; + void Assign(Frame* f, ValPtr v) override; + void Delete(Frame* f) override; - ExprPtr MakeLvalue() override; + ExprPtr MakeLvalue() override; - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; protected: - ValPtr Fold(Val* v) const override; + ValPtr Fold(Val* v) const override; - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - const char* field_name; - const TypeDecl* td; - int field; // -1 = attributes - }; + const char* field_name; + const TypeDecl* td; + int field; // -1 = attributes +}; // "rec?$fieldname" is true if the value of $fieldname in rec is not nil. // "rec?$$attrname" is true if the attribute attrname is not nil. -class HasFieldExpr final : public UnaryExpr - { +class HasFieldExpr final : public UnaryExpr { public: - HasFieldExpr(ExprPtr op, const char* field_name); - ~HasFieldExpr() override; + HasFieldExpr(ExprPtr op, const char* field_name); + ~HasFieldExpr() override; - const char* FieldName() const { return field_name; } - int Field() const { return field; } + const char* FieldName() const { return field_name; } + int Field() const { return field; } - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; - bool IsReduced(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + bool IsReduced(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; protected: - ValPtr Fold(Val* v) const override; + ValPtr Fold(Val* v) const override; - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - const char* field_name; - int field; - }; + const char* field_name; + int field; +}; -class RecordConstructorExpr final : public Expr - { +class RecordConstructorExpr final : public Expr { public: - explicit RecordConstructorExpr(ListExprPtr constructor_list); + explicit RecordConstructorExpr(ListExprPtr constructor_list); - // This form is used to construct records of a known (ultimate) type. - explicit RecordConstructorExpr(RecordTypePtr known_rt, ListExprPtr constructor_list); + // This form is used to construct records of a known (ultimate) type. + explicit RecordConstructorExpr(RecordTypePtr known_rt, ListExprPtr constructor_list); - ListExprPtr Op() const { return op; } - const auto& Map() const { return map; } + ListExprPtr Op() const { return op; } + const auto& Map() const { return map; } - ValPtr Eval(Frame* f) const override; + ValPtr Eval(Frame* f) const override; - bool IsPure() const override; + bool IsPure() const override; - TraversalCode Traverse(TraversalCallback* cb) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; - // Optimization-related: - ExprPtr Duplicate() override; - ExprPtr Inline(Inliner* inl) override; + // Optimization-related: + ExprPtr Duplicate() override; + ExprPtr Inline(Inliner* inl) override; - bool HasReducedOps(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - StmtPtr ReduceToSingletons(Reducer* c) override; + bool HasReducedOps(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + StmtPtr ReduceToSingletons(Reducer* c) override; protected: - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - ListExprPtr op; - std::optional> map; - }; + ListExprPtr op; + std::optional> map; +}; -class TableConstructorExpr final : public UnaryExpr - { +class TableConstructorExpr final : public UnaryExpr { public: - TableConstructorExpr(ListExprPtr constructor_list, std::unique_ptr> attrs, - TypePtr arg_type = nullptr, AttributesPtr arg_attrs = nullptr); + TableConstructorExpr(ListExprPtr constructor_list, std::unique_ptr> attrs, + TypePtr arg_type = nullptr, AttributesPtr arg_attrs = nullptr); - void SetAttrs(AttributesPtr _attrs) { attrs = std::move(_attrs); } - const AttributesPtr& GetAttrs() const { return attrs; } + void SetAttrs(AttributesPtr _attrs) { attrs = std::move(_attrs); } + const AttributesPtr& GetAttrs() const { return attrs; } - ValPtr Eval(Frame* f) const override; + ValPtr Eval(Frame* f) const override; - TraversalCode Traverse(TraversalCallback* cb) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; - bool HasReducedOps(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - StmtPtr ReduceToSingletons(Reducer* c) override; + bool HasReducedOps(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + StmtPtr ReduceToSingletons(Reducer* c) override; protected: - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - AttributesPtr attrs; - }; + AttributesPtr attrs; +}; -class SetConstructorExpr final : public UnaryExpr - { +class SetConstructorExpr final : public UnaryExpr { public: - SetConstructorExpr(ListExprPtr constructor_list, std::unique_ptr> attrs, - TypePtr arg_type = nullptr, AttributesPtr arg_attrs = nullptr); + SetConstructorExpr(ListExprPtr constructor_list, std::unique_ptr> attrs, + TypePtr arg_type = nullptr, AttributesPtr arg_attrs = nullptr); - void SetAttrs(AttributesPtr _attrs) { attrs = std::move(_attrs); } - const AttributesPtr& GetAttrs() const { return attrs; } + void SetAttrs(AttributesPtr _attrs) { attrs = std::move(_attrs); } + const AttributesPtr& GetAttrs() const { return attrs; } - ValPtr Eval(Frame* f) const override; + ValPtr Eval(Frame* f) const override; - TraversalCode Traverse(TraversalCallback* cb) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; - bool HasReducedOps(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - StmtPtr ReduceToSingletons(Reducer* c) override; + bool HasReducedOps(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + StmtPtr ReduceToSingletons(Reducer* c) override; protected: - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - AttributesPtr attrs; - }; + AttributesPtr attrs; +}; -class VectorConstructorExpr final : public UnaryExpr - { +class VectorConstructorExpr final : public UnaryExpr { public: - explicit VectorConstructorExpr(ListExprPtr constructor_list, TypePtr arg_type = nullptr); + explicit VectorConstructorExpr(ListExprPtr constructor_list, TypePtr arg_type = nullptr); - ValPtr Eval(Frame* f) const override; + ValPtr Eval(Frame* f) const override; - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; - bool HasReducedOps(Reducer* c) const override; + bool HasReducedOps(Reducer* c) const override; protected: - void ExprDescribe(ODesc* d) const override; - }; + void ExprDescribe(ODesc* d) const override; +}; -class FieldAssignExpr final : public UnaryExpr - { +class FieldAssignExpr final : public UnaryExpr { public: - FieldAssignExpr(const char* field_name, ExprPtr value); + FieldAssignExpr(const char* field_name, ExprPtr value); - const char* FieldName() const { return field_name.c_str(); } + const char* FieldName() const { return field_name.c_str(); } - // When these are first constructed, we don't know the type. - // The following method coerces/promotes the assignment expression - // as needed, once we do know the type. - // - // Returns true on success, false if the types were incompatible - // (in which case an error is reported). - bool PromoteTo(TypePtr t); + // When these are first constructed, we don't know the type. + // The following method coerces/promotes the assignment expression + // as needed, once we do know the type. + // + // Returns true on success, false if the types were incompatible + // (in which case an error is reported). + bool PromoteTo(TypePtr t); - bool IsRecordElement(TypeDecl* td) const override; + bool IsRecordElement(TypeDecl* td) const override; - // Optimization-related: - ExprPtr Duplicate() override; - bool WillTransform(Reducer* c) const override { return true; } - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + // Optimization-related: + ExprPtr Duplicate() override; + bool WillTransform(Reducer* c) const override { return true; } + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; protected: - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - std::string field_name; - }; + std::string field_name; +}; -class ArithCoerceExpr final : public UnaryExpr - { +class ArithCoerceExpr final : public UnaryExpr { public: - ArithCoerceExpr(ExprPtr op, TypeTag t); + ArithCoerceExpr(ExprPtr op, TypeTag t); - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; - bool WillTransform(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + bool WillTransform(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; protected: - ValPtr FoldSingleVal(ValPtr v, const TypePtr& t) const; - ValPtr Fold(Val* v) const override; - }; + ValPtr FoldSingleVal(ValPtr v, const TypePtr& t) const; + ValPtr Fold(Val* v) const override; +}; -class RecordCoerceExpr final : public UnaryExpr - { +class RecordCoerceExpr final : public UnaryExpr { public: - RecordCoerceExpr(ExprPtr op, RecordTypePtr r); + RecordCoerceExpr(ExprPtr op, RecordTypePtr r); - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; - const std::vector& Map() const { return map; } + const std::vector& Map() const { return map; } protected: - ValPtr Fold(Val* v) const override; + ValPtr Fold(Val* v) const override; - // For each super-record slot, gives subrecord slot with which to - // fill it. - std::vector map; - }; + // For each super-record slot, gives subrecord slot with which to + // fill it. + std::vector map; +}; extern RecordValPtr coerce_to_record(RecordTypePtr rt, Val* v, const std::vector& map); -class TableCoerceExpr final : public UnaryExpr - { +class TableCoerceExpr final : public UnaryExpr { public: - TableCoerceExpr(ExprPtr op, TableTypePtr r, bool type_check = true); - ~TableCoerceExpr() override = default; + TableCoerceExpr(ExprPtr op, TableTypePtr r, bool type_check = true); + ~TableCoerceExpr() override = default; - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; protected: - ValPtr Fold(Val* v) const override; - }; + ValPtr Fold(Val* v) const override; +}; -class VectorCoerceExpr final : public UnaryExpr - { +class VectorCoerceExpr final : public UnaryExpr { public: - VectorCoerceExpr(ExprPtr op, VectorTypePtr v); - ~VectorCoerceExpr() override = default; + VectorCoerceExpr(ExprPtr op, VectorTypePtr v); + ~VectorCoerceExpr() override = default; - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; protected: - ValPtr Fold(Val* v) const override; - }; + ValPtr Fold(Val* v) const override; +}; -class ScheduleTimer final : public Timer - { +class ScheduleTimer final : public Timer { public: - ScheduleTimer(const EventHandlerPtr& event, zeek::Args args, double t); - ~ScheduleTimer() override = default; + ScheduleTimer(const EventHandlerPtr& event, zeek::Args args, double t); + ~ScheduleTimer() override = default; - void Dispatch(double t, bool is_expire) override; + void Dispatch(double t, bool is_expire) override; protected: - EventHandlerPtr event; - zeek::Args args; - }; + EventHandlerPtr event; + zeek::Args args; +}; -class ScheduleExpr final : public Expr - { +class ScheduleExpr final : public Expr { public: - ScheduleExpr(ExprPtr when, EventExprPtr event); + ScheduleExpr(ExprPtr when, EventExprPtr event); - bool IsPure() const override { return false; } + bool IsPure() const override { return false; } - ValPtr Eval(Frame* f) const override; + ValPtr Eval(Frame* f) const override; - Expr* When() const { return when.get(); } - EventExpr* Event() const { return event.get(); } + Expr* When() const { return when.get(); } + EventExpr* Event() const { return event.get(); } - TraversalCode Traverse(TraversalCallback* cb) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; - // Optimization-related: - ExprPtr Duplicate() override; - ExprPtr Inline(Inliner* inl) override; + // Optimization-related: + ExprPtr Duplicate() override; + ExprPtr Inline(Inliner* inl) override; - bool IsReduced(Reducer* c) const override; - bool HasReducedOps(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + bool IsReduced(Reducer* c) const override; + bool HasReducedOps(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - ExprPtr GetOp1() const override final; - ExprPtr GetOp2() const override final; + ExprPtr GetOp1() const override final; + ExprPtr GetOp2() const override final; - void SetOp1(ExprPtr _op) override final; - void SetOp2(ExprPtr _op) override final; + void SetOp1(ExprPtr _op) override final; + void SetOp2(ExprPtr _op) override final; protected: - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - ExprPtr when; - EventExprPtr event; - }; + ExprPtr when; + EventExprPtr event; +}; -class InExpr final : public BinaryExpr - { +class InExpr final : public BinaryExpr { public: - InExpr(ExprPtr op1, ExprPtr op2); + InExpr(ExprPtr op1, ExprPtr op2); - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; - bool HasReducedOps(Reducer* c) const override; + bool HasReducedOps(Reducer* c) const override; protected: - ValPtr Fold(Val* v1, Val* v2) const override; - }; + ValPtr Fold(Val* v1, Val* v2) const override; +}; -class CallExpr final : public Expr - { +class CallExpr final : public Expr { public: - CallExpr(ExprPtr func, ListExprPtr args, bool in_hook = false, bool in_when = false); + CallExpr(ExprPtr func, ListExprPtr args, bool in_hook = false, bool in_when = false); - Expr* Func() const { return func.get(); } - ListExpr* Args() const { return args.get(); } - ListExprPtr ArgsPtr() const { return args; } + Expr* Func() const { return func.get(); } + ListExpr* Args() const { return args.get(); } + ListExprPtr ArgsPtr() const { return args; } - bool IsPure() const override; - bool IsInWhen() const { return in_when; } + bool IsPure() const override; + bool IsInWhen() const { return in_when; } - ValPtr Eval(Frame* f) const override; + ValPtr Eval(Frame* f) const override; - TraversalCode Traverse(TraversalCallback* cb) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; - // Optimization-related: - ExprPtr Duplicate() override; - ExprPtr Inline(Inliner* inl) override; + // Optimization-related: + ExprPtr Duplicate() override; + ExprPtr Inline(Inliner* inl) override; - bool IsReduced(Reducer* c) const override; - bool HasReducedOps(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - StmtPtr ReduceToSingletons(Reducer* c) override; + bool IsReduced(Reducer* c) const override; + bool HasReducedOps(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + StmtPtr ReduceToSingletons(Reducer* c) override; protected: - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - ExprPtr func; - ListExprPtr args; - bool in_when; - }; + ExprPtr func; + ListExprPtr args; + bool in_when; +}; /** * Class that represents an anonymous function expression in Zeek. * On evaluation, captures the frame that it is evaluated in. This becomes * the closure for the instance of the function that it creates. */ -class LambdaExpr final : public Expr - { +class LambdaExpr final : public Expr { public: - LambdaExpr(FunctionIngredientsPtr ingredients, IDPList outer_ids, std::string name = "", - StmtPtr when_parent = nullptr); + LambdaExpr(FunctionIngredientsPtr ingredients, IDPList outer_ids, std::string name = "", + StmtPtr when_parent = nullptr); - const std::string& Name() const { return my_name; } + const std::string& Name() const { return my_name; } - const IDPList& OuterIDs() const { return outer_ids; } + const IDPList& OuterIDs() const { return outer_ids; } - // Lambda's potentially have their own private copy of captures, - // to enable updates to the set during script optimization. - using CaptureList = std::vector; - const std::optional& GetCaptures() const { return captures; } + // Lambda's potentially have their own private copy of captures, + // to enable updates to the set during script optimization. + using CaptureList = std::vector; + const std::optional& GetCaptures() const { return captures; } - ValPtr Eval(Frame* f) const override; - TraversalCode Traverse(TraversalCallback* cb) const override; + ValPtr Eval(Frame* f) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; - ScopePtr GetScope() const; + ScopePtr GetScope() const; - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; - const ScriptFuncPtr& PrimaryFunc() const { return primary_func; } + const ScriptFuncPtr& PrimaryFunc() const { return primary_func; } - const FunctionIngredientsPtr& Ingredients() const { return ingredients; } + const FunctionIngredientsPtr& Ingredients() const { return ingredients; } - void ReplaceBody(StmtPtr new_body); + void ReplaceBody(StmtPtr new_body); - bool IsReduced(Reducer* c) const override; - bool HasReducedOps(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - StmtPtr ReduceToSingletons(Reducer* c) override; + bool IsReduced(Reducer* c) const override; + bool HasReducedOps(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + StmtPtr ReduceToSingletons(Reducer* c) override; protected: - // Constructor used for script optimization. - LambdaExpr(LambdaExpr* orig); + // Constructor used for script optimization. + LambdaExpr(LambdaExpr* orig); - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; private: - friend class WhenInfo; + friend class WhenInfo; - // "Private" captures are captures that correspond to "when" - // condition locals. These aren't true captures in that they - // don't come from the outer frame when the lambda is constructed, - // but they otherwise behave like captures in that they persist - // across function invocations. - void SetPrivateCaptures(const IDSet& pcaps) { private_captures = pcaps; } + // "Private" captures are captures that correspond to "when" + // condition locals. These aren't true captures in that they + // don't come from the outer frame when the lambda is constructed, + // but they otherwise behave like captures in that they persist + // across function invocations. + void SetPrivateCaptures(const IDSet& pcaps) { private_captures = pcaps; } - bool CheckCaptures(StmtPtr when_parent); - void BuildName(); + bool CheckCaptures(StmtPtr when_parent); + void BuildName(); - void UpdateCaptures(Reducer* c); + void UpdateCaptures(Reducer* c); - FunctionIngredientsPtr ingredients; - ScriptFuncPtr primary_func; - IDPtr lambda_id; - IDPList outer_ids; - std::optional captures; - IDSet private_captures; + FunctionIngredientsPtr ingredients; + ScriptFuncPtr primary_func; + IDPtr lambda_id; + IDPList outer_ids; + std::optional captures; + IDSet private_captures; - std::string my_name; - }; + std::string my_name; +}; // This comes before EventExpr so that EventExpr::GetOp1 can return its // arguments as convertible to ExprPtr. -class ListExpr : public Expr - { +class ListExpr : public Expr { public: - ListExpr(); - explicit ListExpr(ExprPtr e); - ~ListExpr() override; + ListExpr(); + explicit ListExpr(ExprPtr e); + ~ListExpr() override; - void Append(ExprPtr e); + void Append(ExprPtr e); - const ExprPList& Exprs() const { return exprs; } - ExprPList& Exprs() { return exprs; } + const ExprPList& Exprs() const { return exprs; } + ExprPList& Exprs() { return exprs; } - // True if the entire list represents pure values. - bool IsPure() const override; + // True if the entire list represents pure values. + bool IsPure() const override; - ValPtr Eval(Frame* f) const override; + ValPtr Eval(Frame* f) const override; - TypePtr InitType() const override; - ExprPtr MakeLvalue() override; - void Assign(Frame* f, ValPtr v) override; + TypePtr InitType() const override; + ExprPtr MakeLvalue() override; + void Assign(Frame* f, ValPtr v) override; - TraversalCode Traverse(TraversalCallback* cb) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; - // Optimization-related: - ExprPtr Duplicate() override; - ExprPtr Inline(Inliner* inl) override; + // Optimization-related: + ExprPtr Duplicate() override; + ExprPtr Inline(Inliner* inl) override; - bool IsReduced(Reducer* c) const override; - bool HasReducedOps(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - StmtPtr ReduceToSingletons(Reducer* c) override; + bool IsReduced(Reducer* c) const override; + bool HasReducedOps(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + StmtPtr ReduceToSingletons(Reducer* c) override; protected: - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - ExprPList exprs; - }; + ExprPList exprs; +}; -class EventExpr final : public Expr - { +class EventExpr final : public Expr { public: - EventExpr(const char* name, ListExprPtr args); + EventExpr(const char* name, ListExprPtr args); - const char* Name() const { return name.c_str(); } - ListExpr* Args() const { return args.get(); } - EventHandlerPtr Handler() const { return handler; } + const char* Name() const { return name.c_str(); } + ListExpr* Args() const { return args.get(); } + EventHandlerPtr Handler() const { return handler; } - ValPtr Eval(Frame* f) const override; + ValPtr Eval(Frame* f) const override; - TraversalCode Traverse(TraversalCallback* cb) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; - // Optimization-related: - ExprPtr Duplicate() override; - ExprPtr Inline(Inliner* inl) override; + // Optimization-related: + ExprPtr Duplicate() override; + ExprPtr Inline(Inliner* inl) override; - bool IsReduced(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - StmtPtr ReduceToSingletons(Reducer* c) override; + bool IsReduced(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + StmtPtr ReduceToSingletons(Reducer* c) override; - ExprPtr GetOp1() const override final { return args; } - void SetOp1(ExprPtr _op) override final { args = {NewRef{}, _op->AsListExpr()}; } + ExprPtr GetOp1() const override final { return args; } + void SetOp1(ExprPtr _op) override final { args = {NewRef{}, _op->AsListExpr()}; } protected: - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - std::string name; - EventHandlerPtr handler; - ListExprPtr args; - }; + std::string name; + EventHandlerPtr handler; + ListExprPtr args; +}; -class RecordAssignExpr final : public ListExpr - { +class RecordAssignExpr final : public ListExpr { public: - RecordAssignExpr(const ExprPtr& record, const ExprPtr& init_list, bool is_init); - }; + RecordAssignExpr(const ExprPtr& record, const ExprPtr& init_list, bool is_init); +}; -class CastExpr final : public UnaryExpr - { +class CastExpr final : public UnaryExpr { public: - CastExpr(ExprPtr op, TypePtr t); + CastExpr(ExprPtr op, TypePtr t); - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; protected: - ValPtr Fold(Val* v) const override; - void ExprDescribe(ODesc* d) const override; - }; + ValPtr Fold(Val* v) const override; + void ExprDescribe(ODesc* d) const override; +}; // Returns the value 'v' cast to type 't'. On an error, returns nil // and populates "error" with an error message. extern ValPtr cast_value(ValPtr v, const TypePtr& t, std::string& error); -class IsExpr final : public UnaryExpr - { +class IsExpr final : public UnaryExpr { public: - IsExpr(ExprPtr op, TypePtr t); + IsExpr(ExprPtr op, TypePtr t); - const TypePtr& TestType() const { return t; } + const TypePtr& TestType() const { return t; } - // Optimization-related: - ExprPtr Duplicate() override; + // Optimization-related: + ExprPtr Duplicate() override; protected: - ValPtr Fold(Val* v) const override; - void ExprDescribe(ODesc* d) const override; + ValPtr Fold(Val* v) const override; + void ExprDescribe(ODesc* d) const override; private: - TypePtr t; - }; + TypePtr t; +}; -class InlineExpr : public Expr - { +class InlineExpr : public Expr { public: - InlineExpr(ListExprPtr arg_args, std::vector params, StmtPtr body, int frame_offset, - TypePtr ret_type); + InlineExpr(ListExprPtr arg_args, std::vector params, StmtPtr body, int frame_offset, TypePtr ret_type); - bool IsPure() const override; + bool IsPure() const override; - ListExprPtr Args() const { return args; } - StmtPtr Body() const { return body; } + ListExprPtr Args() const { return args; } + StmtPtr Body() const { return body; } - ValPtr Eval(Frame* f) const override; + ValPtr Eval(Frame* f) const override; - ExprPtr Duplicate() override; + ExprPtr Duplicate() override; - bool IsReduced(Reducer* c) const override; - bool HasReducedOps(Reducer* c) const override { return false; } - bool WillTransform(Reducer* c) const override { return true; } - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + bool IsReduced(Reducer* c) const override; + bool HasReducedOps(Reducer* c) const override { return false; } + bool WillTransform(Reducer* c) const override { return true; } + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - TraversalCode Traverse(TraversalCallback* cb) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; protected: - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - std::vector params; - int frame_offset; - ListExprPtr args; - StmtPtr body; - }; + std::vector params; + int frame_offset; + ListExprPtr args; + StmtPtr body; +}; // A companion to AddToExpr that's for vector-append, instantiated during // the reduction process. -class AppendToExpr : public BinaryExpr - { +class AppendToExpr : public BinaryExpr { public: - AppendToExpr(ExprPtr op1, ExprPtr op2); - ValPtr Eval(Frame* f) const override; + AppendToExpr(ExprPtr op1, ExprPtr op2); + ValPtr Eval(Frame* f) const override; - ExprPtr Duplicate() override; + ExprPtr Duplicate() override; - bool IsPure() const override { return false; } - bool IsReduced(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; - }; + bool IsPure() const override { return false; } + bool IsReduced(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; +}; // An internal class for reduced form. -class IndexAssignExpr : public BinaryExpr - { +class IndexAssignExpr : public BinaryExpr { public: - // "op1[op2] = op3", all reduced. - IndexAssignExpr(ExprPtr op1, ExprPtr op2, ExprPtr op3); + // "op1[op2] = op3", all reduced. + IndexAssignExpr(ExprPtr op1, ExprPtr op2, ExprPtr op3); - ValPtr Eval(Frame* f) const override; + ValPtr Eval(Frame* f) const override; - ExprPtr Duplicate() override; + ExprPtr Duplicate() override; - bool IsPure() const override { return false; } - bool IsReduced(Reducer* c) const override; - bool HasReducedOps(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; + bool IsPure() const override { return false; } + bool IsReduced(Reducer* c) const override; + bool HasReducedOps(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; - ExprPtr GetOp3() const override final { return op3; } - void SetOp3(ExprPtr _op) override final { op3 = std::move(_op); } + ExprPtr GetOp3() const override final { return op3; } + void SetOp3(ExprPtr _op) override final { op3 = std::move(_op); } - TraversalCode Traverse(TraversalCallback* cb) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; protected: - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - ExprPtr op3; // assignment RHS - }; + ExprPtr op3; // assignment RHS +}; // An internal class for reduced form. -class FieldLHSAssignExpr : public BinaryExpr - { +class FieldLHSAssignExpr : public BinaryExpr { public: - // "op1$field = RHS", where RHS is reduced with respect to - // ReduceToFieldAssignment(). - FieldLHSAssignExpr(ExprPtr op1, ExprPtr op2, const char* field_name, int field); + // "op1$field = RHS", where RHS is reduced with respect to + // ReduceToFieldAssignment(). + FieldLHSAssignExpr(ExprPtr op1, ExprPtr op2, const char* field_name, int field); - const char* FieldName() const { return field_name; } - int Field() const { return field; } + const char* FieldName() const { return field_name; } + int Field() const { return field; } - ValPtr Eval(Frame* f) const override; + ValPtr Eval(Frame* f) const override; - ExprPtr Duplicate() override; + ExprPtr Duplicate() override; - bool IsPure() const override { return false; } - bool IsReduced(Reducer* c) const override; - bool HasReducedOps(Reducer* c) const override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; + bool IsPure() const override { return false; } + bool IsReduced(Reducer* c) const override; + bool HasReducedOps(Reducer* c) const override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + ExprPtr ReduceToSingleton(Reducer* c, StmtPtr& red_stmt) override; protected: - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - const char* field_name; - int field; - }; + const char* field_name; + int field; +}; // Expression to explicitly capture conversion to an "any" type, rather // than it occurring implicitly during script interpretation. -class CoerceToAnyExpr : public UnaryExpr - { +class CoerceToAnyExpr : public UnaryExpr { public: - CoerceToAnyExpr(ExprPtr op); + CoerceToAnyExpr(ExprPtr op); protected: - ValPtr Fold(Val* v) const override; + ValPtr Fold(Val* v) const override; - ExprPtr Duplicate() override; - }; + ExprPtr Duplicate() override; +}; // Same, but for conversion from an "any" type. -class CoerceFromAnyExpr : public UnaryExpr - { +class CoerceFromAnyExpr : public UnaryExpr { public: - CoerceFromAnyExpr(ExprPtr op, TypePtr to_type); + CoerceFromAnyExpr(ExprPtr op, TypePtr to_type); protected: - ValPtr Fold(Val* v) const override; + ValPtr Fold(Val* v) const override; - ExprPtr Duplicate() override; - }; + ExprPtr Duplicate() override; +}; // ... and for conversion from a "vector of any" type. -class CoerceFromAnyVecExpr : public UnaryExpr - { +class CoerceFromAnyVecExpr : public UnaryExpr { public: - // to_type is yield type, not VectorType. - CoerceFromAnyVecExpr(ExprPtr op, TypePtr to_type); + // to_type is yield type, not VectorType. + CoerceFromAnyVecExpr(ExprPtr op, TypePtr to_type); - // Can't use UnaryExpr's Eval() because it will do folding - // over the individual vector elements. - ValPtr Eval(Frame* f) const override; + // Can't use UnaryExpr's Eval() because it will do folding + // over the individual vector elements. + ValPtr Eval(Frame* f) const override; protected: - ExprPtr Duplicate() override; - }; + ExprPtr Duplicate() override; +}; // Expression used to explicitly capture [a, b, c, ...] = x assignments. -class AnyIndexExpr : public UnaryExpr - { +class AnyIndexExpr : public UnaryExpr { public: - AnyIndexExpr(ExprPtr op, int index); + AnyIndexExpr(ExprPtr op, int index); - int Index() const { return index; } + int Index() const { return index; } protected: - ValPtr Fold(Val* v) const override; + ValPtr Fold(Val* v) const override; - void ExprDescribe(ODesc* d) const override; + void ExprDescribe(ODesc* d) const override; - ExprPtr Duplicate() override; - ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; + ExprPtr Duplicate() override; + ExprPtr Reduce(Reducer* c, StmtPtr& red_stmt) override; - int index; - }; + int index; +}; // Used internally for optimization, when a placeholder is needed. -class NopExpr : public Expr - { +class NopExpr : public Expr { public: - explicit NopExpr() : Expr(EXPR_NOP) { } + explicit NopExpr() : Expr(EXPR_NOP) {} - ValPtr Eval(Frame* f) const override; + ValPtr Eval(Frame* f) const override; - ExprPtr Duplicate() override; + ExprPtr Duplicate() override; - TraversalCode Traverse(TraversalCallback* cb) const override; + TraversalCode Traverse(TraversalCallback* cb) const override; protected: - void ExprDescribe(ODesc* d) const override; - }; + void ExprDescribe(ODesc* d) const override; +}; // Assigns v1[v2] = v3. Returns an error message, or nullptr on success. // Factored out so that compiled code can call it as well as the interpreter. extern const char* assign_to_index(ValPtr v1, ValPtr v2, ValPtr v3, bool& iterators_invalidated); -inline Val* Expr::ExprVal() const - { - if ( ! IsConst() ) - BadTag("ExprVal::Val", expr_name(tag), expr_name(EXPR_CONST)); - return ((ConstExpr*)this)->Value(); - } +inline Val* Expr::ExprVal() const { + if ( ! IsConst() ) + BadTag("ExprVal::Val", expr_name(tag), expr_name(EXPR_CONST)); + return ((ConstExpr*)this)->Value(); +} // Decides whether to return an AssignExpr or a RecordAssignExpr. extern ExprPtr get_assign_expr(ExprPtr op1, ExprPtr op2, bool is_init); @@ -1869,25 +1792,13 @@ extern std::optional> eval_list(Frame* f, const ListExpr* l) extern bool expr_greater(const Expr* e1, const Expr* e2); // True if the given Expr* has a vector type -inline bool is_vector(Expr* e) - { - return e->GetType()->Tag() == TYPE_VECTOR; - } -inline bool is_vector(const ExprPtr& e) - { - return is_vector(e.get()); - } +inline bool is_vector(Expr* e) { return e->GetType()->Tag() == TYPE_VECTOR; } +inline bool is_vector(const ExprPtr& e) { return is_vector(e.get()); } // True if the given Expr* has a list type -inline bool is_list(Expr* e) - { - return e->GetType()->Tag() == TYPE_LIST; - } +inline bool is_list(Expr* e) { return e->GetType()->Tag() == TYPE_LIST; } -inline bool is_list(const ExprPtr& e) - { - return is_list(e.get()); - } +inline bool is_list(const ExprPtr& e) { return is_list(e.get()); } - } // namespace detail - } // namespace zeek +} // namespace detail +} // namespace zeek diff --git a/src/File.cc b/src/File.cc index 958c1c2300..cc77df035d 100644 --- a/src/File.cc +++ b/src/File.cc @@ -31,331 +31,296 @@ #include "zeek/Type.h" #include "zeek/Var.h" -namespace zeek - { +namespace zeek { std::list> File::open_files; // Maximizes the number of open file descriptors. -static void maximize_num_fds() - { - struct rlimit rl; - if ( getrlimit(RLIMIT_NOFILE, &rl) < 0 ) - reporter->FatalError("maximize_num_fds(): getrlimit failed"); +static void maximize_num_fds() { + struct rlimit rl; + if ( getrlimit(RLIMIT_NOFILE, &rl) < 0 ) + reporter->FatalError("maximize_num_fds(): getrlimit failed"); - if ( rl.rlim_max == RLIM_INFINITY ) - { - // Don't try raising the current limit. - return; - } + if ( rl.rlim_max == RLIM_INFINITY ) { + // Don't try raising the current limit. + return; + } - // See if we can raise the current to the maximum. - rl.rlim_cur = rl.rlim_max; + // See if we can raise the current to the maximum. + rl.rlim_cur = rl.rlim_max; - if ( setrlimit(RLIMIT_NOFILE, &rl) < 0 ) - reporter->FatalError("maximize_num_fds(): setrlimit failed"); - } + if ( setrlimit(RLIMIT_NOFILE, &rl) < 0 ) + reporter->FatalError("maximize_num_fds(): setrlimit failed"); +} -File::File(FILE* arg_f) - { - Init(); - f = arg_f; - name = access = nullptr; - t = base_type(TYPE_STRING); - is_open = (f != nullptr); - } +File::File(FILE* arg_f) { + Init(); + f = arg_f; + name = access = nullptr; + t = base_type(TYPE_STRING); + is_open = (f != nullptr); +} -File::File(FILE* arg_f, const char* arg_name, const char* arg_access) - { - Init(); - f = arg_f; - name = util::copy_string(arg_name); - access = util::copy_string(arg_access); - t = base_type(TYPE_STRING); - is_open = (f != nullptr); - } +File::File(FILE* arg_f, const char* arg_name, const char* arg_access) { + Init(); + f = arg_f; + name = util::copy_string(arg_name); + access = util::copy_string(arg_access); + t = base_type(TYPE_STRING); + is_open = (f != nullptr); +} -File::File(const char* arg_name, const char* arg_access) - { - Init(); - f = nullptr; - name = util::copy_string(arg_name); - access = util::copy_string(arg_access); - t = base_type(TYPE_STRING); +File::File(const char* arg_name, const char* arg_access) { + Init(); + f = nullptr; + name = util::copy_string(arg_name); + access = util::copy_string(arg_access); + t = base_type(TYPE_STRING); - if ( util::streq(name, "/dev/stdin") ) - f = stdin; - else if ( util::streq(name, "/dev/stdout") ) - f = stdout; - else if ( util::streq(name, "/dev/stderr") ) - f = stderr; + if ( util::streq(name, "/dev/stdin") ) + f = stdin; + else if ( util::streq(name, "/dev/stdout") ) + f = stdout; + else if ( util::streq(name, "/dev/stderr") ) + f = stderr; - if ( f ) - is_open = true; + if ( f ) + is_open = true; - else if ( ! Open() ) - { - reporter->Error("cannot open %s: %s", name, strerror(errno)); - is_open = false; - } - } + else if ( ! Open() ) { + reporter->Error("cannot open %s: %s", name, strerror(errno)); + is_open = false; + } +} -const char* File::Name() const - { - if ( name ) - return name; +const char* File::Name() const { + if ( name ) + return name; - if ( f == stdin ) - return "/dev/stdin"; + if ( f == stdin ) + return "/dev/stdin"; - if ( f == stdout ) - return "/dev/stdout"; + if ( f == stdout ) + return "/dev/stdout"; - if ( f == stderr ) - return "/dev/stderr"; + if ( f == stderr ) + return "/dev/stderr"; - return nullptr; - } + return nullptr; +} -bool File::Open(FILE* file, const char* mode) - { - static bool fds_maximized = false; - open_time = run_state::network_time ? run_state::network_time : util::current_time(); +bool File::Open(FILE* file, const char* mode) { + static bool fds_maximized = false; + open_time = run_state::network_time ? run_state::network_time : util::current_time(); - if ( ! fds_maximized ) - { - // Haven't initialized yet. - maximize_num_fds(); - fds_maximized = true; - } + if ( ! fds_maximized ) { + // Haven't initialized yet. + maximize_num_fds(); + fds_maximized = true; + } - f = file; + f = file; - if ( ! f ) - { - if ( ! mode ) - f = fopen(name, access); - else - f = fopen(name, mode); - } + if ( ! f ) { + if ( ! mode ) + f = fopen(name, access); + else + f = fopen(name, mode); + } - SetBuf(buffered); + SetBuf(buffered); - if ( ! f ) - { - is_open = false; - return false; - } + if ( ! f ) { + is_open = false; + return false; + } - is_open = true; - open_files.emplace_back(name, this); + is_open = true; + open_files.emplace_back(name, this); - RaiseOpenEvent(); + RaiseOpenEvent(); - return true; - } + return true; +} -File::~File() - { - Close(); - Unref(attrs); +File::~File() { + Close(); + Unref(attrs); - delete[] name; - delete[] access; + delete[] name; + delete[] access; #ifdef USE_PERFTOOLS_DEBUG - heap_checker->UnIgnoreObject(this); + heap_checker->UnIgnoreObject(this); #endif - } +} -void File::Init() - { - open_time = 0; - is_open = false; - attrs = nullptr; - buffered = true; - raw_output = false; +void File::Init() { + open_time = 0; + is_open = false; + attrs = nullptr; + buffered = true; + raw_output = false; #ifdef USE_PERFTOOLS_DEBUG - heap_checker->IgnoreObject(this); + heap_checker->IgnoreObject(this); #endif - } +} -FILE* File::FileHandle() - { - return f; - } +FILE* File::FileHandle() { return f; } -FILE* File::Seek(long new_position) - { - if ( ! FileHandle() ) - return nullptr; +FILE* File::Seek(long new_position) { + if ( ! FileHandle() ) + return nullptr; - if ( fseek(f, new_position, SEEK_SET) < 0 ) - reporter->Error("seek failed"); + if ( fseek(f, new_position, SEEK_SET) < 0 ) + reporter->Error("seek failed"); - return f; - } + return f; +} -void File::SetBuf(bool arg_buffered) - { - if ( ! f ) - return; +void File::SetBuf(bool arg_buffered) { + if ( ! f ) + return; - if ( util::detail::setvbuf(f, NULL, arg_buffered ? _IOFBF : _IOLBF, 0) != 0 ) - reporter->Error("setvbuf failed"); + if ( util::detail::setvbuf(f, NULL, arg_buffered ? _IOFBF : _IOLBF, 0) != 0 ) + reporter->Error("setvbuf failed"); - buffered = arg_buffered; - } + buffered = arg_buffered; +} -bool File::Close() - { - if ( ! is_open ) - return true; +bool File::Close() { + if ( ! is_open ) + return true; - // Do not close stdin/stdout/stderr. - if ( f == stdin || f == stdout || f == stderr ) - return false; + // Do not close stdin/stdout/stderr. + if ( f == stdin || f == stdout || f == stderr ) + return false; - if ( ! f ) - return false; + if ( ! f ) + return false; - fclose(f); - f = nullptr; - open_time = 0; - is_open = false; + fclose(f); + f = nullptr; + open_time = 0; + is_open = false; - Unlink(); + Unlink(); - return true; - } + return true; +} -void File::Unlink() - { - for ( auto it = open_files.begin(); it != open_files.end(); ++it ) - { - if ( (*it).second == this ) - { - open_files.erase(it); - return; - } - } - } +void File::Unlink() { + for ( auto it = open_files.begin(); it != open_files.end(); ++it ) { + if ( (*it).second == this ) { + open_files.erase(it); + return; + } + } +} -void File::Describe(ODesc* d) const - { - d->AddSP("file"); +void File::Describe(ODesc* d) const { + d->AddSP("file"); - if ( name ) - { - d->Add("\""); - d->Add(name); - d->AddSP("\""); - } + if ( name ) { + d->Add("\""); + d->Add(name); + d->AddSP("\""); + } - d->AddSP("of"); - if ( t ) - t->Describe(d); - else - d->Add("(no type)"); - } + d->AddSP("of"); + if ( t ) + t->Describe(d); + else + d->Add("(no type)"); +} -void File::SetAttrs(detail::Attributes* arg_attrs) - { - if ( ! arg_attrs ) - return; +void File::SetAttrs(detail::Attributes* arg_attrs) { + if ( ! arg_attrs ) + return; - attrs = arg_attrs; - Ref(attrs); + attrs = arg_attrs; + Ref(attrs); - if ( attrs->Find(detail::ATTR_RAW_OUTPUT) ) - EnableRawOutput(); - } + if ( attrs->Find(detail::ATTR_RAW_OUTPUT) ) + EnableRawOutput(); +} -RecordVal* File::Rotate() - { - if ( ! is_open ) - return nullptr; +RecordVal* File::Rotate() { + if ( ! is_open ) + return nullptr; - // Do not rotate stdin/stdout/stderr. - if ( f == stdin || f == stdout || f == stderr ) - return nullptr; + // Do not rotate stdin/stdout/stderr. + if ( f == stdin || f == stdout || f == stderr ) + return nullptr; - static auto rotate_info = id::find_type("rotate_info"); - auto* info = new RecordVal(rotate_info); - FILE* newf = util::detail::rotate_file(name, info); + static auto rotate_info = id::find_type("rotate_info"); + auto* info = new RecordVal(rotate_info); + FILE* newf = util::detail::rotate_file(name, info); - if ( ! newf ) - { - Unref(info); - return nullptr; - } + if ( ! newf ) { + Unref(info); + return nullptr; + } - info->AssignTime(2, open_time); + info->AssignTime(2, open_time); - Unlink(); + Unlink(); - fclose(f); - f = nullptr; + fclose(f); + f = nullptr; - Open(newf); - return info; - } + Open(newf); + return info; +} -void File::CloseOpenFiles() - { - auto it = open_files.begin(); - while ( it != open_files.end() ) - { - auto el = it++; - (*el).second->Close(); - } - } +void File::CloseOpenFiles() { + auto it = open_files.begin(); + while ( it != open_files.end() ) { + auto el = it++; + (*el).second->Close(); + } +} -bool File::Write(const char* data, int len) - { - if ( ! is_open ) - return false; +bool File::Write(const char* data, int len) { + if ( ! is_open ) + return false; - if ( ! len ) - len = strlen(data); + if ( ! len ) + len = strlen(data); - if ( fwrite(data, len, 1, f) < 1 ) - return false; + if ( fwrite(data, len, 1, f) < 1 ) + return false; - return true; - } + return true; +} -void File::RaiseOpenEvent() - { - if ( ! ::file_opened ) - return; +void File::RaiseOpenEvent() { + if ( ! ::file_opened ) + return; - FilePtr bf{NewRef{}, this}; - auto* event = new Event(::file_opened, {make_intrusive(std::move(bf))}); - event_mgr.Dispatch(event, true); - } + FilePtr bf{NewRef{}, this}; + auto* event = new Event(::file_opened, {make_intrusive(std::move(bf))}); + event_mgr.Dispatch(event, true); +} -double File::Size() - { - fflush(f); - struct stat s; - if ( fstat(fileno(f), &s) < 0 ) - { - reporter->Error("can't stat fd for %s: %s", name, strerror(errno)); - return 0; - } +double File::Size() { + fflush(f); + struct stat s; + if ( fstat(fileno(f), &s) < 0 ) { + reporter->Error("can't stat fd for %s: %s", name, strerror(errno)); + return 0; + } - return s.st_size; - } + return s.st_size; +} -FilePtr File::Get(const char* name) - { - for ( const auto& el : open_files ) - if ( el.first == name ) - return {NewRef{}, el.second}; +FilePtr File::Get(const char* name) { + for ( const auto& el : open_files ) + if ( el.first == name ) + return {NewRef{}, el.second}; - return make_intrusive(name, "w"); - } + return make_intrusive(name, "w"); +} - } // namespace zeek +} // namespace zeek diff --git a/src/File.h b/src/File.h index 26ef3c154f..85037c6739 100644 --- a/src/File.h +++ b/src/File.h @@ -10,18 +10,16 @@ #include "zeek/Val.h" #include "zeek/util.h" -namespace zeek - { +namespace zeek { -namespace detail - { +namespace detail { class PrintStmt; class Attributes; extern void do_print_stmt(const std::vector& vals); - } // namespace detail; +} // namespace detail class RecordVal; class Type; @@ -30,94 +28,93 @@ using TypePtr = IntrusivePtr; class File; using FilePtr = IntrusivePtr; -class File final : public Obj - { +class File final : public Obj { public: - explicit File(FILE* arg_f); - File(FILE* arg_f, const char* filename, const char* access); - File(const char* filename, const char* access); - ~File() override; + explicit File(FILE* arg_f); + File(FILE* arg_f, const char* filename, const char* access); + File(const char* filename, const char* access); + ~File() override; - const char* Name() const; + const char* Name() const; - // Returns false if an error occurred. - bool Write(const char* data, int len = 0); + // Returns false if an error occurred. + bool Write(const char* data, int len = 0); - void Flush() { fflush(f); } + void Flush() { fflush(f); } - FILE* Seek(long position); // seek to absolute position + FILE* Seek(long position); // seek to absolute position - void SetBuf(bool buffered); // false=line buffered, true=fully buffered + void SetBuf(bool buffered); // false=line buffered, true=fully buffered - const TypePtr& GetType() const { return t; } + const TypePtr& GetType() const { return t; } - // Whether the file is open in a general sense; it might - // not be open as a Unix file due to our management of - // a finite number of FDs. - bool IsOpen() const { return is_open; } + // Whether the file is open in a general sense; it might + // not be open as a Unix file due to our management of + // a finite number of FDs. + bool IsOpen() const { return is_open; } - // Returns true if the close made sense, false if it was already - // closed, not active, or whatever. - bool Close(); + // Returns true if the close made sense, false if it was already + // closed, not active, or whatever. + bool Close(); - void Describe(ODesc* d) const override; + void Describe(ODesc* d) const override; - // Rotates the logfile. Returns rotate_info. - RecordVal* Rotate(); + // Rotates the logfile. Returns rotate_info. + RecordVal* Rotate(); - // Set &raw_output attribute. - void SetAttrs(detail::Attributes* attrs); + // Set &raw_output attribute. + void SetAttrs(detail::Attributes* attrs); - // Returns the current size of the file, after fresh stat'ing. - double Size(); + // Returns the current size of the file, after fresh stat'ing. + double Size(); - // Close all files which are currently open. - static void CloseOpenFiles(); + // Close all files which are currently open. + static void CloseOpenFiles(); - // Get the file with the given name, opening it if it doesn't yet exist. - static FilePtr Get(const char* name); + // Get the file with the given name, opening it if it doesn't yet exist. + static FilePtr Get(const char* name); - void EnableRawOutput() { raw_output = true; } - bool IsRawOutput() const { return raw_output; } + void EnableRawOutput() { raw_output = true; } + bool IsRawOutput() const { return raw_output; } protected: - friend void detail::do_print_stmt(const std::vector& vals); + friend void detail::do_print_stmt(const std::vector& vals); - File() { Init(); } - void Init(); + File() { Init(); } + void Init(); - /** - * If file is given, it's an open file to use already. - * If file is not given and mode is, the filename will be opened with that - * access mode. - */ - bool Open(FILE* f = nullptr, const char* mode = nullptr); + /** + * If file is given, it's an open file to use already. + * If file is not given and mode is, the filename will be opened with that + * access mode. + */ + bool Open(FILE* f = nullptr, const char* mode = nullptr); - void Unlink(); + void Unlink(); - // Returns nil if the file is not active, was in error, etc. - // (Protected because we do not want anyone to write directly - // to the file, but the PrintStmt friend uses this to check whether - // it's really stdout.) - FILE* FileHandle(); + // Returns nil if the file is not active, was in error, etc. + // (Protected because we do not want anyone to write directly + // to the file, but the PrintStmt friend uses this to check whether + // it's really stdout.) + FILE* FileHandle(); - // Raises a file_opened event. - void RaiseOpenEvent(); + // Raises a file_opened event. + void RaiseOpenEvent(); - FILE* f = nullptr; - TypePtr t; - char* name = nullptr; - char* access = nullptr; - detail::Attributes* attrs = nullptr; - double open_time = 0.0; - bool is_open = false; // whether the file is open in a general sense - bool buffered = false; - bool raw_output = false; + FILE* f = nullptr; + TypePtr t; + char* name = nullptr; + char* access = nullptr; + detail::Attributes* attrs = nullptr; + double open_time = 0.0; + bool is_open = false; // whether the file is open in a general sense + bool buffered = false; + bool raw_output = false; - static constexpr int MIN_BUFFER_SIZE = 1024; + static constexpr int MIN_BUFFER_SIZE = 1024; private: - static std::list> open_files; - }; + static std::list> open_files; +}; - } // namespace zeek +} // namespace zeek diff --git a/src/Flare.cc b/src/Flare.cc index f2dbc60155..8e2e7eb6b1 100644 --- a/src/Flare.cc +++ b/src/Flare.cc @@ -12,148 +12,134 @@ #include -#define fatalError(...) \ - do \ - { \ - if ( reporter ) \ - reporter->FatalError(__VA_ARGS__); \ - else \ - { \ - fprintf(stderr, __VA_ARGS__); \ - fprintf(stderr, "\n"); \ - _exit(1); \ - } \ - } while ( 0 ) +#define fatalError(...) \ + do { \ + if ( reporter ) \ + reporter->FatalError(__VA_ARGS__); \ + else { \ + fprintf(stderr, __VA_ARGS__); \ + fprintf(stderr, "\n"); \ + _exit(1); \ + } \ + } while ( 0 ) #endif -namespace zeek::detail - { +namespace zeek::detail { Flare::Flare() #ifndef _MSC_VER - : pipe(FD_CLOEXEC, FD_CLOEXEC, O_NONBLOCK, O_NONBLOCK) - { - } + : pipe(FD_CLOEXEC, FD_CLOEXEC, O_NONBLOCK, O_NONBLOCK) { +} #else - { - WSADATA wsaData; - if ( WSAStartup(MAKEWORD(2, 2), &wsaData) != 0 ) - fatalError("WSAStartup failure: %d", WSAGetLastError()); +{ + WSADATA wsaData; + if ( WSAStartup(MAKEWORD(2, 2), &wsaData) != 0 ) + fatalError("WSAStartup failure: %d", WSAGetLastError()); - recvfd = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, nullptr, 0, - WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT); - if ( recvfd == (int)INVALID_SOCKET ) - fatalError("WSASocket failure: %d", WSAGetLastError()); - sendfd = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, nullptr, 0, - WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT); - if ( sendfd == (int)INVALID_SOCKET ) - fatalError("WSASocket failure: %d", WSAGetLastError()); + recvfd = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, nullptr, 0, WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT); + if ( recvfd == (int)INVALID_SOCKET ) + fatalError("WSASocket failure: %d", WSAGetLastError()); + sendfd = WSASocket(AF_INET, SOCK_DGRAM, IPPROTO_UDP, nullptr, 0, WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT); + if ( sendfd == (int)INVALID_SOCKET ) + fatalError("WSASocket failure: %d", WSAGetLastError()); - sockaddr_in sa; - memset(&sa, 0, sizeof(sa)); - sa.sin_family = AF_INET; - sa.sin_addr.s_addr = inet_addr("127.0.0.1"); - if ( bind(recvfd, (sockaddr*)&sa, sizeof(sa)) == SOCKET_ERROR ) - fatalError("bind failure: %d", WSAGetLastError()); - int salen = sizeof(sa); - if ( getsockname(recvfd, (sockaddr*)&sa, &salen) == SOCKET_ERROR ) - fatalError("getsockname failure: %d", WSAGetLastError()); - if ( connect(sendfd, (sockaddr*)&sa, sizeof(sa)) == SOCKET_ERROR ) - fatalError("connect failure: %d", WSAGetLastError()); - } + sockaddr_in sa; + memset(&sa, 0, sizeof(sa)); + sa.sin_family = AF_INET; + sa.sin_addr.s_addr = inet_addr("127.0.0.1"); + if ( bind(recvfd, (sockaddr*)&sa, sizeof(sa)) == SOCKET_ERROR ) + fatalError("bind failure: %d", WSAGetLastError()); + int salen = sizeof(sa); + if ( getsockname(recvfd, (sockaddr*)&sa, &salen) == SOCKET_ERROR ) + fatalError("getsockname failure: %d", WSAGetLastError()); + if ( connect(sendfd, (sockaddr*)&sa, sizeof(sa)) == SOCKET_ERROR ) + fatalError("connect failure: %d", WSAGetLastError()); +} #endif -[[noreturn]] static void bad_pipe_op(const char* which, bool signal_safe) - { - if ( signal_safe ) - abort(); +[[noreturn]] static void bad_pipe_op(const char* which, bool signal_safe) { + if ( signal_safe ) + abort(); - char buf[256]; - util::zeek_strerror_r(errno, buf, sizeof(buf)); + char buf[256]; + util::zeek_strerror_r(errno, buf, sizeof(buf)); - if ( reporter ) - reporter->FatalErrorWithCore("unexpected pipe %s failure: %s", which, buf); - else - { - fprintf(stderr, "unexpected pipe %s failure: %s", which, buf); - abort(); - } - } + if ( reporter ) + reporter->FatalErrorWithCore("unexpected pipe %s failure: %s", which, buf); + else { + fprintf(stderr, "unexpected pipe %s failure: %s", which, buf); + abort(); + } +} -void Flare::Fire(bool signal_safe) - { - char tmp = 0; +void Flare::Fire(bool signal_safe) { + char tmp = 0; - for ( ;; ) - { + for ( ;; ) { #ifndef _MSC_VER - int n = write(pipe.WriteFD(), &tmp, 1); + int n = write(pipe.WriteFD(), &tmp, 1); #else - int n = send(sendfd, &tmp, 1, 0); + int n = send(sendfd, &tmp, 1, 0); #endif - if ( n > 0 ) - // Success -- wrote a byte to pipe. - break; + if ( n > 0 ) + // Success -- wrote a byte to pipe. + break; - if ( n < 0 ) - { + if ( n < 0 ) { #ifdef _MSC_VER - errno = WSAGetLastError(); - bad_pipe_op("send", signal_safe); + errno = WSAGetLastError(); + bad_pipe_op("send", signal_safe); #endif - if ( errno == EAGAIN ) - // Success: pipe is full and just need at least one byte in it. - break; + if ( errno == EAGAIN ) + // Success: pipe is full and just need at least one byte in it. + break; - if ( errno == EINTR ) - // Interrupted: try again. - continue; + if ( errno == EINTR ) + // Interrupted: try again. + continue; - bad_pipe_op("write", signal_safe); - } + bad_pipe_op("write", signal_safe); + } - // No error, but didn't write a byte: try again. - } - } + // No error, but didn't write a byte: try again. + } +} -int Flare::Extinguish(bool signal_safe) - { - int rval = 0; - char tmp[256]; +int Flare::Extinguish(bool signal_safe) { + int rval = 0; + char tmp[256]; - for ( ;; ) - { + for ( ;; ) { #ifndef _MSC_VER - int n = read(pipe.ReadFD(), &tmp, sizeof(tmp)); + int n = read(pipe.ReadFD(), &tmp, sizeof(tmp)); #else - int n = recv(recvfd, tmp, sizeof(tmp), 0); + int n = recv(recvfd, tmp, sizeof(tmp), 0); #endif - if ( n >= 0 ) - { - rval += n; - // Pipe may not be empty yet: try again. - continue; - } + if ( n >= 0 ) { + rval += n; + // Pipe may not be empty yet: try again. + continue; + } #ifdef _MSC_VER - if ( WSAGetLastError() == WSAEWOULDBLOCK ) - break; - errno = WSAGetLastError(); - bad_pipe_op("recv", signal_safe); + if ( WSAGetLastError() == WSAEWOULDBLOCK ) + break; + errno = WSAGetLastError(); + bad_pipe_op("recv", signal_safe); #endif - if ( errno == EAGAIN ) - // Success: pipe is now empty. - break; + if ( errno == EAGAIN ) + // Success: pipe is now empty. + break; - if ( errno == EINTR ) - // Interrupted: try again. - continue; + if ( errno == EINTR ) + // Interrupted: try again. + continue; - bad_pipe_op("read", signal_safe); - } + bad_pipe_op("read", signal_safe); + } - return rval; - } + return rval; +} - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/Flare.h b/src/Flare.h index affae42dfd..1a1586c853 100644 --- a/src/Flare.h +++ b/src/Flare.h @@ -6,57 +6,55 @@ #include "Pipe.h" #endif -namespace zeek::detail - { +namespace zeek::detail { -class Flare - { +class Flare { public: - /** - * Create a flare object that can be used to signal a "ready" status via - * a file descriptor that may be integrated with select(), poll(), etc. - * Not thread-safe, but that should only require Fire()/Extinguish() calls - * to be made mutually exclusive (across all copies of a Flare). - */ - Flare(); + /** + * Create a flare object that can be used to signal a "ready" status via + * a file descriptor that may be integrated with select(), poll(), etc. + * Not thread-safe, but that should only require Fire()/Extinguish() calls + * to be made mutually exclusive (across all copies of a Flare). + */ + Flare(); - /** - * @return a file descriptor that will become ready if the flare has been - * Fire()'d and not yet Extinguished()'d. - */ - int FD() const + /** + * @return a file descriptor that will become ready if the flare has been + * Fire()'d and not yet Extinguished()'d. + */ + int FD() const #ifndef _MSC_VER - { - return pipe.ReadFD(); - } + { + return pipe.ReadFD(); + } #else - { - return recvfd; - } + { + return recvfd; + } #endif - /** - * Put the object in the "ready" state. - * @param signal_safe whether to skip error-reporting functionality that - * is not async-signal-safe (errors still abort the process regardless) - */ - void Fire(bool signal_safe = false); + /** + * Put the object in the "ready" state. + * @param signal_safe whether to skip error-reporting functionality that + * is not async-signal-safe (errors still abort the process regardless) + */ + void Fire(bool signal_safe = false); - /** - * Take the object out of the "ready" state. - * @param signal_safe whether to skip error-reporting functionality that - * is not async-signal-safe (errors still abort the process regardless) - * @return number of bytes read from the pipe, corresponds to the number - * of times Fire() was called. - */ - int Extinguish(bool signal_safe = false); + /** + * Take the object out of the "ready" state. + * @param signal_safe whether to skip error-reporting functionality that + * is not async-signal-safe (errors still abort the process regardless) + * @return number of bytes read from the pipe, corresponds to the number + * of times Fire() was called. + */ + int Extinguish(bool signal_safe = false); private: #ifndef _MSC_VER - Pipe pipe; + Pipe pipe; #else - int sendfd, recvfd; + int sendfd, recvfd; #endif - }; +}; - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/Frag.cc b/src/Frag.cc index fa5d170761..5995080e18 100644 --- a/src/Frag.cc +++ b/src/Frag.cc @@ -14,371 +14,328 @@ constexpr uint32_t MIN_ACCEPTABLE_FRAG_SIZE = 64; constexpr uint32_t MAX_ACCEPTABLE_FRAG_SIZE = 64000; -namespace zeek::detail - { - -FragTimer::~FragTimer() - { - if ( f ) - f->ClearTimer(); - } - -void FragTimer::Dispatch(double t, bool /* is_expire */) - { - if ( f ) - f->Expire(t); - else - reporter->InternalWarning("fragment timer dispatched w/o reassembler"); - } - -FragReassembler::FragReassembler(session::Manager* arg_s, const std::shared_ptr& ip, - const u_char* pkt, const FragReassemblerKey& k, double t) - : Reassembler(0, REASSEM_FRAG) - { - s = arg_s; - key = k; - - const struct ip* ip4 = ip->IP4_Hdr(); - if ( ip4 ) - { - proto_hdr_len = ip->HdrLen(); - proto_hdr = new u_char[64]; // max IP header + slop - // Don't do a structure copy - need to pick up options, too. - memcpy((void*)proto_hdr, (const void*)ip4, proto_hdr_len); - } - else - { - proto_hdr_len = ip->HdrLen() - 8; // minus length of fragment header - proto_hdr = new u_char[proto_hdr_len]; - memcpy(proto_hdr, ip->IP6_Hdr(), proto_hdr_len); - } - - reassembled_pkt = nullptr; - frag_size = 0; // flag meaning "not known" - next_proto = ip->NextProto(); - - if ( frag_timeout != 0.0 ) - { - expire_timer = new FragTimer(this, t + frag_timeout); - timer_mgr->Add(expire_timer); - } - else - expire_timer = nullptr; - - AddFragment(t, ip, pkt); - } - -FragReassembler::~FragReassembler() - { - DeleteTimer(); - delete[] proto_hdr; - } - -void FragReassembler::AddFragment(double t, const std::shared_ptr& ip, const u_char* pkt) - { - const struct ip* ip4 = ip->IP4_Hdr(); - - if ( ip4 ) - { - if ( ip4->ip_p != ((const struct ip*)proto_hdr)->ip_p || - ip4->ip_hl != ((const struct ip*)proto_hdr)->ip_hl ) - // || ip4->ip_tos != proto_hdr->ip_tos - // don't check TOS, there's at least one stack that actually - // uses different values, and it's hard to see an associated - // attack. - s->Weird("fragment_protocol_inconsistency", ip.get()); - } - else - { - if ( ip->NextProto() != next_proto || ip->HdrLen() - 8 != proto_hdr_len ) - s->Weird("fragment_protocol_inconsistency", ip.get()); - // TODO: more detailed unfrag header consistency checks? - } - - if ( ip->DF() ) - // Linux MTU discovery for UDP can do this, for example. - s->Weird("fragment_with_DF", ip.get()); - - uint16_t offset = ip->FragOffset(); - uint32_t len = ip->TotalLen(); - uint16_t hdr_len = ip->HdrLen(); - - if ( len < hdr_len ) - { - s->Weird("fragment_protocol_inconsistency", ip.get()); - return; - } - - uint64_t upper_seq = offset + len - hdr_len; - - if ( ! offset ) - // Make sure to use the first fragment header's next field. - next_proto = ip->NextProto(); - - if ( ! ip->MF() ) - { - // Last fragment. - if ( frag_size == 0 ) - frag_size = upper_seq; - - else if ( upper_seq != frag_size ) - { - s->Weird("fragment_size_inconsistency", ip.get()); - - if ( upper_seq > frag_size ) - frag_size = upper_seq; - } - } - - else if ( len < MIN_ACCEPTABLE_FRAG_SIZE ) - s->Weird("excessively_small_fragment", ip.get()); - - if ( upper_seq > MAX_ACCEPTABLE_FRAG_SIZE ) - s->Weird("excessively_large_fragment", ip.get()); - - if ( frag_size && upper_seq > frag_size ) - { - // This can happen if we receive a fragment that's *not* - // the last fragment, but still imputes a size that's - // larger than the size we derived from a previously-seen - // "last fragment". - - s->Weird("fragment_size_inconsistency", ip.get()); - frag_size = upper_seq; - } - - // Do we need to check for consistent options? That's tricky - // for things like LSRR that get modified in route. - - // Remove header. - pkt += hdr_len; - len -= hdr_len; - - NewBlock(run_state::network_time, offset, len, pkt); - } - -void FragReassembler::Weird(const char* name) const - { - unsigned int version = ((const ip*)proto_hdr)->ip_v; - - if ( version == 4 ) - { - IP_Hdr hdr((const ip*)proto_hdr, false); - s->Weird(name, &hdr); - } - - else if ( version == 6 ) - { - IP_Hdr hdr((const ip6_hdr*)proto_hdr, false, proto_hdr_len); - s->Weird(name, &hdr); - } - - else - { - reporter->InternalWarning("Unexpected IP version in FragReassembler"); - reporter->Weird(name); - } - } - -void FragReassembler::Overlap(const u_char* b1, const u_char* b2, uint64_t n) - { - if ( memcmp((const void*)b1, (const void*)b2, n) ) - Weird("fragment_inconsistency"); - else - Weird("fragment_overlap"); - } - -void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */) - { - auto it = block_list.Begin(); - - if ( it->second.seq > 0 || ! frag_size ) - // For sure don't have it all yet. - return; - - auto next = std::next(it); - - // We might have it all - look for contiguous all the way. - while ( next != block_list.End() ) - { - if ( it->second.upper != next->second.seq ) - break; - - ++it; - ++next; - } - - const auto& last = block_list.LastBlock(); - - if ( next != block_list.End() ) - { - // We have a hole. - if ( it->second.upper >= frag_size ) - { - // We're stuck. The point where we stopped is - // contiguous up through the expected end of - // the fragment, but there's more stuff still - // beyond it, which is not contiguous. This - // can happen for benign reasons when we're - // intermingling parts of two fragmented packets. - Weird("fragment_size_inconsistency"); - - // We decide to analyze the contiguous portion now. - // Extend the fragment up through the end of what - // we have. - frag_size = it->second.upper; - } - else - return; - } - - else if ( last.upper > frag_size ) - { - Weird("fragment_size_inconsistency"); - frag_size = last.upper; - } - - else if ( last.upper < frag_size ) - // Missing the tail. - return; - - // We have it all. Compute the expected size of the fragment. - uint64_t n = proto_hdr_len + frag_size; - - // It's possible that we have blocks associated with this fragment - // that exceed this size, if we saw MF fragments (which don't lead - // to us setting frag_size) that went beyond the size indicated by - // the final, non-MF fragment. This can happen for benign reasons - // due to intermingling of fragments from an older datagram with those - // for a more recent one. - - u_char* pkt = new u_char[n]; - memcpy((void*)pkt, (const void*)proto_hdr, proto_hdr_len); - - u_char* pkt_start = pkt; - - pkt += proto_hdr_len; - - for ( it = block_list.Begin(); it != block_list.End(); ++it ) - { - const auto& b = it->second; - - if ( it != block_list.Begin() ) - { - const auto& prev = std::prev(it)->second; - - // If we're above a hole, stop. This can happen because - // the logic above regarding a hole that's above the - // expected fragment size. - if ( prev.upper < b.seq ) - break; - } - - if ( b.upper > n ) - { - reporter->InternalWarning("bad fragment reassembly"); - DeleteTimer(); - Expire(run_state::network_time); - delete[] pkt_start; - return; - } - - memcpy(&pkt[b.seq], b.block, b.upper - b.seq); - } - - reassembled_pkt.reset(); - - unsigned int version = ((const struct ip*)pkt_start)->ip_v; - - if ( version == 4 ) - { - struct ip* reassem4 = (struct ip*)pkt_start; - reassem4->ip_len = htons(frag_size + proto_hdr_len); - reassembled_pkt = std::make_shared(reassem4, true, true); - DeleteTimer(); - } - - else if ( version == 6 ) - { - struct ip6_hdr* reassem6 = (struct ip6_hdr*)pkt_start; - reassem6->ip6_plen = htons(frag_size + proto_hdr_len - 40); - const IPv6_Hdr_Chain* chain = new IPv6_Hdr_Chain(reassem6, next_proto, n); - reassembled_pkt = std::make_shared(reassem6, true, n, chain, true); - DeleteTimer(); - } - - else - { - reporter->InternalWarning("bad IP version in fragment reassembly: %d", version); - delete[] pkt_start; - } - } - -void FragReassembler::Expire(double t) - { - block_list.Clear(); - expire_timer->ClearReassembler(); - expire_timer = nullptr; // timer manager will delete it - - fragment_mgr->Remove(this); - } - -void FragReassembler::DeleteTimer() - { - if ( expire_timer ) - { - expire_timer->ClearReassembler(); - timer_mgr->Cancel(expire_timer); - expire_timer = nullptr; // timer manager will delete it - } - } - -FragmentManager::~FragmentManager() - { - Clear(); - } - -FragReassembler* FragmentManager::NextFragment(double t, const std::shared_ptr& ip, - const u_char* pkt) - { - uint32_t frag_id = ip->ID(); - FragReassemblerKey key = std::make_tuple(ip->SrcAddr(), ip->DstAddr(), frag_id); - - FragReassembler* f = nullptr; - auto it = fragments.find(key); - if ( it != fragments.end() ) - f = it->second; - - if ( ! f ) - { - f = new FragReassembler(session_mgr, ip, pkt, key, t); - fragments[key] = f; - if ( fragments.size() > max_fragments ) - max_fragments = fragments.size(); - return f; - } - - f->AddFragment(t, ip, pkt); - return f; - } - -void FragmentManager::Clear() - { - for ( const auto& entry : fragments ) - Unref(entry.second); - - fragments.clear(); - } - -void FragmentManager::Remove(detail::FragReassembler* f) - { - if ( ! f ) - return; - - if ( fragments.erase(f->Key()) == 0 ) - reporter->InternalWarning("fragment reassembler not in dict"); - - Unref(f); - } - - } // namespace zeek::detail +namespace zeek::detail { + +FragTimer::~FragTimer() { + if ( f ) + f->ClearTimer(); +} + +void FragTimer::Dispatch(double t, bool /* is_expire */) { + if ( f ) + f->Expire(t); + else + reporter->InternalWarning("fragment timer dispatched w/o reassembler"); +} + +FragReassembler::FragReassembler(session::Manager* arg_s, const std::shared_ptr& ip, const u_char* pkt, + const FragReassemblerKey& k, double t) + : Reassembler(0, REASSEM_FRAG) { + s = arg_s; + key = k; + + const struct ip* ip4 = ip->IP4_Hdr(); + if ( ip4 ) { + proto_hdr_len = ip->HdrLen(); + proto_hdr = new u_char[64]; // max IP header + slop + // Don't do a structure copy - need to pick up options, too. + memcpy((void*)proto_hdr, (const void*)ip4, proto_hdr_len); + } + else { + proto_hdr_len = ip->HdrLen() - 8; // minus length of fragment header + proto_hdr = new u_char[proto_hdr_len]; + memcpy(proto_hdr, ip->IP6_Hdr(), proto_hdr_len); + } + + reassembled_pkt = nullptr; + frag_size = 0; // flag meaning "not known" + next_proto = ip->NextProto(); + + if ( frag_timeout != 0.0 ) { + expire_timer = new FragTimer(this, t + frag_timeout); + timer_mgr->Add(expire_timer); + } + else + expire_timer = nullptr; + + AddFragment(t, ip, pkt); +} + +FragReassembler::~FragReassembler() { + DeleteTimer(); + delete[] proto_hdr; +} + +void FragReassembler::AddFragment(double t, const std::shared_ptr& ip, const u_char* pkt) { + const struct ip* ip4 = ip->IP4_Hdr(); + + if ( ip4 ) { + if ( ip4->ip_p != ((const struct ip*)proto_hdr)->ip_p || ip4->ip_hl != ((const struct ip*)proto_hdr)->ip_hl ) + // || ip4->ip_tos != proto_hdr->ip_tos + // don't check TOS, there's at least one stack that actually + // uses different values, and it's hard to see an associated + // attack. + s->Weird("fragment_protocol_inconsistency", ip.get()); + } + else { + if ( ip->NextProto() != next_proto || ip->HdrLen() - 8 != proto_hdr_len ) + s->Weird("fragment_protocol_inconsistency", ip.get()); + // TODO: more detailed unfrag header consistency checks? + } + + if ( ip->DF() ) + // Linux MTU discovery for UDP can do this, for example. + s->Weird("fragment_with_DF", ip.get()); + + uint16_t offset = ip->FragOffset(); + uint32_t len = ip->TotalLen(); + uint16_t hdr_len = ip->HdrLen(); + + if ( len < hdr_len ) { + s->Weird("fragment_protocol_inconsistency", ip.get()); + return; + } + + uint64_t upper_seq = offset + len - hdr_len; + + if ( ! offset ) + // Make sure to use the first fragment header's next field. + next_proto = ip->NextProto(); + + if ( ! ip->MF() ) { + // Last fragment. + if ( frag_size == 0 ) + frag_size = upper_seq; + + else if ( upper_seq != frag_size ) { + s->Weird("fragment_size_inconsistency", ip.get()); + + if ( upper_seq > frag_size ) + frag_size = upper_seq; + } + } + + else if ( len < MIN_ACCEPTABLE_FRAG_SIZE ) + s->Weird("excessively_small_fragment", ip.get()); + + if ( upper_seq > MAX_ACCEPTABLE_FRAG_SIZE ) + s->Weird("excessively_large_fragment", ip.get()); + + if ( frag_size && upper_seq > frag_size ) { + // This can happen if we receive a fragment that's *not* + // the last fragment, but still imputes a size that's + // larger than the size we derived from a previously-seen + // "last fragment". + + s->Weird("fragment_size_inconsistency", ip.get()); + frag_size = upper_seq; + } + + // Do we need to check for consistent options? That's tricky + // for things like LSRR that get modified in route. + + // Remove header. + pkt += hdr_len; + len -= hdr_len; + + NewBlock(run_state::network_time, offset, len, pkt); +} + +void FragReassembler::Weird(const char* name) const { + unsigned int version = ((const ip*)proto_hdr)->ip_v; + + if ( version == 4 ) { + IP_Hdr hdr((const ip*)proto_hdr, false); + s->Weird(name, &hdr); + } + + else if ( version == 6 ) { + IP_Hdr hdr((const ip6_hdr*)proto_hdr, false, proto_hdr_len); + s->Weird(name, &hdr); + } + + else { + reporter->InternalWarning("Unexpected IP version in FragReassembler"); + reporter->Weird(name); + } +} + +void FragReassembler::Overlap(const u_char* b1, const u_char* b2, uint64_t n) { + if ( memcmp((const void*)b1, (const void*)b2, n) ) + Weird("fragment_inconsistency"); + else + Weird("fragment_overlap"); +} + +void FragReassembler::BlockInserted(DataBlockMap::const_iterator /* it */) { + auto it = block_list.Begin(); + + if ( it->second.seq > 0 || ! frag_size ) + // For sure don't have it all yet. + return; + + auto next = std::next(it); + + // We might have it all - look for contiguous all the way. + while ( next != block_list.End() ) { + if ( it->second.upper != next->second.seq ) + break; + + ++it; + ++next; + } + + const auto& last = block_list.LastBlock(); + + if ( next != block_list.End() ) { + // We have a hole. + if ( it->second.upper >= frag_size ) { + // We're stuck. The point where we stopped is + // contiguous up through the expected end of + // the fragment, but there's more stuff still + // beyond it, which is not contiguous. This + // can happen for benign reasons when we're + // intermingling parts of two fragmented packets. + Weird("fragment_size_inconsistency"); + + // We decide to analyze the contiguous portion now. + // Extend the fragment up through the end of what + // we have. + frag_size = it->second.upper; + } + else + return; + } + + else if ( last.upper > frag_size ) { + Weird("fragment_size_inconsistency"); + frag_size = last.upper; + } + + else if ( last.upper < frag_size ) + // Missing the tail. + return; + + // We have it all. Compute the expected size of the fragment. + uint64_t n = proto_hdr_len + frag_size; + + // It's possible that we have blocks associated with this fragment + // that exceed this size, if we saw MF fragments (which don't lead + // to us setting frag_size) that went beyond the size indicated by + // the final, non-MF fragment. This can happen for benign reasons + // due to intermingling of fragments from an older datagram with those + // for a more recent one. + + u_char* pkt = new u_char[n]; + memcpy((void*)pkt, (const void*)proto_hdr, proto_hdr_len); + + u_char* pkt_start = pkt; + + pkt += proto_hdr_len; + + for ( it = block_list.Begin(); it != block_list.End(); ++it ) { + const auto& b = it->second; + + if ( it != block_list.Begin() ) { + const auto& prev = std::prev(it)->second; + + // If we're above a hole, stop. This can happen because + // the logic above regarding a hole that's above the + // expected fragment size. + if ( prev.upper < b.seq ) + break; + } + + if ( b.upper > n ) { + reporter->InternalWarning("bad fragment reassembly"); + DeleteTimer(); + Expire(run_state::network_time); + delete[] pkt_start; + return; + } + + memcpy(&pkt[b.seq], b.block, b.upper - b.seq); + } + + reassembled_pkt.reset(); + + unsigned int version = ((const struct ip*)pkt_start)->ip_v; + + if ( version == 4 ) { + struct ip* reassem4 = (struct ip*)pkt_start; + reassem4->ip_len = htons(frag_size + proto_hdr_len); + reassembled_pkt = std::make_shared(reassem4, true, true); + DeleteTimer(); + } + + else if ( version == 6 ) { + struct ip6_hdr* reassem6 = (struct ip6_hdr*)pkt_start; + reassem6->ip6_plen = htons(frag_size + proto_hdr_len - 40); + const IPv6_Hdr_Chain* chain = new IPv6_Hdr_Chain(reassem6, next_proto, n); + reassembled_pkt = std::make_shared(reassem6, true, n, chain, true); + DeleteTimer(); + } + + else { + reporter->InternalWarning("bad IP version in fragment reassembly: %d", version); + delete[] pkt_start; + } +} + +void FragReassembler::Expire(double t) { + block_list.Clear(); + expire_timer->ClearReassembler(); + expire_timer = nullptr; // timer manager will delete it + + fragment_mgr->Remove(this); +} + +void FragReassembler::DeleteTimer() { + if ( expire_timer ) { + expire_timer->ClearReassembler(); + timer_mgr->Cancel(expire_timer); + expire_timer = nullptr; // timer manager will delete it + } +} + +FragmentManager::~FragmentManager() { Clear(); } + +FragReassembler* FragmentManager::NextFragment(double t, const std::shared_ptr& ip, const u_char* pkt) { + uint32_t frag_id = ip->ID(); + FragReassemblerKey key = std::make_tuple(ip->SrcAddr(), ip->DstAddr(), frag_id); + + FragReassembler* f = nullptr; + auto it = fragments.find(key); + if ( it != fragments.end() ) + f = it->second; + + if ( ! f ) { + f = new FragReassembler(session_mgr, ip, pkt, key, t); + fragments[key] = f; + if ( fragments.size() > max_fragments ) + max_fragments = fragments.size(); + return f; + } + + f->AddFragment(t, ip, pkt); + return f; +} + +void FragmentManager::Clear() { + for ( const auto& entry : fragments ) + Unref(entry.second); + + fragments.clear(); +} + +void FragmentManager::Remove(detail::FragReassembler* f) { + if ( ! f ) + return; + + if ( fragments.erase(f->Key()) == 0 ) + reporter->InternalWarning("fragment reassembler not in dict"); + + Unref(f); +} + +} // namespace zeek::detail diff --git a/src/Frag.h b/src/Frag.h index 7b8c2cac23..e705b4304c 100644 --- a/src/Frag.h +++ b/src/Frag.h @@ -10,102 +10,95 @@ #include "zeek/Timer.h" #include "zeek/util.h" // for zeek_uint_t -namespace zeek - { +namespace zeek { class IP_Hdr; -namespace session - { +namespace session { class Manager; - } +} -namespace detail - { +namespace detail { class FragReassembler; class FragTimer; using FragReassemblerKey = std::tuple; -class FragReassembler : public Reassembler - { +class FragReassembler : public Reassembler { public: - FragReassembler(session::Manager* s, const std::shared_ptr& ip, const u_char* pkt, - const FragReassemblerKey& k, double t); - ~FragReassembler() override; + FragReassembler(session::Manager* s, const std::shared_ptr& ip, const u_char* pkt, + const FragReassemblerKey& k, double t); + ~FragReassembler() override; - void AddFragment(double t, const std::shared_ptr& ip, const u_char* pkt); + void AddFragment(double t, const std::shared_ptr& ip, const u_char* pkt); - void Expire(double t); - void DeleteTimer(); - void ClearTimer() { expire_timer = nullptr; } + void Expire(double t); + void DeleteTimer(); + void ClearTimer() { expire_timer = nullptr; } - std::shared_ptr ReassembledPkt() { return std::move(reassembled_pkt); } - const FragReassemblerKey& Key() const { return key; } + std::shared_ptr ReassembledPkt() { return std::move(reassembled_pkt); } + const FragReassemblerKey& Key() const { return key; } protected: - void BlockInserted(DataBlockMap::const_iterator it) override; - void Overlap(const u_char* b1, const u_char* b2, uint64_t n) override; - void Weird(const char* name) const; + void BlockInserted(DataBlockMap::const_iterator it) override; + void Overlap(const u_char* b1, const u_char* b2, uint64_t n) override; + void Weird(const char* name) const; - u_char* proto_hdr; - std::shared_ptr reassembled_pkt; - session::Manager* s; - uint64_t frag_size; // size of fully reassembled fragment - FragReassemblerKey key; - uint16_t next_proto; // first IPv6 fragment header's next proto field - uint16_t proto_hdr_len; + u_char* proto_hdr; + std::shared_ptr reassembled_pkt; + session::Manager* s; + uint64_t frag_size; // size of fully reassembled fragment + FragReassemblerKey key; + uint16_t next_proto; // first IPv6 fragment header's next proto field + uint16_t proto_hdr_len; - FragTimer* expire_timer; - }; + FragTimer* expire_timer; +}; -class FragTimer final : public Timer - { +class FragTimer final : public Timer { public: - FragTimer(FragReassembler* arg_f, double arg_t) : Timer(arg_t, TIMER_FRAG) { f = arg_f; } - ~FragTimer() override; + FragTimer(FragReassembler* arg_f, double arg_t) : Timer(arg_t, TIMER_FRAG) { f = arg_f; } + ~FragTimer() override; - void Dispatch(double t, bool is_expire) override; + void Dispatch(double t, bool is_expire) override; - // Break the association between this timer and its creator. - void ClearReassembler() { f = nullptr; } + // Break the association between this timer and its creator. + void ClearReassembler() { f = nullptr; } protected: - FragReassembler* f; - }; + FragReassembler* f; +}; -class FragmentManager - { +class FragmentManager { public: - FragmentManager() = default; - ~FragmentManager(); + FragmentManager() = default; + ~FragmentManager(); - FragReassembler* NextFragment(double t, const std::shared_ptr& ip, const u_char* pkt); - void Clear(); - void Remove(detail::FragReassembler* f); + FragReassembler* NextFragment(double t, const std::shared_ptr& ip, const u_char* pkt); + void Clear(); + void Remove(detail::FragReassembler* f); - size_t Size() const { return fragments.size(); } - size_t MaxFragments() const { return max_fragments; } + size_t Size() const { return fragments.size(); } + size_t MaxFragments() const { return max_fragments; } private: - using FragmentMap = std::map; - FragmentMap fragments; - size_t max_fragments = 0; - }; + using FragmentMap = std::map; + FragmentMap fragments; + size_t max_fragments = 0; +}; extern FragmentManager* fragment_mgr; -class FragReassemblerTracker - { +class FragReassemblerTracker { public: - FragReassemblerTracker(FragReassembler* f) : frag_reassembler(f) { } + FragReassemblerTracker(FragReassembler* f) : frag_reassembler(f) {} - ~FragReassemblerTracker() { fragment_mgr->Remove(frag_reassembler); } + ~FragReassemblerTracker() { fragment_mgr->Remove(frag_reassembler); } private: - FragReassembler* frag_reassembler; - }; + FragReassembler* frag_reassembler; +}; - } // namespace detail - } // namespace zeek +} // namespace detail +} // namespace zeek diff --git a/src/Frame.cc b/src/Frame.cc index dd1a9dc598..b0d09c1ca3 100644 --- a/src/Frame.cc +++ b/src/Frame.cc @@ -13,211 +13,185 @@ std::vector g_frame_stack; -namespace zeek::detail - { +namespace zeek::detail { -Frame::Frame(int arg_size, const ScriptFunc* func, const zeek::Args* fn_args) - { - size = arg_size; - frame = std::make_unique(size); - function = func; - func_args = fn_args; +Frame::Frame(int arg_size, const ScriptFunc* func, const zeek::Args* fn_args) { + size = arg_size; + frame = std::make_unique(size); + function = func; + func_args = fn_args; - next_stmt = nullptr; - break_before_next_stmt = false; - break_on_return = false; + next_stmt = nullptr; + break_before_next_stmt = false; + break_on_return = false; - call = nullptr; - delayed = false; + call = nullptr; + delayed = false; - // We could Ref()/Unref() the captures frame, but there's really - // no need because by definition this current frame exists to - // enable execution of the function, and its captures frame won't - // go away until the function itself goes away, which can only be - // after this frame does. - captures = function ? function->GetCapturesFrame() : nullptr; - captures_offset_map = function ? function->GetCapturesOffsetMap() : nullptr; - current_offset = 0; - } + // We could Ref()/Unref() the captures frame, but there's really + // no need because by definition this current frame exists to + // enable execution of the function, and its captures frame won't + // go away until the function itself goes away, which can only be + // after this frame does. + captures = function ? function->GetCapturesFrame() : nullptr; + captures_offset_map = function ? function->GetCapturesOffsetMap() : nullptr; + current_offset = 0; +} -void Frame::SetElement(int n, ValPtr v) - { - n += current_offset; - ASSERT(n >= 0 && n < size); - frame[n] = std::move(v); - } +void Frame::SetElement(int n, ValPtr v) { + n += current_offset; + ASSERT(n >= 0 && n < size); + frame[n] = std::move(v); +} -void Frame::SetElement(const ID* id, ValPtr v) - { - if ( captures ) - { - auto cap_off = captures_offset_map->find(id->Name()); - if ( cap_off != captures_offset_map->end() ) - { - captures->SetElement(cap_off->second, std::move(v)); - return; - } - } +void Frame::SetElement(const ID* id, ValPtr v) { + if ( captures ) { + auto cap_off = captures_offset_map->find(id->Name()); + if ( cap_off != captures_offset_map->end() ) { + captures->SetElement(cap_off->second, std::move(v)); + return; + } + } - SetElement(id->Offset(), std::move(v)); - } + SetElement(id->Offset(), std::move(v)); +} -const ValPtr& Frame::GetElementByID(const ID* id) const - { - if ( captures ) - { - auto cap_off = captures_offset_map->find(id->Name()); - if ( cap_off != captures_offset_map->end() ) - return captures->GetElement(cap_off->second); - } +const ValPtr& Frame::GetElementByID(const ID* id) const { + if ( captures ) { + auto cap_off = captures_offset_map->find(id->Name()); + if ( cap_off != captures_offset_map->end() ) + return captures->GetElement(cap_off->second); + } - return frame[id->Offset() + current_offset]; - } + return frame[id->Offset() + current_offset]; +} -void Frame::Reset(int startIdx) - { - for ( int i = startIdx + current_offset; i < size; ++i ) - frame[i] = nullptr; - } +void Frame::Reset(int startIdx) { + for ( int i = startIdx + current_offset; i < size; ++i ) + frame[i] = nullptr; +} -void Frame::Describe(ODesc* d) const - { - if ( ! d->IsBinary() ) - d->AddSP("frame"); +void Frame::Describe(ODesc* d) const { + if ( ! d->IsBinary() ) + d->AddSP("frame"); - if ( ! d->IsReadable() ) - { - d->Add(size); + if ( ! d->IsReadable() ) { + d->Add(size); - for ( int i = 0; i < size; ++i ) - { - d->Add(frame[i] != nullptr); - d->SP(); - } - } + for ( int i = 0; i < size; ++i ) { + d->Add(frame[i] != nullptr); + d->SP(); + } + } - for ( int i = 0; i < size; ++i ) - if ( frame[i] ) - frame[i]->Describe(d); - else if ( d->IsReadable() ) - d->Add(""); - } + for ( int i = 0; i < size; ++i ) + if ( frame[i] ) + frame[i]->Describe(d); + else if ( d->IsReadable() ) + d->Add(""); +} -Frame* Frame::Clone() const - { - Frame* other = new Frame(size, function, func_args); +Frame* Frame::Clone() const { + Frame* other = new Frame(size, function, func_args); - other->call = call; - other->assoc = assoc; - other->trigger = trigger; + other->call = call; + other->assoc = assoc; + other->trigger = trigger; - for ( int i = 0; i < size; i++ ) - if ( frame[i] ) - other->frame[i] = frame[i]->Clone(); + for ( int i = 0; i < size; i++ ) + if ( frame[i] ) + other->frame[i] = frame[i]->Clone(); - // Note, there's no need to clone "captures" or "captures_offset_map" - // since those get created fresh when constructing "other". + // Note, there's no need to clone "captures" or "captures_offset_map" + // since those get created fresh when constructing "other". - return other; - } + return other; +} -Frame* Frame::CloneForTrigger() const - { - Frame* other = new Frame(0, function, func_args); +Frame* Frame::CloneForTrigger() const { + Frame* other = new Frame(0, function, func_args); - other->call = call; - other->assoc = assoc; - other->trigger = trigger; + other->call = call; + other->assoc = assoc; + other->trigger = trigger; - return other; - } + return other; +} -static bool val_is_func(const ValPtr& v, ScriptFunc* func) - { - if ( v->GetType()->Tag() != TYPE_FUNC ) - return false; +static bool val_is_func(const ValPtr& v, ScriptFunc* func) { + if ( v->GetType()->Tag() != TYPE_FUNC ) + return false; - return v->AsFunc() == func; - } + return v->AsFunc() == func; +} -broker::expected Frame::Serialize() - { - broker::vector body; +broker::expected Frame::Serialize() { + broker::vector body; - for ( int i = 0; i < size; ++i ) - { - const auto& val = frame[i]; - auto expected = Broker::detail::val_to_data(val.get()); - if ( ! expected ) - return broker::ec::invalid_data; + for ( int i = 0; i < size; ++i ) { + const auto& val = frame[i]; + auto expected = Broker::detail::val_to_data(val.get()); + if ( ! expected ) + return broker::ec::invalid_data; - TypeTag tag = val->GetType()->Tag(); - broker::vector val_tuple{std::move(*expected), static_cast(tag)}; - body.emplace_back(std::move(val_tuple)); - } + TypeTag tag = val->GetType()->Tag(); + broker::vector val_tuple{std::move(*expected), static_cast(tag)}; + body.emplace_back(std::move(val_tuple)); + } - broker::vector rval; - rval.emplace_back(std::move(body)); + broker::vector rval; + rval.emplace_back(std::move(body)); - return {std::move(rval)}; - } + return {std::move(rval)}; +} -std::pair Frame::Unserialize(const broker::vector& data) - { - if ( data.size() == 0 ) - return std::make_pair(true, nullptr); +std::pair Frame::Unserialize(const broker::vector& data) { + if ( data.size() == 0 ) + return std::make_pair(true, nullptr); - auto where = data.begin(); - auto has_body = broker::get_if(*where); - if ( ! has_body ) - return std::make_pair(false, nullptr); + auto where = data.begin(); + auto has_body = broker::get_if(*where); + if ( ! has_body ) + return std::make_pair(false, nullptr); - broker::vector body = *has_body; - int frame_size = body.size(); - auto rf = make_intrusive(frame_size, nullptr, nullptr); + broker::vector body = *has_body; + int frame_size = body.size(); + auto rf = make_intrusive(frame_size, nullptr, nullptr); - for ( int i = 0; i < frame_size; ++i ) - { - auto has_vec = broker::get_if(body[i]); - if ( ! has_vec ) - continue; + for ( int i = 0; i < frame_size; ++i ) { + auto has_vec = broker::get_if(body[i]); + if ( ! has_vec ) + continue; - broker::vector val_tuple = *has_vec; - if ( val_tuple.size() != 2 ) - return std::make_pair(false, nullptr); + broker::vector val_tuple = *has_vec; + if ( val_tuple.size() != 2 ) + return std::make_pair(false, nullptr); - auto has_type = broker::get_if(val_tuple[1]); - if ( ! has_type ) - return std::make_pair(false, nullptr); + auto has_type = broker::get_if(val_tuple[1]); + if ( ! has_type ) + return std::make_pair(false, nullptr); - broker::integer g = *has_type; - Type t(static_cast(g)); + broker::integer g = *has_type; + Type t(static_cast(g)); - auto val = Broker::detail::data_to_val(std::move(val_tuple[0]), &t); - if ( ! val ) - return std::make_pair(false, nullptr); + auto val = Broker::detail::data_to_val(std::move(val_tuple[0]), &t); + if ( ! val ) + return std::make_pair(false, nullptr); - rf->frame[i] = std::move(val); - } + rf->frame[i] = std::move(val); + } - return std::make_pair(true, std::move(rf)); - } + return std::make_pair(true, std::move(rf)); +} -const detail::Location* Frame::GetCallLocation() const - { - // This is currently trivial, but we keep it as an explicit - // method because it can provide flexibility for compiled code. - return call->GetLocationInfo(); - } +const detail::Location* Frame::GetCallLocation() const { + // This is currently trivial, but we keep it as an explicit + // method because it can provide flexibility for compiled code. + return call->GetLocationInfo(); +} -void Frame::SetTrigger(trigger::TriggerPtr arg_trigger) - { - trigger = std::move(arg_trigger); - } +void Frame::SetTrigger(trigger::TriggerPtr arg_trigger) { trigger = std::move(arg_trigger); } -void Frame::ClearTrigger() - { - trigger = nullptr; - } +void Frame::ClearTrigger() { trigger = nullptr; } - } +} // namespace zeek::detail diff --git a/src/Frame.h b/src/Frame.h index 8b169b42df..05407aaa97 100644 --- a/src/Frame.h +++ b/src/Frame.h @@ -17,250 +17,244 @@ #include "zeek/ZeekArgs.h" #include "zeek/ZeekList.h" // for typedef val_list -namespace zeek - { +namespace zeek { using ValPtr = IntrusivePtr; -namespace detail - { +namespace detail { class CallExpr; class ScriptFunc; using IDPtr = IntrusivePtr; -namespace trigger - { +namespace trigger { class Trigger; using TriggerPtr = IntrusivePtr; - } +} // namespace trigger class Frame; using FramePtr = IntrusivePtr; -class Frame : public Obj - { +class Frame : public Obj { public: - /** - * Constructs a new frame belonging to *func* with *fn_args* - * arguments. - * - * @param the size of the frame - * @param func the function that is creating this frame - * @param fn_args the arguments being passed to that function. - */ - Frame(int size, const ScriptFunc* func, const zeek::Args* fn_args); + /** + * Constructs a new frame belonging to *func* with *fn_args* + * arguments. + * + * @param the size of the frame + * @param func the function that is creating this frame + * @param fn_args the arguments being passed to that function. + */ + Frame(int size, const ScriptFunc* func, const zeek::Args* fn_args); - /** - * Returns the size of the frame. - * - * @return the number of elements in the frame. - */ - int FrameSize() const { return size; } + /** + * Returns the size of the frame. + * + * @return the number of elements in the frame. + */ + int FrameSize() const { return size; } - /** - * @param n the index to get. - * @return the value at index *n* of the underlying array. - */ - const ValPtr& GetElement(int n) const - { - // Note: technically this may want to adjust by current_offset, but - // in practice, this method is never called from anywhere other than - // function call invocation, where current_offset should be zero. - return frame[n]; - } + /** + * @param n the index to get. + * @return the value at index *n* of the underlying array. + */ + const ValPtr& GetElement(int n) const { + // Note: technically this may want to adjust by current_offset, but + // in practice, this method is never called from anywhere other than + // function call invocation, where current_offset should be zero. + return frame[n]; + } - /** - * Sets the element at index *n* of the underlying array to *v*. - * @param n the index to set - * @param v the value to set it to - */ - void SetElement(int n, ValPtr v); + /** + * Sets the element at index *n* of the underlying array to *v*. + * @param n the index to set + * @param v the value to set it to + */ + void SetElement(int n, ValPtr v); - /** - * Associates *id* and *v* in the frame. Future lookups of - * *id* will return *v*. - * - * @param id the ID to associate - * @param v the value to associate it with - */ - void SetElement(const ID* id, ValPtr v); - void SetElement(const IDPtr& id, ValPtr v) { SetElement(id.get(), std::move(v)); } + /** + * Associates *id* and *v* in the frame. Future lookups of + * *id* will return *v*. + * + * @param id the ID to associate + * @param v the value to associate it with + */ + void SetElement(const ID* id, ValPtr v); + void SetElement(const IDPtr& id, ValPtr v) { SetElement(id.get(), std::move(v)); } - /** - * Gets the value associated with *id* and returns it. Returns - * nullptr if no such element exists. - * - * @param id the id who's value to retrieve - * @return the value associated with *id* - */ - const ValPtr& GetElementByID(const IDPtr& id) const { return GetElementByID(id.get()); } + /** + * Gets the value associated with *id* and returns it. Returns + * nullptr if no such element exists. + * + * @param id the id who's value to retrieve + * @return the value associated with *id* + */ + const ValPtr& GetElementByID(const IDPtr& id) const { return GetElementByID(id.get()); } - /** - * Adjusts the current offset being used for frame accesses. - * This is in support of inlined functions. - * - * @param incr Amount by which to increase the frame offset. - * Use a negative value to shrink the offset. - */ - void AdjustOffset(int incr) { current_offset += incr; } + /** + * Adjusts the current offset being used for frame accesses. + * This is in support of inlined functions. + * + * @param incr Amount by which to increase the frame offset. + * Use a negative value to shrink the offset. + */ + void AdjustOffset(int incr) { current_offset += incr; } - /** - * Resets all of the indexes from [*startIdx, frame_size) in - * the Frame. - * @param the first index to unref. - */ - void Reset(int startIdx); + /** + * Resets all of the indexes from [*startIdx, frame_size) in + * the Frame. + * @param the first index to unref. + */ + void Reset(int startIdx); - /** - * Describes the frame and all of its values. - */ - void Describe(ODesc* d) const override; + /** + * Describes the frame and all of its values. + */ + void Describe(ODesc* d) const override; - /** - * @return the function that the frame is associated with. - */ - const ScriptFunc* GetFunction() const { return function; } + /** + * @return the function that the frame is associated with. + */ + const ScriptFunc* GetFunction() const { return function; } - /** - * @return the arguments passed to the function that this frame - * is associated with. - */ - const Args* GetFuncArgs() const { return func_args; } + /** + * @return the arguments passed to the function that this frame + * is associated with. + */ + const Args* GetFuncArgs() const { return func_args; } - /** - * Change the function that the frame is associated with. - * - * @param func the function for the frame to be associated with. - */ - void SetFunction(ScriptFunc* func) { function = func; } + /** + * Change the function that the frame is associated with. + * + * @param func the function for the frame to be associated with. + */ + void SetFunction(ScriptFunc* func) { function = func; } - /** - * Sets the next statement to be executed in the context of the frame. - * - * @param stmt the statement to set it to. - */ - void SetNextStmt(Stmt* stmt) { next_stmt = stmt; } + /** + * Sets the next statement to be executed in the context of the frame. + * + * @param stmt the statement to set it to. + */ + void SetNextStmt(Stmt* stmt) { next_stmt = stmt; } - /** - * @return the next statement to be executed in the context of the frame. - */ - Stmt* GetNextStmt() const { return next_stmt; } + /** + * @return the next statement to be executed in the context of the frame. + */ + Stmt* GetNextStmt() const { return next_stmt; } - /** Used to implement "next" command in debugger. */ - void BreakBeforeNextStmt(bool should_break) { break_before_next_stmt = should_break; } - bool BreakBeforeNextStmt() const { return break_before_next_stmt; } + /** Used to implement "next" command in debugger. */ + void BreakBeforeNextStmt(bool should_break) { break_before_next_stmt = should_break; } + bool BreakBeforeNextStmt() const { return break_before_next_stmt; } - /** Used to implement "finish" command in debugger. */ - void BreakOnReturn(bool should_break) { break_on_return = should_break; } - bool BreakOnReturn() const { return break_on_return; } + /** Used to implement "finish" command in debugger. */ + void BreakOnReturn(bool should_break) { break_on_return = should_break; } + bool BreakOnReturn() const { return break_on_return; } - /** - * Performs a deep copy of all the values in the current frame. - * - * @return a copy of this frame. - */ - Frame* Clone() const; + /** + * Performs a deep copy of all the values in the current frame. + * + * @return a copy of this frame. + */ + Frame* Clone() const; - /** - * Creates a copy of the frame that just includes its trigger context. - * - * @return a partial copy of this frame. - */ - Frame* CloneForTrigger() const; + /** + * Creates a copy of the frame that just includes its trigger context. + * + * @return a partial copy of this frame. + */ + Frame* CloneForTrigger() const; - /** - * Serializes the frame (only done for lambda/when captures) as a - * sequence of two-element vectors, the first element reflecting - * the frame value, the second its type. - */ - broker::expected Serialize(); + /** + * Serializes the frame (only done for lambda/when captures) as a + * sequence of two-element vectors, the first element reflecting + * the frame value, the second its type. + */ + broker::expected Serialize(); - /** - * Instantiates a Frame from a serialized one. - * - * @return a pair in which the first item is the status of the serialization; - * and the second is the unserialized frame with reference count +1, or - * null if the serialization wasn't successful. - */ - static std::pair Unserialize(const broker::vector& data); + /** + * Instantiates a Frame from a serialized one. + * + * @return a pair in which the first item is the status of the serialization; + * and the second is the unserialized frame with reference count +1, or + * null if the serialization wasn't successful. + */ + static std::pair Unserialize(const broker::vector& data); - // If the frame is run in the context of a trigger condition evaluation, - // the trigger needs to be registered. - void SetTrigger(trigger::TriggerPtr arg_trigger); - void ClearTrigger(); - trigger::Trigger* GetTrigger() const { return trigger.get(); } + // If the frame is run in the context of a trigger condition evaluation, + // the trigger needs to be registered. + void SetTrigger(trigger::TriggerPtr arg_trigger); + void ClearTrigger(); + trigger::Trigger* GetTrigger() const { return trigger.get(); } - void SetCall(const CallExpr* arg_call) - { - call = arg_call; - SetTriggerAssoc((void*)call); - } - void SetOnlyCall(const CallExpr* arg_call) { call = arg_call; } - const CallExpr* GetCall() const { return call; } + void SetCall(const CallExpr* arg_call) { + call = arg_call; + SetTriggerAssoc((void*)call); + } + void SetOnlyCall(const CallExpr* arg_call) { call = arg_call; } + const CallExpr* GetCall() const { return call; } - void SetTriggerAssoc(const void* arg_assoc) { assoc = arg_assoc; } - const void* GetTriggerAssoc() const { return assoc; } + void SetTriggerAssoc(const void* arg_assoc) { assoc = arg_assoc; } + const void* GetTriggerAssoc() const { return assoc; } - const detail::Location* GetCallLocation() const; + const detail::Location* GetCallLocation() const; - void SetDelayed() { delayed = true; } - bool HasDelayed() const { return delayed; } + void SetDelayed() { delayed = true; } + bool HasDelayed() const { return delayed; } private: - using OffsetMap = std::unordered_map; + using OffsetMap = std::unordered_map; - // This has a trivial form now, but used to hold additional - // information, which is why we abstract it away from just being - // a ValPtr. - using Element = ValPtr; + // This has a trivial form now, but used to hold additional + // information, which is why we abstract it away from just being + // a ValPtr. + using Element = ValPtr; - const ValPtr& GetElementByID(const ID* id) const; + const ValPtr& GetElementByID(const ID* id) const; - /** The number of vals that can be stored in this frame. */ - int size; + /** The number of vals that can be stored in this frame. */ + int size; - bool break_before_next_stmt; - bool break_on_return; - bool delayed; + bool break_before_next_stmt; + bool break_on_return; + bool delayed; - /** Associates ID's offsets with values. */ - std::unique_ptr frame; + /** Associates ID's offsets with values. */ + std::unique_ptr frame; - /** - * The offset we're currently using for references into the frame. - * This is how we support inlined functions without having to - * alter the offsets associated with their local variables. - */ - int current_offset; + /** + * The offset we're currently using for references into the frame. + * This is how we support inlined functions without having to + * alter the offsets associated with their local variables. + */ + int current_offset; - /** Frame used for lambda/when captures. */ - Frame* captures; + /** Frame used for lambda/when captures. */ + Frame* captures; - /** Maps IDs to offsets into the "captures" frame. If the ID - * isn't present, then it's not a capture. - */ - const OffsetMap* captures_offset_map; + /** Maps IDs to offsets into the "captures" frame. If the ID + * isn't present, then it's not a capture. + */ + const OffsetMap* captures_offset_map; - /** The function this frame is associated with. */ - const ScriptFunc* function; + /** The function this frame is associated with. */ + const ScriptFunc* function; - // The following is only needed for the debugger. - /** The arguments to the function that this Frame is associated with. */ - const zeek::Args* func_args; + // The following is only needed for the debugger. + /** The arguments to the function that this Frame is associated with. */ + const zeek::Args* func_args; - /** The next statement to be evaluated in the context of this frame. */ - Stmt* next_stmt; + /** The next statement to be evaluated in the context of this frame. */ + Stmt* next_stmt; - trigger::TriggerPtr trigger; - const CallExpr* call = nullptr; - const void* assoc = nullptr; - }; + trigger::TriggerPtr trigger; + const CallExpr* call = nullptr; + const void* assoc = nullptr; +}; - } // namespace detail - } // namespace zeek +} // namespace detail +} // namespace zeek /** * If we stopped using this and instead just made a struct of the information diff --git a/src/Func.cc b/src/Func.cc index 885a4333cd..efd9b5a992 100644 --- a/src/Func.cc +++ b/src/Func.cc @@ -66,1151 +66,999 @@ extern RETSIGTYPE sig_handler(int signo); -namespace zeek::detail - { +namespace zeek::detail { std::vector call_stack; bool did_builtin_init = false; std::vector bif_initializers; static const std::pair empty_hook_result(false, nullptr); - } // namespace zeek::detail - -namespace zeek - { - -std::string render_call_stack() - { - std::string rval; - int lvl = 0; - - if ( ! detail::call_stack.empty() ) - rval += "| "; - - for ( auto it = detail::call_stack.rbegin(); it != detail::call_stack.rend(); ++it ) - { - if ( lvl > 0 ) - rval += " | "; - - auto& ci = *it; - auto name = ci.func->Name(); - std::string arg_desc; - - for ( const auto& arg : ci.args ) - { - ODesc d; - d.SetShort(); - arg->Describe(&d); - - if ( ! arg_desc.empty() ) - arg_desc += ", "; - - arg_desc += d.Description(); - } - - rval += util::fmt("#%d %s(%s)", lvl, name, arg_desc.data()); - - if ( ci.call ) - { - auto loc = ci.call->GetLocationInfo(); - rval += util::fmt(" at %s:%d", loc->filename, loc->first_line); - } - - ++lvl; - } - - if ( ! detail::call_stack.empty() ) - rval += " |"; - - return rval; - } - -void Func::AddBody(const detail::FunctionIngredients& ingr, detail::StmtPtr new_body) - { - if ( ! new_body ) - new_body = ingr.Body(); - - AddBody(new_body, ingr.Inits(), ingr.FrameSize(), ingr.Priority(), ingr.Groups()); - } - -void Func::AddBody(detail::StmtPtr new_body, const std::vector& new_inits, - size_t new_frame_size, int priority) - { - std::set groups; - AddBody(new_body, new_inits, new_frame_size, priority, groups); - } - -void Func::AddBody(detail::StmtPtr new_body, size_t new_frame_size) - { - std::vector no_inits; - std::set no_groups; - AddBody(std::move(new_body), no_inits, new_frame_size, 0, no_groups); - } - -void Func::AddBody(detail::StmtPtr /* new_body */, - const std::vector& /* new_inits */, size_t /* new_frame_size */, - int /* priority */, const std::set& /* groups */) - { - Internal("Func::AddBody called"); - } - -void Func::SetScope(detail::ScopePtr newscope) - { - scope = std::move(newscope); - } - -FuncPtr Func::DoClone() - { - // By default, ok just to return a reference. Func does not have any state - // that is different across instances. - return {NewRef{}, this}; - } - -void Func::DescribeDebug(ODesc* d, const Args* args) const - { - d->Add(Name()); - - if ( args ) - { - d->Add("("); - const auto& func_args = GetType()->Params(); - auto num_fields = static_cast(func_args->NumFields()); - - for ( auto i = 0u; i < args->size(); ++i ) - { - // Handle varargs case (more args than formals). - if ( i >= num_fields ) - { - d->Add("vararg"); - int va_num = i - num_fields; - d->Add(va_num); - } - else - d->Add(func_args->FieldName(i)); - - d->Add(" = '"); - (*args)[i]->Describe(d); - - if ( i < args->size() - 1 ) - d->Add("', "); - else - d->Add("'"); - } - - d->Add(")"); - } - } - -detail::TraversalCode Func::Traverse(detail::TraversalCallback* cb) const - { - // FIXME: Make a fake scope for builtins? - auto old_scope = cb->current_scope; - cb->current_scope = scope; - - detail::TraversalCode tc = cb->PreFunction(this); - HANDLE_TC_STMT_PRE(tc); - - // FIXME: Traverse arguments to builtin functions, too. - if ( kind == SCRIPT_FUNC && scope ) - { - tc = scope->Traverse(cb); - HANDLE_TC_STMT_PRE(tc); - - for ( const auto& body : bodies ) - { - tc = body.stmts->Traverse(cb); - HANDLE_TC_STMT_PRE(tc); - } - } - - tc = cb->PostFunction(this); - - cb->current_scope = old_scope; - HANDLE_TC_STMT_POST(tc); - } - -void Func::CopyStateInto(Func* other) const - { - other->bodies = bodies; - other->scope = scope; - other->kind = kind; - - other->type = type; - - other->name = name; - } - -void Func::CheckPluginResult(bool handled, const ValPtr& hook_result, FunctionFlavor flavor) const - { - // Helper function factoring out this code from ScriptFunc:Call() for - // better readability. - - if ( ! handled ) - { - if ( hook_result ) - reporter->InternalError( - "plugin set processed flag to false but actually returned a value"); - - // The plugin result hasn't been processed yet (read: fall - // into ::Call method). - return; - } - - switch ( flavor ) - { - case FUNC_FLAVOR_EVENT: - if ( hook_result ) - reporter->InternalError("plugin returned non-void result for event %s", - this->Name()); - - break; - - case FUNC_FLAVOR_HOOK: - if ( hook_result->GetType()->Tag() != TYPE_BOOL ) - reporter->InternalError("plugin returned non-bool for hook %s", this->Name()); - - break; - - case FUNC_FLAVOR_FUNCTION: - { - const auto& yt = GetType()->Yield(); - - if ( (! yt) || yt->Tag() == TYPE_VOID ) - { - if ( hook_result ) - reporter->InternalError("plugin returned non-void result for void method %s", - this->Name()); - } - - else if ( hook_result && hook_result->GetType()->Tag() != yt->Tag() && - yt->Tag() != TYPE_ANY ) - { - reporter->InternalError("plugin returned wrong type (got %d, expecting %d) for %s", - hook_result->GetType()->Tag(), yt->Tag(), this->Name()); - } - - break; - } - } - } - -namespace detail - { - -ScriptFunc::ScriptFunc(const IDPtr& arg_id) : Func(SCRIPT_FUNC) - { - name = arg_id->Name(); - type = arg_id->GetType(); - frame_size = 0; - } - -ScriptFunc::ScriptFunc(std::string _name, FuncTypePtr ft, std::vector bs, - std::vector priorities) - { - name = std::move(_name); - frame_size = ft->ParamList()->GetTypes().size(); - type = std::move(ft); - - auto n = bs.size(); - ASSERT(n == priorities.size()); - - for ( auto i = 0u; i < n; ++i ) - { - Body b; - b.stmts = std::move(bs[i]); - b.priority = priorities[i]; - bodies.push_back(b); - } - - std::stable_sort(bodies.begin(), bodies.end()); - - if ( ! bodies.empty() ) - { - current_body = bodies[0].stmts; - current_priority = bodies[0].priority; - } - } - -ScriptFunc::~ScriptFunc() - { - if ( captures_vec ) - { - auto& cvec = *captures_vec; - auto& captures = *type->GetCaptures(); - for ( auto i = 0u; i < captures.size(); ++i ) - if ( captures[i].IsManaged() ) - ZVal::DeleteManagedType(cvec[i]); - } - - delete captures_frame; - delete captures_offset_mapping; - } - -bool ScriptFunc::IsPure() const - { - return std::all_of(bodies.begin(), bodies.end(), - [](const Body& b) - { - return b.stmts->IsPure(); - }); - } - -ValPtr ScriptFunc::Invoke(zeek::Args* args, Frame* parent) const - { - SegmentProfiler prof(segment_logger, location); - - if ( sample_logger ) - sample_logger->FunctionSeen(this); - - auto [handled, hook_result] = PLUGIN_HOOK_WITH_RESULT( - HOOK_CALL_FUNCTION, HookCallFunction(this, parent, args), empty_hook_result); - - CheckPluginResult(handled, hook_result, Flavor()); - - if ( handled ) - return hook_result; - - if ( bodies.empty() ) - { - // Can only happen for events and hooks. - assert(Flavor() == FUNC_FLAVOR_EVENT || Flavor() == FUNC_FLAVOR_HOOK); - return Flavor() == FUNC_FLAVOR_HOOK ? val_mgr->True() : nullptr; - } - - auto f = make_intrusive(frame_size, this, args); - - // Hand down any trigger. - if ( parent ) - { - f->SetTrigger({NewRef{}, parent->GetTrigger()}); - f->SetTriggerAssoc(parent->GetTriggerAssoc()); - } - - g_frame_stack.push_back(f.get()); // used for backtracing - const CallExpr* call_expr = parent ? parent->GetCall() : nullptr; - call_stack.emplace_back(CallInfo{call_expr, this, *args}); - - // If a script function is ever invoked with more arguments than it has - // parameters log an error and return. Most likely a "variadic function" - // that only has a single any parameter and is excluded from static type - // checking is involved. This should otherwise not be possible to hit. - auto num_params = static_cast(GetType()->Params()->NumFields()); - if ( args->size() > num_params ) - { - emit_builtin_exception("too many arguments for function call"); - return nullptr; - } - - if ( etm && Flavor() == FUNC_FLAVOR_EVENT ) - etm->StartEvent(this, args); - - if ( g_trace_state.DoTrace() ) - { - ODesc d; - DescribeDebug(&d, args); - - g_trace_state.LogTrace("%s called: %s\n", GetType()->FlavorString().c_str(), - d.Description()); - } - - StmtFlowType flow = FLOW_NEXT; - ValPtr result; - - for ( const auto& body : bodies ) - { - if ( body.disabled ) - continue; - - if ( sample_logger ) - sample_logger->LocationSeen(body.stmts->GetLocationInfo()); - - // Fill in the rest of the frame with the function's arguments. - for ( auto j = 0u; j < args->size(); ++j ) - { - const auto& arg = (*args)[j]; - - if ( f->GetElement(j) != arg ) - // Either not yet set, or somebody reassigned the frame slot. - f->SetElement(j, arg); - } - - if ( spm ) - spm->StartInvocation(this, body.stmts); - - f->Reset(args->size()); - - try - { - result = body.stmts->Exec(f.get(), flow); - } - - catch ( InterpreterException& e ) - { - // Already reported, but now determine whether to unwind further. - if ( Flavor() == FUNC_FLAVOR_FUNCTION ) - { - g_frame_stack.pop_back(); - call_stack.pop_back(); - // Result not set b/c exception was thrown - throw; - } - - // Continue exec'ing remaining bodies of hooks/events. - continue; - } - - if ( spm ) - spm->EndInvocation(); - - if ( f->HasDelayed() ) - { - assert(! result); - assert(parent); - parent->SetDelayed(); - break; - } - - if ( Flavor() == FUNC_FLAVOR_HOOK ) - { - // Ignore any return values of hook bodies, final return value - // depends on whether a body returns as a result of break statement. - result = nullptr; - - if ( flow == FLOW_BREAK ) - { - // Short-circuit execution of remaining hook handler bodies. - result = val_mgr->False(); - break; - } - } - } - - call_stack.pop_back(); - - if ( Flavor() == FUNC_FLAVOR_HOOK ) - { - if ( ! result ) - result = val_mgr->True(); - } - - else if ( etm && Flavor() == FUNC_FLAVOR_EVENT ) - etm->EndEvent(this, args); - - // Warn if the function returns something, but we returned from - // the function without an explicit return, or without a value. - else if ( GetType()->Yield() && GetType()->Yield()->Tag() != TYPE_VOID && - (flow != FLOW_RETURN /* we fell off the end */ || - ! result /* explicit return with no result */) && - ! f->HasDelayed() ) - reporter->Warning("non-void function returning without a value: %s", Name()); - - if ( result && g_trace_state.DoTrace() ) - { - ODesc d; - result->Describe(&d); - - g_trace_state.LogTrace("Function return: %s\n", d.Description()); - } - - g_frame_stack.pop_back(); - - return result; - } - -void ScriptFunc::CreateCaptures(Frame* f) - { - const auto& captures = type->GetCaptures(); - - if ( ! captures ) - return; - - // Create *either* a private Frame to hold the values of captured - // variables, and a mapping from those variables to their offsets - // in the Frame; *or* a ZVal frame if this script has a ZAM-compiled - // body. - ASSERT(bodies.size() == 1); - - if ( bodies[0].stmts->Tag() == STMT_ZAM ) - captures_vec = std::make_unique>(); - else - { - delete captures_frame; - delete captures_offset_mapping; - captures_frame = new Frame(captures->size(), this, nullptr); - captures_offset_mapping = new OffsetMap; - } - - int offset = 0; - for ( const auto& c : *captures ) - { - auto v = f->GetElementByID(c.Id()); - - if ( v ) - { - if ( c.IsDeepCopy() || ! v->Modifiable() ) - v = v->Clone(); - - if ( captures_vec ) - // Don't use v->GetType() here, as that might - // be "any" and we need to convert. - captures_vec->push_back(ZVal(v, c.Id()->GetType())); - else - captures_frame->SetElement(offset, std::move(v)); - } - - else if ( captures_vec ) - captures_vec->push_back(ZVal()); - - if ( ! captures_vec ) - captures_offset_mapping->insert_or_assign(c.Id()->Name(), offset); - - ++offset; - } - } - -void ScriptFunc::CreateCaptures(std::unique_ptr> cvec) - { - const auto& captures = *type->GetCaptures(); - - ASSERT(cvec->size() == captures.size()); - ASSERT(bodies.size() == 1 && bodies[0].stmts->Tag() == STMT_ZAM); - - captures_vec = std::move(cvec); - - auto n = captures.size(); - for ( auto i = 0U; i < n; ++i ) - { - auto& c_i = captures[i]; - auto& cv_i = (*captures_vec)[i]; - - if ( c_i.IsDeepCopy() ) - { - auto& t = c_i.Id()->GetType(); - auto new_cv_i = cv_i.ToVal(t)->Clone(); - if ( c_i.IsManaged() ) - ZVal::DeleteManagedType(cv_i); - - cv_i = ZVal(std::move(new_cv_i), t); - } - } - } - -void ScriptFunc::SetCaptures(Frame* f) - { - const auto& captures = type->GetCaptures(); - ASSERT(captures); - - delete captures_frame; - delete captures_offset_mapping; - captures_frame = f; - captures_offset_mapping = new OffsetMap; - - int offset = 0; - for ( const auto& c : *captures ) - { - captures_offset_mapping->insert_or_assign(c.Id()->Name(), offset); - ++offset; - } - } - -void ScriptFunc::AddBody(StmtPtr new_body, const std::vector& new_inits, - size_t new_frame_size, int priority, const std::set& groups) - { - if ( new_frame_size > frame_size ) - frame_size = new_frame_size; - - auto num_args = static_cast(GetType()->Params()->NumFields()); - - if ( num_args > frame_size ) - frame_size = num_args; - - new_body = AddInits(std::move(new_body), new_inits); - - if ( Flavor() == FUNC_FLAVOR_FUNCTION ) - { - // For functions, we replace the old body with the new one. - assert(bodies.size() <= 1); - bodies.clear(); - } - - Body b; - b.stmts = new_body; - b.groups = groups; - current_body = new_body; - current_priority = b.priority = priority; - - bodies.push_back(b); - std::stable_sort(bodies.begin(), bodies.end()); - } - -void ScriptFunc::ReplaceBody(const StmtPtr& old_body, StmtPtr new_body) - { - bool found_it = false; - - for ( auto body = bodies.begin(); body != bodies.end(); ++body ) - if ( body->stmts.get() == old_body.get() ) - { - if ( new_body ) - { - body->stmts = new_body; - current_priority = body->priority; - } - else - bodies.erase(body); - - found_it = true; - break; - } - - current_body = new_body; - } - -bool ScriptFunc::DeserializeCaptures(const broker::vector& data) - { - auto result = Frame::Unserialize(data); - - ASSERT(result.first); - - auto& f = result.second; - - if ( bodies[0].stmts->Tag() == STMT_ZAM ) - { - auto& captures = *type->GetCaptures(); - int n = f->FrameSize(); - - ASSERT(captures.size() == static_cast(n)); - - auto cvec = std::make_unique>(); - - for ( int i = 0; i < n; ++i ) - { - auto& f_i = f->GetElement(i); - cvec->push_back(ZVal(f_i, captures[i].Id()->GetType())); - } - - CreateCaptures(std::move(cvec)); - } - - else - SetCaptures(f.release()); - - return true; - } - -FuncPtr ScriptFunc::DoClone() - { - // ScriptFunc could hold a closure. In this case a clone of it must - // store a copy of this closure. - // - // We don't use make_intrusive<> directly because we're accessing - // a protected constructor. - auto other = IntrusivePtr{AdoptRef{}, new ScriptFunc()}; - - CopyStateInto(other.get()); - - other->frame_size = frame_size; - other->outer_ids = outer_ids; - - if ( captures_frame ) - { - other->captures_frame = captures_frame->Clone(); - other->captures_offset_mapping = new OffsetMap; - *other->captures_offset_mapping = *captures_offset_mapping; - } - - if ( captures_vec ) - { - auto cv_i = captures_vec->begin(); - other->captures_vec = std::make_unique>(); - for ( auto& c : *type->GetCaptures() ) - { - // Need to clone cv_i. - auto& t_i = c.Id()->GetType(); - auto cv_i_val = cv_i->ToVal(t_i)->Clone(); - other->captures_vec->push_back(ZVal(std::move(cv_i_val), t_i)); - ++cv_i; - } - } - - return other; - } - -broker::expected ScriptFunc::SerializeCaptures() const - { - if ( captures_vec ) - { - auto& cv = *captures_vec; - auto& captures = *type->GetCaptures(); - int n = captures_vec->size(); - auto temp_frame = make_intrusive(n, this, nullptr); - - for ( int i = 0; i < n; ++i ) - { - auto c_i = cv[i].ToVal(captures[i].Id()->GetType()); - temp_frame->SetElement(i, c_i); - } - - return temp_frame->Serialize(); - } - - if ( captures_frame ) - return captures_frame->Serialize(); - - // No captures, return an empty vector. - return broker::vector{}; - } - -void ScriptFunc::Describe(ODesc* d) const - { - d->Add(Name()); - - d->NL(); - d->AddCount(frame_size); - for ( const auto& body : bodies ) - { - body.stmts->AccessStats(d); - body.stmts->Describe(d); - } - } - -StmtPtr ScriptFunc::AddInits(StmtPtr body, const std::vector& inits) - { - if ( inits.empty() ) - return body; - - auto stmt_series = make_intrusive(); - stmt_series->Stmts().push_back(make_intrusive(inits)); - stmt_series->Stmts().push_back(std::move(body)); - - return stmt_series; - } - -BuiltinFunc::BuiltinFunc(built_in_func arg_func, const char* arg_name, bool arg_is_pure) - : Func(BUILTIN_FUNC) - { - func = arg_func; - name = make_full_var_name(GLOBAL_MODULE_NAME, arg_name); - is_pure = arg_is_pure; - - const auto& id = lookup_ID(Name(), GLOBAL_MODULE_NAME, false); - if ( ! id ) - reporter->InternalError("built-in function %s missing", Name()); - if ( id->HasVal() ) - reporter->InternalError("built-in function %s multiply defined", Name()); - - type = id->GetType(); - id->SetVal(make_intrusive(IntrusivePtr{NewRef{}, this})); - id->SetConst(); - } - -bool BuiltinFunc::IsPure() const - { - return is_pure; - } - -ValPtr BuiltinFunc::Invoke(Args* args, Frame* parent) const - { - if ( spm ) - spm->StartInvocation(this); - - SegmentProfiler prof(segment_logger, Name()); - - if ( sample_logger ) - sample_logger->FunctionSeen(this); - - auto [handled, hook_result] = PLUGIN_HOOK_WITH_RESULT( - HOOK_CALL_FUNCTION, HookCallFunction(this, parent, args), empty_hook_result); - - CheckPluginResult(handled, hook_result, FUNC_FLAVOR_FUNCTION); - - if ( handled ) - { - if ( spm ) - spm->EndInvocation(); - return hook_result; - } - - if ( g_trace_state.DoTrace() ) - { - ODesc d; - DescribeDebug(&d, args); - - g_trace_state.LogTrace("\tBuiltin Function called: %s\n", d.Description()); - } - - const CallExpr* call_expr = parent ? parent->GetCall() : nullptr; - call_stack.emplace_back(CallInfo{call_expr, this, *args}); - auto result = std::move(func(parent, args).rval); - call_stack.pop_back(); - - if ( result && g_trace_state.DoTrace() ) - { - ODesc d; - result->Describe(&d); - - g_trace_state.LogTrace("\tFunction return: %s\n", d.Description()); - } - - if ( spm ) - spm->EndInvocation(); - - return result; - } - -void BuiltinFunc::Describe(ODesc* d) const - { - d->Add(Name()); - d->AddCount(is_pure); - } - -bool check_built_in_call(BuiltinFunc* f, CallExpr* call) - { - if ( f->TheFunc() != BifFunc::fmt_bif ) - return true; - - const ExprPList& args = call->Args()->Exprs(); - if ( args.length() == 0 ) - { - // Empty calls are allowed, since you can't just - // use "print;" to get a blank line. - return true; - } - - const Expr* fmt_str_arg = args[0]; - if ( fmt_str_arg->GetType()->Tag() != TYPE_STRING ) - { - call->Error("first argument to util::fmt() needs to be a format string"); - return false; - } - - auto fmt_str_val = fmt_str_arg->Eval(nullptr); - - if ( fmt_str_val ) - { - const char* fmt_str = fmt_str_val->AsStringVal()->CheckString(); - - int num_fmt = 0; - while ( *fmt_str ) - { - if ( *(fmt_str++) != '%' ) - continue; - - if ( ! *fmt_str ) - { - call->Error("format string ends with bare '%'"); - return false; - } - - if ( *(fmt_str++) != '%' ) - // Not a "%%" escape. - ++num_fmt; - } - - if ( args.length() != num_fmt + 1 ) - { - call->Error( - "mismatch between format string to util::fmt() and number of arguments passed"); - return false; - } - } - - return true; - } +} // namespace zeek::detail + +namespace zeek { + +std::string render_call_stack() { + std::string rval; + int lvl = 0; + + if ( ! detail::call_stack.empty() ) + rval += "| "; + + for ( auto it = detail::call_stack.rbegin(); it != detail::call_stack.rend(); ++it ) { + if ( lvl > 0 ) + rval += " | "; + + auto& ci = *it; + auto name = ci.func->Name(); + std::string arg_desc; + + for ( const auto& arg : ci.args ) { + ODesc d; + d.SetShort(); + arg->Describe(&d); + + if ( ! arg_desc.empty() ) + arg_desc += ", "; + + arg_desc += d.Description(); + } + + rval += util::fmt("#%d %s(%s)", lvl, name, arg_desc.data()); + + if ( ci.call ) { + auto loc = ci.call->GetLocationInfo(); + rval += util::fmt(" at %s:%d", loc->filename, loc->first_line); + } + + ++lvl; + } + + if ( ! detail::call_stack.empty() ) + rval += " |"; + + return rval; +} + +void Func::AddBody(const detail::FunctionIngredients& ingr, detail::StmtPtr new_body) { + if ( ! new_body ) + new_body = ingr.Body(); + + AddBody(new_body, ingr.Inits(), ingr.FrameSize(), ingr.Priority(), ingr.Groups()); +} + +void Func::AddBody(detail::StmtPtr new_body, const std::vector& new_inits, size_t new_frame_size, + int priority) { + std::set groups; + AddBody(new_body, new_inits, new_frame_size, priority, groups); +} + +void Func::AddBody(detail::StmtPtr new_body, size_t new_frame_size) { + std::vector no_inits; + std::set no_groups; + AddBody(std::move(new_body), no_inits, new_frame_size, 0, no_groups); +} + +void Func::AddBody(detail::StmtPtr /* new_body */, const std::vector& /* new_inits */, + size_t /* new_frame_size */, int /* priority */, const std::set& /* groups */) { + Internal("Func::AddBody called"); +} + +void Func::SetScope(detail::ScopePtr newscope) { scope = std::move(newscope); } + +FuncPtr Func::DoClone() { + // By default, ok just to return a reference. Func does not have any state + // that is different across instances. + return {NewRef{}, this}; +} + +void Func::DescribeDebug(ODesc* d, const Args* args) const { + d->Add(Name()); + + if ( args ) { + d->Add("("); + const auto& func_args = GetType()->Params(); + auto num_fields = static_cast(func_args->NumFields()); + + for ( auto i = 0u; i < args->size(); ++i ) { + // Handle varargs case (more args than formals). + if ( i >= num_fields ) { + d->Add("vararg"); + int va_num = i - num_fields; + d->Add(va_num); + } + else + d->Add(func_args->FieldName(i)); + + d->Add(" = '"); + (*args)[i]->Describe(d); + + if ( i < args->size() - 1 ) + d->Add("', "); + else + d->Add("'"); + } + + d->Add(")"); + } +} + +detail::TraversalCode Func::Traverse(detail::TraversalCallback* cb) const { + // FIXME: Make a fake scope for builtins? + auto old_scope = cb->current_scope; + cb->current_scope = scope; + + detail::TraversalCode tc = cb->PreFunction(this); + HANDLE_TC_STMT_PRE(tc); + + // FIXME: Traverse arguments to builtin functions, too. + if ( kind == SCRIPT_FUNC && scope ) { + tc = scope->Traverse(cb); + HANDLE_TC_STMT_PRE(tc); + + for ( const auto& body : bodies ) { + tc = body.stmts->Traverse(cb); + HANDLE_TC_STMT_PRE(tc); + } + } + + tc = cb->PostFunction(this); + + cb->current_scope = old_scope; + HANDLE_TC_STMT_POST(tc); +} + +void Func::CopyStateInto(Func* other) const { + other->bodies = bodies; + other->scope = scope; + other->kind = kind; + + other->type = type; + + other->name = name; +} + +void Func::CheckPluginResult(bool handled, const ValPtr& hook_result, FunctionFlavor flavor) const { + // Helper function factoring out this code from ScriptFunc:Call() for + // better readability. + + if ( ! handled ) { + if ( hook_result ) + reporter->InternalError("plugin set processed flag to false but actually returned a value"); + + // The plugin result hasn't been processed yet (read: fall + // into ::Call method). + return; + } + + switch ( flavor ) { + case FUNC_FLAVOR_EVENT: + if ( hook_result ) + reporter->InternalError("plugin returned non-void result for event %s", this->Name()); + + break; + + case FUNC_FLAVOR_HOOK: + if ( hook_result->GetType()->Tag() != TYPE_BOOL ) + reporter->InternalError("plugin returned non-bool for hook %s", this->Name()); + + break; + + case FUNC_FLAVOR_FUNCTION: { + const auto& yt = GetType()->Yield(); + + if ( (! yt) || yt->Tag() == TYPE_VOID ) { + if ( hook_result ) + reporter->InternalError("plugin returned non-void result for void method %s", this->Name()); + } + + else if ( hook_result && hook_result->GetType()->Tag() != yt->Tag() && yt->Tag() != TYPE_ANY ) { + reporter->InternalError("plugin returned wrong type (got %d, expecting %d) for %s", + hook_result->GetType()->Tag(), yt->Tag(), this->Name()); + } + + break; + } + } +} + +namespace detail { + +ScriptFunc::ScriptFunc(const IDPtr& arg_id) : Func(SCRIPT_FUNC) { + name = arg_id->Name(); + type = arg_id->GetType(); + frame_size = 0; +} + +ScriptFunc::ScriptFunc(std::string _name, FuncTypePtr ft, std::vector bs, std::vector priorities) { + name = std::move(_name); + frame_size = ft->ParamList()->GetTypes().size(); + type = std::move(ft); + + auto n = bs.size(); + ASSERT(n == priorities.size()); + + for ( auto i = 0u; i < n; ++i ) { + Body b; + b.stmts = std::move(bs[i]); + b.priority = priorities[i]; + bodies.push_back(b); + } + + std::stable_sort(bodies.begin(), bodies.end()); + + if ( ! bodies.empty() ) { + current_body = bodies[0].stmts; + current_priority = bodies[0].priority; + } +} + +ScriptFunc::~ScriptFunc() { + if ( captures_vec ) { + auto& cvec = *captures_vec; + auto& captures = *type->GetCaptures(); + for ( auto i = 0u; i < captures.size(); ++i ) + if ( captures[i].IsManaged() ) + ZVal::DeleteManagedType(cvec[i]); + } + + delete captures_frame; + delete captures_offset_mapping; +} + +bool ScriptFunc::IsPure() const { + return std::all_of(bodies.begin(), bodies.end(), [](const Body& b) { return b.stmts->IsPure(); }); +} + +ValPtr ScriptFunc::Invoke(zeek::Args* args, Frame* parent) const { + SegmentProfiler prof(segment_logger, location); + + if ( sample_logger ) + sample_logger->FunctionSeen(this); + + auto [handled, hook_result] = + PLUGIN_HOOK_WITH_RESULT(HOOK_CALL_FUNCTION, HookCallFunction(this, parent, args), empty_hook_result); + + CheckPluginResult(handled, hook_result, Flavor()); + + if ( handled ) + return hook_result; + + if ( bodies.empty() ) { + // Can only happen for events and hooks. + assert(Flavor() == FUNC_FLAVOR_EVENT || Flavor() == FUNC_FLAVOR_HOOK); + return Flavor() == FUNC_FLAVOR_HOOK ? val_mgr->True() : nullptr; + } + + auto f = make_intrusive(frame_size, this, args); + + // Hand down any trigger. + if ( parent ) { + f->SetTrigger({NewRef{}, parent->GetTrigger()}); + f->SetTriggerAssoc(parent->GetTriggerAssoc()); + } + + g_frame_stack.push_back(f.get()); // used for backtracing + const CallExpr* call_expr = parent ? parent->GetCall() : nullptr; + call_stack.emplace_back(CallInfo{call_expr, this, *args}); + + // If a script function is ever invoked with more arguments than it has + // parameters log an error and return. Most likely a "variadic function" + // that only has a single any parameter and is excluded from static type + // checking is involved. This should otherwise not be possible to hit. + auto num_params = static_cast(GetType()->Params()->NumFields()); + if ( args->size() > num_params ) { + emit_builtin_exception("too many arguments for function call"); + return nullptr; + } + + if ( etm && Flavor() == FUNC_FLAVOR_EVENT ) + etm->StartEvent(this, args); + + if ( g_trace_state.DoTrace() ) { + ODesc d; + DescribeDebug(&d, args); + + g_trace_state.LogTrace("%s called: %s\n", GetType()->FlavorString().c_str(), d.Description()); + } + + StmtFlowType flow = FLOW_NEXT; + ValPtr result; + + for ( const auto& body : bodies ) { + if ( body.disabled ) + continue; + + if ( sample_logger ) + sample_logger->LocationSeen(body.stmts->GetLocationInfo()); + + // Fill in the rest of the frame with the function's arguments. + for ( auto j = 0u; j < args->size(); ++j ) { + const auto& arg = (*args)[j]; + + if ( f->GetElement(j) != arg ) + // Either not yet set, or somebody reassigned the frame slot. + f->SetElement(j, arg); + } + + if ( spm ) + spm->StartInvocation(this, body.stmts); + + f->Reset(args->size()); + + try { + result = body.stmts->Exec(f.get(), flow); + } + + catch ( InterpreterException& e ) { + // Already reported, but now determine whether to unwind further. + if ( Flavor() == FUNC_FLAVOR_FUNCTION ) { + g_frame_stack.pop_back(); + call_stack.pop_back(); + // Result not set b/c exception was thrown + throw; + } + + // Continue exec'ing remaining bodies of hooks/events. + continue; + } + + if ( spm ) + spm->EndInvocation(); + + if ( f->HasDelayed() ) { + assert(! result); + assert(parent); + parent->SetDelayed(); + break; + } + + if ( Flavor() == FUNC_FLAVOR_HOOK ) { + // Ignore any return values of hook bodies, final return value + // depends on whether a body returns as a result of break statement. + result = nullptr; + + if ( flow == FLOW_BREAK ) { + // Short-circuit execution of remaining hook handler bodies. + result = val_mgr->False(); + break; + } + } + } + + call_stack.pop_back(); + + if ( Flavor() == FUNC_FLAVOR_HOOK ) { + if ( ! result ) + result = val_mgr->True(); + } + + else if ( etm && Flavor() == FUNC_FLAVOR_EVENT ) + etm->EndEvent(this, args); + + // Warn if the function returns something, but we returned from + // the function without an explicit return, or without a value. + else if ( GetType()->Yield() && GetType()->Yield()->Tag() != TYPE_VOID && + (flow != FLOW_RETURN /* we fell off the end */ || ! result /* explicit return with no result */) && + ! f->HasDelayed() ) + reporter->Warning("non-void function returning without a value: %s", Name()); + + if ( result && g_trace_state.DoTrace() ) { + ODesc d; + result->Describe(&d); + + g_trace_state.LogTrace("Function return: %s\n", d.Description()); + } + + g_frame_stack.pop_back(); + + return result; +} + +void ScriptFunc::CreateCaptures(Frame* f) { + const auto& captures = type->GetCaptures(); + + if ( ! captures ) + return; + + // Create *either* a private Frame to hold the values of captured + // variables, and a mapping from those variables to their offsets + // in the Frame; *or* a ZVal frame if this script has a ZAM-compiled + // body. + ASSERT(bodies.size() == 1); + + if ( bodies[0].stmts->Tag() == STMT_ZAM ) + captures_vec = std::make_unique>(); + else { + delete captures_frame; + delete captures_offset_mapping; + captures_frame = new Frame(captures->size(), this, nullptr); + captures_offset_mapping = new OffsetMap; + } + + int offset = 0; + for ( const auto& c : *captures ) { + auto v = f->GetElementByID(c.Id()); + + if ( v ) { + if ( c.IsDeepCopy() || ! v->Modifiable() ) + v = v->Clone(); + + if ( captures_vec ) + // Don't use v->GetType() here, as that might + // be "any" and we need to convert. + captures_vec->push_back(ZVal(v, c.Id()->GetType())); + else + captures_frame->SetElement(offset, std::move(v)); + } + + else if ( captures_vec ) + captures_vec->push_back(ZVal()); + + if ( ! captures_vec ) + captures_offset_mapping->insert_or_assign(c.Id()->Name(), offset); + + ++offset; + } +} + +void ScriptFunc::CreateCaptures(std::unique_ptr> cvec) { + const auto& captures = *type->GetCaptures(); + + ASSERT(cvec->size() == captures.size()); + ASSERT(bodies.size() == 1 && bodies[0].stmts->Tag() == STMT_ZAM); + + captures_vec = std::move(cvec); + + auto n = captures.size(); + for ( auto i = 0U; i < n; ++i ) { + auto& c_i = captures[i]; + auto& cv_i = (*captures_vec)[i]; + + if ( c_i.IsDeepCopy() ) { + auto& t = c_i.Id()->GetType(); + auto new_cv_i = cv_i.ToVal(t)->Clone(); + if ( c_i.IsManaged() ) + ZVal::DeleteManagedType(cv_i); + + cv_i = ZVal(std::move(new_cv_i), t); + } + } +} + +void ScriptFunc::SetCaptures(Frame* f) { + const auto& captures = type->GetCaptures(); + ASSERT(captures); + + delete captures_frame; + delete captures_offset_mapping; + captures_frame = f; + captures_offset_mapping = new OffsetMap; + + int offset = 0; + for ( const auto& c : *captures ) { + captures_offset_mapping->insert_or_assign(c.Id()->Name(), offset); + ++offset; + } +} + +void ScriptFunc::AddBody(StmtPtr new_body, const std::vector& new_inits, size_t new_frame_size, int priority, + const std::set& groups) { + if ( new_frame_size > frame_size ) + frame_size = new_frame_size; + + auto num_args = static_cast(GetType()->Params()->NumFields()); + + if ( num_args > frame_size ) + frame_size = num_args; + + new_body = AddInits(std::move(new_body), new_inits); + + if ( Flavor() == FUNC_FLAVOR_FUNCTION ) { + // For functions, we replace the old body with the new one. + assert(bodies.size() <= 1); + bodies.clear(); + } + + Body b; + b.stmts = new_body; + b.groups = groups; + current_body = new_body; + current_priority = b.priority = priority; + + bodies.push_back(b); + std::stable_sort(bodies.begin(), bodies.end()); +} + +void ScriptFunc::ReplaceBody(const StmtPtr& old_body, StmtPtr new_body) { + bool found_it = false; + + for ( auto body = bodies.begin(); body != bodies.end(); ++body ) + if ( body->stmts.get() == old_body.get() ) { + if ( new_body ) { + body->stmts = new_body; + current_priority = body->priority; + } + else + bodies.erase(body); + + found_it = true; + break; + } + + current_body = new_body; +} + +bool ScriptFunc::DeserializeCaptures(const broker::vector& data) { + auto result = Frame::Unserialize(data); + + ASSERT(result.first); + + auto& f = result.second; + + if ( bodies[0].stmts->Tag() == STMT_ZAM ) { + auto& captures = *type->GetCaptures(); + int n = f->FrameSize(); + + ASSERT(captures.size() == static_cast(n)); + + auto cvec = std::make_unique>(); + + for ( int i = 0; i < n; ++i ) { + auto& f_i = f->GetElement(i); + cvec->push_back(ZVal(f_i, captures[i].Id()->GetType())); + } + + CreateCaptures(std::move(cvec)); + } + + else + SetCaptures(f.release()); + + return true; +} + +FuncPtr ScriptFunc::DoClone() { + // ScriptFunc could hold a closure. In this case a clone of it must + // store a copy of this closure. + // + // We don't use make_intrusive<> directly because we're accessing + // a protected constructor. + auto other = IntrusivePtr{AdoptRef{}, new ScriptFunc()}; + + CopyStateInto(other.get()); + + other->frame_size = frame_size; + other->outer_ids = outer_ids; + + if ( captures_frame ) { + other->captures_frame = captures_frame->Clone(); + other->captures_offset_mapping = new OffsetMap; + *other->captures_offset_mapping = *captures_offset_mapping; + } + + if ( captures_vec ) { + auto cv_i = captures_vec->begin(); + other->captures_vec = std::make_unique>(); + for ( auto& c : *type->GetCaptures() ) { + // Need to clone cv_i. + auto& t_i = c.Id()->GetType(); + auto cv_i_val = cv_i->ToVal(t_i)->Clone(); + other->captures_vec->push_back(ZVal(std::move(cv_i_val), t_i)); + ++cv_i; + } + } + + return other; +} + +broker::expected ScriptFunc::SerializeCaptures() const { + if ( captures_vec ) { + auto& cv = *captures_vec; + auto& captures = *type->GetCaptures(); + int n = captures_vec->size(); + auto temp_frame = make_intrusive(n, this, nullptr); + + for ( int i = 0; i < n; ++i ) { + auto c_i = cv[i].ToVal(captures[i].Id()->GetType()); + temp_frame->SetElement(i, c_i); + } + + return temp_frame->Serialize(); + } + + if ( captures_frame ) + return captures_frame->Serialize(); + + // No captures, return an empty vector. + return broker::vector{}; +} + +void ScriptFunc::Describe(ODesc* d) const { + d->Add(Name()); + + d->NL(); + d->AddCount(frame_size); + for ( const auto& body : bodies ) { + body.stmts->AccessStats(d); + body.stmts->Describe(d); + } +} + +StmtPtr ScriptFunc::AddInits(StmtPtr body, const std::vector& inits) { + if ( inits.empty() ) + return body; + + auto stmt_series = make_intrusive(); + stmt_series->Stmts().push_back(make_intrusive(inits)); + stmt_series->Stmts().push_back(std::move(body)); + + return stmt_series; +} + +BuiltinFunc::BuiltinFunc(built_in_func arg_func, const char* arg_name, bool arg_is_pure) : Func(BUILTIN_FUNC) { + func = arg_func; + name = make_full_var_name(GLOBAL_MODULE_NAME, arg_name); + is_pure = arg_is_pure; + + const auto& id = lookup_ID(Name(), GLOBAL_MODULE_NAME, false); + if ( ! id ) + reporter->InternalError("built-in function %s missing", Name()); + if ( id->HasVal() ) + reporter->InternalError("built-in function %s multiply defined", Name()); + + type = id->GetType(); + id->SetVal(make_intrusive(IntrusivePtr{NewRef{}, this})); + id->SetConst(); +} + +bool BuiltinFunc::IsPure() const { return is_pure; } + +ValPtr BuiltinFunc::Invoke(Args* args, Frame* parent) const { + if ( spm ) + spm->StartInvocation(this); + + SegmentProfiler prof(segment_logger, Name()); + + if ( sample_logger ) + sample_logger->FunctionSeen(this); + + auto [handled, hook_result] = + PLUGIN_HOOK_WITH_RESULT(HOOK_CALL_FUNCTION, HookCallFunction(this, parent, args), empty_hook_result); + + CheckPluginResult(handled, hook_result, FUNC_FLAVOR_FUNCTION); + + if ( handled ) { + if ( spm ) + spm->EndInvocation(); + return hook_result; + } + + if ( g_trace_state.DoTrace() ) { + ODesc d; + DescribeDebug(&d, args); + + g_trace_state.LogTrace("\tBuiltin Function called: %s\n", d.Description()); + } + + const CallExpr* call_expr = parent ? parent->GetCall() : nullptr; + call_stack.emplace_back(CallInfo{call_expr, this, *args}); + auto result = std::move(func(parent, args).rval); + call_stack.pop_back(); + + if ( result && g_trace_state.DoTrace() ) { + ODesc d; + result->Describe(&d); + + g_trace_state.LogTrace("\tFunction return: %s\n", d.Description()); + } + + if ( spm ) + spm->EndInvocation(); + + return result; +} + +void BuiltinFunc::Describe(ODesc* d) const { + d->Add(Name()); + d->AddCount(is_pure); +} + +bool check_built_in_call(BuiltinFunc* f, CallExpr* call) { + if ( f->TheFunc() != BifFunc::fmt_bif ) + return true; + + const ExprPList& args = call->Args()->Exprs(); + if ( args.length() == 0 ) { + // Empty calls are allowed, since you can't just + // use "print;" to get a blank line. + return true; + } + + const Expr* fmt_str_arg = args[0]; + if ( fmt_str_arg->GetType()->Tag() != TYPE_STRING ) { + call->Error("first argument to util::fmt() needs to be a format string"); + return false; + } + + auto fmt_str_val = fmt_str_arg->Eval(nullptr); + + if ( fmt_str_val ) { + const char* fmt_str = fmt_str_val->AsStringVal()->CheckString(); + + int num_fmt = 0; + while ( *fmt_str ) { + if ( *(fmt_str++) != '%' ) + continue; + + if ( ! *fmt_str ) { + call->Error("format string ends with bare '%'"); + return false; + } + + if ( *(fmt_str++) != '%' ) + // Not a "%%" escape. + ++num_fmt; + } + + if ( args.length() != num_fmt + 1 ) { + call->Error("mismatch between format string to util::fmt() and number of arguments passed"); + return false; + } + } + + return true; +} // Gets a function's priority from its Scope's attributes. Errors if it sees any // problems. -static int get_func_priority(const std::vector& attrs) - { - int priority = 0; +static int get_func_priority(const std::vector& attrs) { + int priority = 0; - for ( const auto& a : attrs ) - { - if ( a->Tag() == ATTR_DEPRECATED || a->Tag() == ATTR_IS_USED || a->Tag() == ATTR_GROUP ) - continue; + for ( const auto& a : attrs ) { + if ( a->Tag() == ATTR_DEPRECATED || a->Tag() == ATTR_IS_USED || a->Tag() == ATTR_GROUP ) + continue; - if ( a->Tag() != ATTR_PRIORITY ) - { - a->Error("illegal attribute for function body"); - continue; - } + if ( a->Tag() != ATTR_PRIORITY ) { + a->Error("illegal attribute for function body"); + continue; + } - auto v = a->GetExpr()->Eval(nullptr); + auto v = a->GetExpr()->Eval(nullptr); - if ( ! v ) - { - a->Error("cannot evaluate attribute expression"); - continue; - } + if ( ! v ) { + a->Error("cannot evaluate attribute expression"); + continue; + } - if ( ! IsIntegral(v->GetType()->Tag()) ) - { - a->Error("expression is not of integral type"); - continue; - } + if ( ! IsIntegral(v->GetType()->Tag()) ) { + a->Error("expression is not of integral type"); + continue; + } - priority = v->InternalInt(); - } + priority = v->InternalInt(); + } - return priority; - } + return priority; +} // Get a function's groups from its Scope's attributes. Errors if it sees any // problems with the group tag. get_func_priority() checks for illegal // attributes, so we don't do this here. -static std::set get_func_groups(const std::vector& attrs) - { - std::set groups; +static std::set get_func_groups(const std::vector& attrs) { + std::set groups; - for ( const auto& a : attrs ) - { - if ( a->Tag() != ATTR_GROUP ) - continue; + for ( const auto& a : attrs ) { + if ( a->Tag() != ATTR_GROUP ) + continue; - auto v = a->GetExpr()->Eval(nullptr); + auto v = a->GetExpr()->Eval(nullptr); - if ( ! v ) - { - a->Error("cannot evaluate attribute expression"); - continue; - } + if ( ! v ) { + a->Error("cannot evaluate attribute expression"); + continue; + } - if ( ! IsString(v->GetType()->Tag()) ) - { - a->Error("expression is not of string type"); - continue; - } + if ( ! IsString(v->GetType()->Tag()) ) { + a->Error("expression is not of string type"); + continue; + } - auto group = event_registry->RegisterGroup(EventGroupKind::Attribute, - v->AsStringVal()->ToStdStringView()); - groups.insert(group); - } + auto group = event_registry->RegisterGroup(EventGroupKind::Attribute, v->AsStringVal()->ToStdStringView()); + groups.insert(group); + } - return groups; - } + return groups; +} -FunctionIngredients::FunctionIngredients(ScopePtr _scope, StmtPtr _body, - const std::string& module_name) - { - scope = std::move(_scope); - body = std::move(_body); +FunctionIngredients::FunctionIngredients(ScopePtr _scope, StmtPtr _body, const std::string& module_name) { + scope = std::move(_scope); + body = std::move(_body); - frame_size = scope->Length(); - inits = scope->GetInits(); + frame_size = scope->Length(); + inits = scope->GetInits(); - id = scope->GetID(); + id = scope->GetID(); - const auto& attrs = scope->Attrs(); + const auto& attrs = scope->Attrs(); - if ( attrs ) - { - priority = get_func_priority(*attrs); + if ( attrs ) { + priority = get_func_priority(*attrs); - groups = get_func_groups(*attrs); + groups = get_func_groups(*attrs); - for ( const auto& a : *attrs ) - if ( a->Tag() == ATTR_IS_USED ) - { - // Associate this with the identifier, too. - id->AddAttr(make_intrusive(ATTR_IS_USED)); - break; - } - } - else - priority = 0; + for ( const auto& a : *attrs ) + if ( a->Tag() == ATTR_IS_USED ) { + // Associate this with the identifier, too. + id->AddAttr(make_intrusive(ATTR_IS_USED)); + break; + } + } + else + priority = 0; - // Implicit module event groups for events and hooks. - auto flavor = id->GetType()->Flavor(); - if ( flavor == FUNC_FLAVOR_EVENT || flavor == FUNC_FLAVOR_HOOK ) - { - auto module_group = event_registry->RegisterGroup(EventGroupKind::Module, module_name); - groups.insert(module_group); - } - } + // Implicit module event groups for events and hooks. + auto flavor = id->GetType()->Flavor(); + if ( flavor == FUNC_FLAVOR_EVENT || flavor == FUNC_FLAVOR_HOOK ) { + auto module_group = event_registry->RegisterGroup(EventGroupKind::Module, module_name); + groups.insert(module_group); + } +} zeek::RecordValPtr make_backtrace_element(std::string_view name, const VectorValPtr args, - const zeek::detail::Location* loc) - { - static auto elem_type = id::find_type("BacktraceElement"); - static auto function_name_idx = elem_type->FieldOffset("function_name"); - static auto function_args_idx = elem_type->FieldOffset("function_args"); - static auto file_location_idx = elem_type->FieldOffset("file_location"); - static auto line_location_idx = elem_type->FieldOffset("line_location"); + const zeek::detail::Location* loc) { + static auto elem_type = id::find_type("BacktraceElement"); + static auto function_name_idx = elem_type->FieldOffset("function_name"); + static auto function_args_idx = elem_type->FieldOffset("function_args"); + static auto file_location_idx = elem_type->FieldOffset("file_location"); + static auto line_location_idx = elem_type->FieldOffset("line_location"); - auto elem = make_intrusive(elem_type); - elem->Assign(function_name_idx, name.data()); - elem->Assign(function_args_idx, std::move(args)); + auto elem = make_intrusive(elem_type); + elem->Assign(function_name_idx, name.data()); + elem->Assign(function_args_idx, std::move(args)); - if ( loc ) - { - elem->Assign(file_location_idx, loc->filename); - elem->Assign(line_location_idx, loc->first_line); - } + if ( loc ) { + elem->Assign(file_location_idx, loc->filename); + elem->Assign(line_location_idx, loc->first_line); + } - return elem; - } + return elem; +} -zeek::VectorValPtr get_current_script_backtrace() - { - static auto backtrace_type = id::find_type("Backtrace"); +zeek::VectorValPtr get_current_script_backtrace() { + static auto backtrace_type = id::find_type("Backtrace"); - auto rval = make_intrusive(backtrace_type); + auto rval = make_intrusive(backtrace_type); - // The body of the following loop can wind up adding items to - // the call stack (because MakeCallArgumentVector() evaluates - // default arguments, which can in turn involve calls to script - // functions), so we work from a copy of the current call stack - // to prevent problems with iterator invalidation. - auto cs_copy = zeek::detail::call_stack; + // The body of the following loop can wind up adding items to + // the call stack (because MakeCallArgumentVector() evaluates + // default arguments, which can in turn involve calls to script + // functions), so we work from a copy of the current call stack + // to prevent problems with iterator invalidation. + auto cs_copy = zeek::detail::call_stack; - for ( auto it = cs_copy.rbegin(); it != cs_copy.rend(); ++it ) - { - const auto& ci = *it; - if ( ! ci.func ) - // This happens for compiled code. - continue; + for ( auto it = cs_copy.rbegin(); it != cs_copy.rend(); ++it ) { + const auto& ci = *it; + if ( ! ci.func ) + // This happens for compiled code. + continue; - const auto& params = ci.func->GetType()->Params(); - auto args = MakeCallArgumentVector(ci.args, params); + const auto& params = ci.func->GetType()->Params(); + auto args = MakeCallArgumentVector(ci.args, params); - auto elem = make_backtrace_element(ci.func->Name(), std::move(args), - ci.call ? ci.call->GetLocationInfo() : nullptr); - rval->Append(std::move(elem)); - } + auto elem = + make_backtrace_element(ci.func->Name(), std::move(args), ci.call ? ci.call->GetLocationInfo() : nullptr); + rval->Append(std::move(elem)); + } - return rval; - } + return rval; +} -static void emit_builtin_error_common(const char* msg, Obj* arg, bool unwind) - { - auto emit = [=](const CallExpr* ce) - { - if ( ce ) - { - if ( unwind ) - { - if ( arg ) - { - ODesc d; - arg->Describe(&d); - reporter->ExprRuntimeError(ce, "%s (%s), during call:", msg, d.Description()); - } - else - reporter->ExprRuntimeError(ce, "%s", msg); - } - else - ce->Error(msg, arg); - } - else - { - if ( arg ) - { - if ( unwind ) - reporter->RuntimeError(arg->GetLocationInfo(), "%s", msg); - else - arg->Error(msg); - } - else - { - if ( unwind ) - reporter->RuntimeError(nullptr, "%s", msg); - else - reporter->Error("%s", msg); - } - } - }; +static void emit_builtin_error_common(const char* msg, Obj* arg, bool unwind) { + auto emit = [=](const CallExpr* ce) { + if ( ce ) { + if ( unwind ) { + if ( arg ) { + ODesc d; + arg->Describe(&d); + reporter->ExprRuntimeError(ce, "%s (%s), during call:", msg, d.Description()); + } + else + reporter->ExprRuntimeError(ce, "%s", msg); + } + else + ce->Error(msg, arg); + } + else { + if ( arg ) { + if ( unwind ) + reporter->RuntimeError(arg->GetLocationInfo(), "%s", msg); + else + arg->Error(msg); + } + else { + if ( unwind ) + reporter->RuntimeError(nullptr, "%s", msg); + else + reporter->Error("%s", msg); + } + } + }; - if ( call_stack.empty() ) - { - // Shouldn't happen unless someone (mistakenly) calls builtin_error() - // from somewhere that's not even evaluating script-code. - emit(nullptr); - return; - } + if ( call_stack.empty() ) { + // Shouldn't happen unless someone (mistakenly) calls builtin_error() + // from somewhere that's not even evaluating script-code. + emit(nullptr); + return; + } - auto last_call = call_stack.back(); + auto last_call = call_stack.back(); - if ( call_stack.size() < 2 ) - { - // Don't need to check for wrapper function like "::__" - emit(last_call.call); - return; - } + if ( call_stack.size() < 2 ) { + // Don't need to check for wrapper function like "::__" + emit(last_call.call); + return; + } - auto starts_with_double_underscore = [](const std::string& name) -> bool - { - return name.size() > 2 && name[0] == '_' && name[1] == '_'; - }; - std::string last_func = last_call.func->Name(); + auto starts_with_double_underscore = [](const std::string& name) -> bool { + return name.size() > 2 && name[0] == '_' && name[1] == '_'; + }; + std::string last_func = last_call.func->Name(); - auto pos = last_func.find_first_of("::"); - std::string wrapper_func; + auto pos = last_func.find_first_of("::"); + std::string wrapper_func; - if ( pos == std::string::npos ) - { - if ( ! starts_with_double_underscore(last_func) ) - { - emit(last_call.call); - return; - } + if ( pos == std::string::npos ) { + if ( ! starts_with_double_underscore(last_func) ) { + emit(last_call.call); + return; + } - wrapper_func = last_func.substr(2); - } - else - { - auto module_name = last_func.substr(0, pos); - auto func_name = last_func.substr(pos + 2); + wrapper_func = last_func.substr(2); + } + else { + auto module_name = last_func.substr(0, pos); + auto func_name = last_func.substr(pos + 2); - if ( ! starts_with_double_underscore(func_name) ) - { - emit(last_call.call); - return; - } + if ( ! starts_with_double_underscore(func_name) ) { + emit(last_call.call); + return; + } - wrapper_func = module_name + "::" + func_name.substr(2); - } + wrapper_func = module_name + "::" + func_name.substr(2); + } - auto parent_call = call_stack[call_stack.size() - 2]; - auto parent_func = parent_call.func->Name(); + auto parent_call = call_stack[call_stack.size() - 2]; + auto parent_func = parent_call.func->Name(); - if ( wrapper_func == parent_func ) - emit(parent_call.call); - else - emit(last_call.call); - } + if ( wrapper_func == parent_func ) + emit(parent_call.call); + else + emit(last_call.call); +} -void emit_builtin_exception(const char* msg) - { - emit_builtin_error_common(msg, nullptr, true); - } +void emit_builtin_exception(const char* msg) { emit_builtin_error_common(msg, nullptr, true); } -void emit_builtin_exception(const char* msg, const ValPtr& arg) - { - emit_builtin_error_common(msg, arg.get(), true); - } +void emit_builtin_exception(const char* msg, const ValPtr& arg) { emit_builtin_error_common(msg, arg.get(), true); } -void emit_builtin_exception(const char* msg, Obj* arg) - { - emit_builtin_error_common(msg, arg, true); - } +void emit_builtin_exception(const char* msg, Obj* arg) { emit_builtin_error_common(msg, arg, true); } -void init_primary_bifs() - { - if ( did_builtin_init ) - return; +void init_primary_bifs() { + if ( did_builtin_init ) + return; - ProcStats = id::find_type("ProcStats"); - NetStats = id::find_type("NetStats"); - MatcherStats = id::find_type("MatcherStats"); - ConnStats = id::find_type("ConnStats"); - ReassemblerStats = id::find_type("ReassemblerStats"); - DNSStats = id::find_type("DNSStats"); - GapStats = id::find_type("GapStats"); - EventStats = id::find_type("EventStats"); - TimerStats = id::find_type("TimerStats"); - FileAnalysisStats = id::find_type("FileAnalysisStats"); - ThreadStats = id::find_type("ThreadStats"); - BrokerStats = id::find_type("BrokerStats"); - ReporterStats = id::find_type("ReporterStats"); + ProcStats = id::find_type("ProcStats"); + NetStats = id::find_type("NetStats"); + MatcherStats = id::find_type("MatcherStats"); + ConnStats = id::find_type("ConnStats"); + ReassemblerStats = id::find_type("ReassemblerStats"); + DNSStats = id::find_type("DNSStats"); + GapStats = id::find_type("GapStats"); + EventStats = id::find_type("EventStats"); + TimerStats = id::find_type("TimerStats"); + FileAnalysisStats = id::find_type("FileAnalysisStats"); + ThreadStats = id::find_type("ThreadStats"); + BrokerStats = id::find_type("BrokerStats"); + ReporterStats = id::find_type("ReporterStats"); - var_sizes = id::find_type("var_sizes")->AsTableType(); + var_sizes = id::find_type("var_sizes")->AsTableType(); #include "CPP-load.bif.func_init" #include "communityid.bif.func_init" @@ -1222,25 +1070,18 @@ void init_primary_bifs() #include "supervisor.bif.func_init" #include "zeek.bif.func_init" - init_builtin_types(); - did_builtin_init = true; - } + init_builtin_types(); + did_builtin_init = true; +} - } // namespace detail +} // namespace detail -void emit_builtin_error(const char* msg) - { - zeek::detail::emit_builtin_error_common(msg, nullptr, false); - } +void emit_builtin_error(const char* msg) { zeek::detail::emit_builtin_error_common(msg, nullptr, false); } -void emit_builtin_error(const char* msg, const zeek::ValPtr& arg) - { - zeek::detail::emit_builtin_error_common(msg, arg.get(), false); - } +void emit_builtin_error(const char* msg, const zeek::ValPtr& arg) { + zeek::detail::emit_builtin_error_common(msg, arg.get(), false); +} -void emit_builtin_error(const char* msg, Obj* arg) - { - zeek::detail::emit_builtin_error_common(msg, arg, false); - } +void emit_builtin_error(const char* msg, Obj* arg) { zeek::detail::emit_builtin_error_common(msg, arg, false); } - } // namespace zeek +} // namespace zeek diff --git a/src/Func.h b/src/Func.h index ed0e59aea8..331ce5c61f 100644 --- a/src/Func.h +++ b/src/Func.h @@ -18,21 +18,19 @@ #include "zeek/ZeekArgs.h" #include "zeek/ZeekList.h" -namespace broker - { +namespace broker { class data; using vector = std::vector; -template class expected; - } +template +class expected; +} // namespace broker -namespace zeek - { +namespace zeek { class Val; class FuncType; -namespace detail - { +namespace detail { class Scope; class Stmt; @@ -46,7 +44,7 @@ using StmtPtr = IntrusivePtr; class ScriptFunc; class FunctionIngredients; - } // namespace detail +} // namespace detail class EventGroup; using EventGroupPtr = std::shared_ptr; @@ -54,350 +52,329 @@ using EventGroupPtr = std::shared_ptr; class Func; using FuncPtr = IntrusivePtr; -class Func : public Obj - { +class Func : public Obj { public: - static inline const FuncPtr nil; + static inline const FuncPtr nil; - enum Kind - { - SCRIPT_FUNC, - BUILTIN_FUNC - }; + enum Kind { SCRIPT_FUNC, BUILTIN_FUNC }; - explicit Func(Kind arg_kind) : kind(arg_kind) { } + explicit Func(Kind arg_kind) : kind(arg_kind) {} - virtual bool IsPure() const = 0; - FunctionFlavor Flavor() const { return GetType()->Flavor(); } + virtual bool IsPure() const = 0; + FunctionFlavor Flavor() const { return GetType()->Flavor(); } - struct Body - { - detail::StmtPtr stmts; - int priority; - std::set groups; - // If any of the groups are disabled, this body is disabled. - // The disabled field is updated from EventGroup instances. - bool disabled = false; + struct Body { + detail::StmtPtr stmts; + int priority; + std::set groups; + // If any of the groups are disabled, this body is disabled. + // The disabled field is updated from EventGroup instances. + bool disabled = false; - bool operator<(const Body& other) const - { - return priority > other.priority; - } // reverse sort - }; + bool operator<(const Body& other) const { return priority > other.priority; } // reverse sort + }; - const std::vector& GetBodies() const { return bodies; } - bool HasBodies() const { return ! bodies.empty(); } + const std::vector& GetBodies() const { return bodies; } + bool HasBodies() const { return ! bodies.empty(); } - /** - * Are there bodies and is any one of them enabled? - * - * @return true if bodies exist and at least one is enabled. - */ - bool HasEnabledBodies() const { return ! bodies.empty() && has_enabled_bodies; }; + /** + * Are there bodies and is any one of them enabled? + * + * @return true if bodies exist and at least one is enabled. + */ + bool HasEnabledBodies() const { return ! bodies.empty() && has_enabled_bodies; }; - /** - * Calls a Zeek function. - * @param args the list of arguments to the function call. - * @param parent the frame from which the function is being called. - * @return the return value of the function call. - */ - virtual ValPtr Invoke(zeek::Args* args, detail::Frame* parent = nullptr) const = 0; + /** + * Calls a Zeek function. + * @param args the list of arguments to the function call. + * @param parent the frame from which the function is being called. + * @return the return value of the function call. + */ + virtual ValPtr Invoke(zeek::Args* args, detail::Frame* parent = nullptr) const = 0; - /** - * A version of Invoke() taking a variable number of individual arguments. - */ - template - std::enable_if_t>, ValPtr>, - ValPtr> - Invoke(Args&&... args) const - { - auto zargs = zeek::Args{std::forward(args)...}; - return Invoke(&zargs); - } + /** + * A version of Invoke() taking a variable number of individual arguments. + */ + template + std::enable_if_t>, ValPtr>, ValPtr> Invoke( + Args&&... args) const { + auto zargs = zeek::Args{std::forward(args)...}; + return Invoke(&zargs); + } - // Various ways to add a new event handler to an existing function - // (event). The usual version to use is the first with its default - // parameter. All of the others are for use by script optimization, - // as is a non-default second parameter to the first method, which - // overrides the function body in "ingr". - void AddBody(const detail::FunctionIngredients& ingr, detail::StmtPtr new_body = nullptr); - virtual void AddBody(detail::StmtPtr new_body, const std::vector& new_inits, - size_t new_frame_size, int priority, - const std::set& groups); - void AddBody(detail::StmtPtr new_body, const std::vector& new_inits, - size_t new_frame_size, int priority = 0); - void AddBody(detail::StmtPtr new_body, size_t new_frame_size); + // Various ways to add a new event handler to an existing function + // (event). The usual version to use is the first with its default + // parameter. All of the others are for use by script optimization, + // as is a non-default second parameter to the first method, which + // overrides the function body in "ingr". + void AddBody(const detail::FunctionIngredients& ingr, detail::StmtPtr new_body = nullptr); + virtual void AddBody(detail::StmtPtr new_body, const std::vector& new_inits, size_t new_frame_size, + int priority, const std::set& groups); + void AddBody(detail::StmtPtr new_body, const std::vector& new_inits, size_t new_frame_size, + int priority = 0); + void AddBody(detail::StmtPtr new_body, size_t new_frame_size); - virtual void SetScope(detail::ScopePtr newscope); - virtual detail::ScopePtr GetScope() const { return scope; } + virtual void SetScope(detail::ScopePtr newscope); + virtual detail::ScopePtr GetScope() const { return scope; } - const FuncTypePtr& GetType() const { return type; } + const FuncTypePtr& GetType() const { return type; } - Kind GetKind() const { return kind; } + Kind GetKind() const { return kind; } - const char* Name() const { return name.c_str(); } - void SetName(const char* arg_name) { name = arg_name; } + const char* Name() const { return name.c_str(); } + void SetName(const char* arg_name) { name = arg_name; } - void Describe(ODesc* d) const override = 0; - virtual void DescribeDebug(ODesc* d, const zeek::Args* args) const; + void Describe(ODesc* d) const override = 0; + virtual void DescribeDebug(ODesc* d, const zeek::Args* args) const; - virtual FuncPtr DoClone(); + virtual FuncPtr DoClone(); - virtual detail::TraversalCode Traverse(detail::TraversalCallback* cb) const; + virtual detail::TraversalCode Traverse(detail::TraversalCallback* cb) const; protected: - Func() = default; + Func() = default; - // Copies this function's state into other. - void CopyStateInto(Func* other) const; + // Copies this function's state into other. + void CopyStateInto(Func* other) const; - // Helper function for checking result of plugin hook. - void CheckPluginResult(bool handled, const ValPtr& hook_result, FunctionFlavor flavor) const; + // Helper function for checking result of plugin hook. + void CheckPluginResult(bool handled, const ValPtr& hook_result, FunctionFlavor flavor) const; - std::vector bodies; - detail::ScopePtr scope; - Kind kind = SCRIPT_FUNC; - FuncTypePtr type; - std::string name; + std::vector bodies; + detail::ScopePtr scope; + Kind kind = SCRIPT_FUNC; + FuncTypePtr type; + std::string name; private: - // EventGroup updates Func::Body.disabled and has_enabled_bodies. - // This is friend/private with EventGroup here so that we do not - // expose accessors in the zeek:: public interface. - friend class EventGroup; - bool has_enabled_bodies = true; - }; + // EventGroup updates Func::Body.disabled and has_enabled_bodies. + // This is friend/private with EventGroup here so that we do not + // expose accessors in the zeek:: public interface. + friend class EventGroup; + bool has_enabled_bodies = true; +}; -namespace detail - { +namespace detail { -class ScriptFunc : public Func - { +class ScriptFunc : public Func { public: - ScriptFunc(const IDPtr& id); + ScriptFunc(const IDPtr& id); - // For compiled scripts. - ScriptFunc(std::string name, FuncTypePtr ft, std::vector bodies, - std::vector priorities); + // For compiled scripts. + ScriptFunc(std::string name, FuncTypePtr ft, std::vector bodies, std::vector priorities); - ~ScriptFunc() override; + ~ScriptFunc() override; - bool IsPure() const override; - ValPtr Invoke(zeek::Args* args, Frame* parent) const override; + bool IsPure() const override; + ValPtr Invoke(zeek::Args* args, Frame* parent) const override; - /** - * Creates a separate frame for captures and initializes its - * elements. The list of captures comes from the ScriptFunc's - * type, so doesn't need to be passed in, just the frame to - * use in evaluating the identifiers. - * - * @param f the frame used for evaluating the captured identifiers - */ - void CreateCaptures(Frame* f); + /** + * Creates a separate frame for captures and initializes its + * elements. The list of captures comes from the ScriptFunc's + * type, so doesn't need to be passed in, just the frame to + * use in evaluating the identifiers. + * + * @param f the frame used for evaluating the captured identifiers + */ + void CreateCaptures(Frame* f); - /** - * Uses the given set of ZVal's for captures. Note that this is - * different from the method above, which uses its argument to - * compute the captures, rather than here where they are pre-computed. - * - * Makes deep copies if required. - * - * @param cvec a vector of ZVal's corresponding to the captures. - */ - void CreateCaptures(std::unique_ptr> cvec); + /** + * Uses the given set of ZVal's for captures. Note that this is + * different from the method above, which uses its argument to + * compute the captures, rather than here where they are pre-computed. + * + * Makes deep copies if required. + * + * @param cvec a vector of ZVal's corresponding to the captures. + */ + void CreateCaptures(std::unique_ptr> cvec); - /** - * Returns the frame associated with this function for tracking - * captures, or nil if there isn't one. - * - * @return internal frame kept by the function for persisting captures - */ - Frame* GetCapturesFrame() const { return captures_frame; } + /** + * Returns the frame associated with this function for tracking + * captures, or nil if there isn't one. + * + * @return internal frame kept by the function for persisting captures + */ + Frame* GetCapturesFrame() const { return captures_frame; } - /** - * Returns the set of ZVal's used for captures. It's okay to modify - * these as long as memory-management is done for managed entries. - * - * @return internal vector of ZVal's kept for persisting captures - */ - auto& GetCapturesVec() const - { - ASSERT(captures_vec); - return *captures_vec; - } + /** + * Returns the set of ZVal's used for captures. It's okay to modify + * these as long as memory-management is done for managed entries. + * + * @return internal vector of ZVal's kept for persisting captures + */ + auto& GetCapturesVec() const { + ASSERT(captures_vec); + return *captures_vec; + } - // Same definition as in Frame.h. - using OffsetMap = std::unordered_map; + // Same definition as in Frame.h. + using OffsetMap = std::unordered_map; - /** - * Returns the mapping of captures to slots in the captures frame. - * - * @return pointer to mapping of captures to slots - */ - const OffsetMap* GetCapturesOffsetMap() const { return captures_offset_mapping; } + /** + * Returns the mapping of captures to slots in the captures frame. + * + * @return pointer to mapping of captures to slots + */ + const OffsetMap* GetCapturesOffsetMap() const { return captures_offset_mapping; } - /** - * Serializes this function's capture frame. - * - * @return a serialized version of the function's capture frame. - */ - virtual broker::expected SerializeCaptures() const; + /** + * Serializes this function's capture frame. + * + * @return a serialized version of the function's capture frame. + */ + virtual broker::expected SerializeCaptures() const; - /** - * Sets the captures frame to one built from *data*. - * - * @param data a serialized frame - */ - bool DeserializeCaptures(const broker::vector& data); + /** + * Sets the captures frame to one built from *data*. + * + * @param data a serialized frame + */ + bool DeserializeCaptures(const broker::vector& data); - using Func::AddBody; + using Func::AddBody; - void AddBody(detail::StmtPtr new_body, const std::vector& new_inits, - size_t new_frame_size, int priority, - const std::set& groups) override; + void AddBody(detail::StmtPtr new_body, const std::vector& new_inits, size_t new_frame_size, + int priority, const std::set& groups) override; - /** - * Replaces the given current instance of a function body with - * a new one. If new_body is nil then the current instance is - * deleted with no replacement. - * - * @param old_body Body to replace. - * @param new_body New body to use; can be nil. - */ - void ReplaceBody(const detail::StmtPtr& old_body, detail::StmtPtr new_body); + /** + * Replaces the given current instance of a function body with + * a new one. If new_body is nil then the current instance is + * deleted with no replacement. + * + * @param old_body Body to replace. + * @param new_body New body to use; can be nil. + */ + void ReplaceBody(const detail::StmtPtr& old_body, detail::StmtPtr new_body); - StmtPtr CurrentBody() const { return current_body; } - int CurrentPriority() const { return current_priority; } + StmtPtr CurrentBody() const { return current_body; } + int CurrentPriority() const { return current_priority; } - /** - * Returns the function's frame size. - * @return The number of ValPtr slots in the function's frame. - */ - int FrameSize() const { return frame_size; } + /** + * Returns the function's frame size. + * @return The number of ValPtr slots in the function's frame. + */ + int FrameSize() const { return frame_size; } - /** - * Changes the function's frame size to a new size - used for - * script optimization/compilation. - * - * @param new_size The frame size the function should use. - */ - void SetFrameSize(int new_size) { frame_size = new_size; } + /** + * Changes the function's frame size to a new size - used for + * script optimization/compilation. + * + * @param new_size The frame size the function should use. + */ + void SetFrameSize(int new_size) { frame_size = new_size; } - /** Sets this function's outer_id list. */ - void SetOuterIDs(IDPList ids) { outer_ids = std::move(ids); } + /** Sets this function's outer_id list. */ + void SetOuterIDs(IDPList ids) { outer_ids = std::move(ids); } - void Describe(ODesc* d) const override; + void Describe(ODesc* d) const override; protected: - ScriptFunc() : Func(SCRIPT_FUNC) { } + ScriptFunc() : Func(SCRIPT_FUNC) {} - StmtPtr AddInits(StmtPtr body, const std::vector& inits); + StmtPtr AddInits(StmtPtr body, const std::vector& inits); - /** - * Clones this function along with its captures. - */ - FuncPtr DoClone() override; + /** + * Clones this function along with its captures. + */ + FuncPtr DoClone() override; - /** - * Uses the given frame for captures, and generates the - * mapping from captured variables to offsets in the frame. - * Virtual so it can be modified for script optimization uses. - * - * @param f the frame holding the values of capture variables - */ - virtual void SetCaptures(Frame* f); + /** + * Uses the given frame for captures, and generates the + * mapping from captured variables to offsets in the frame. + * Virtual so it can be modified for script optimization uses. + * + * @param f the frame holding the values of capture variables + */ + virtual void SetCaptures(Frame* f); private: - size_t frame_size = 0; + size_t frame_size = 0; - // List of the outer IDs used in the function. - IDPList outer_ids; + // List of the outer IDs used in the function. + IDPList outer_ids; - // Frame for (capture-by-copy) closures. These persist over the - // function's lifetime, providing quasi-globals that maintain - // state across individual calls to the function. - Frame* captures_frame = nullptr; + // Frame for (capture-by-copy) closures. These persist over the + // function's lifetime, providing quasi-globals that maintain + // state across individual calls to the function. + Frame* captures_frame = nullptr; - OffsetMap* captures_offset_mapping = nullptr; + OffsetMap* captures_offset_mapping = nullptr; - // Captures when using ZVal block instead of a Frame. - std::unique_ptr> captures_vec; + // Captures when using ZVal block instead of a Frame. + std::unique_ptr> captures_vec; - // The most recently added/updated body ... - StmtPtr current_body; + // The most recently added/updated body ... + StmtPtr current_body; - // ... and its priority. - int current_priority = 0; - }; + // ... and its priority. + int current_priority = 0; +}; using built_in_func = BifReturnVal (*)(Frame* frame, const Args* args); -class BuiltinFunc final : public Func - { +class BuiltinFunc final : public Func { public: - BuiltinFunc(built_in_func func, const char* name, bool is_pure); - ~BuiltinFunc() override = default; + BuiltinFunc(built_in_func func, const char* name, bool is_pure); + ~BuiltinFunc() override = default; - bool IsPure() const override; - ValPtr Invoke(zeek::Args* args, Frame* parent) const override; - built_in_func TheFunc() const { return func; } + bool IsPure() const override; + ValPtr Invoke(zeek::Args* args, Frame* parent) const override; + built_in_func TheFunc() const { return func; } - void Describe(ODesc* d) const override; + void Describe(ODesc* d) const override; protected: - BuiltinFunc() - { - func = nullptr; - is_pure = 0; - } + BuiltinFunc() { + func = nullptr; + is_pure = 0; + } - built_in_func func; - bool is_pure; - }; + built_in_func func; + bool is_pure; +}; extern bool check_built_in_call(BuiltinFunc* f, CallExpr* call); -struct CallInfo - { - const CallExpr* call; - const Func* func; - const zeek::Args& args; - }; +struct CallInfo { + const CallExpr* call; + const Func* func; + const zeek::Args& args; +}; // Class that collects all the specifics defining a Func. -class FunctionIngredients - { +class FunctionIngredients { public: - // Gathers all of the information from a scope and a function body needed - // to build a function. - FunctionIngredients(ScopePtr scope, StmtPtr body, const std::string& module_name); + // Gathers all of the information from a scope and a function body needed + // to build a function. + FunctionIngredients(ScopePtr scope, StmtPtr body, const std::string& module_name); - const IDPtr& GetID() const { return id; } + const IDPtr& GetID() const { return id; } - const StmtPtr& Body() const { return body; } - void ReplaceBody(StmtPtr new_body) { body = std::move(new_body); } + const StmtPtr& Body() const { return body; } + void ReplaceBody(StmtPtr new_body) { body = std::move(new_body); } - const auto& Inits() const { return inits; } - void ClearInits() { inits.clear(); } + const auto& Inits() const { return inits; } + void ClearInits() { inits.clear(); } - size_t FrameSize() const { return frame_size; } - int Priority() const { return priority; } - const ScopePtr& Scope() const { return scope; } - const auto& Groups() const { return groups; } + size_t FrameSize() const { return frame_size; } + int Priority() const { return priority; } + const ScopePtr& Scope() const { return scope; } + const auto& Groups() const { return groups; } - // Used by script optimization to update lambda ingredients - // after compilation. - void SetFrameSize(size_t _frame_size) { frame_size = _frame_size; } + // Used by script optimization to update lambda ingredients + // after compilation. + void SetFrameSize(size_t _frame_size) { frame_size = _frame_size; } private: - IDPtr id; - StmtPtr body; - std::vector inits; - size_t frame_size = 0; - int priority = 0; - ScopePtr scope; - std::set groups; - }; + IDPtr id; + StmtPtr body; + std::vector inits; + size_t frame_size = 0; + int priority = 0; + ScopePtr scope; + std::set groups; +}; using FunctionIngredientsPtr = std::shared_ptr; @@ -427,19 +404,18 @@ extern bool did_builtin_init; extern std::vector bif_initializers; extern void init_primary_bifs(); -inline void run_bif_initializers() - { - for ( const auto& bi : bif_initializers ) - bi(); +inline void run_bif_initializers() { + for ( const auto& bi : bif_initializers ) + bi(); - bif_initializers = {}; - } + bif_initializers = {}; +} extern void emit_builtin_exception(const char* msg); extern void emit_builtin_exception(const char* msg, const ValPtr& arg); extern void emit_builtin_exception(const char* msg, Obj* arg); - } // namespace detail +} // namespace detail extern std::string render_call_stack(); @@ -448,4 +424,4 @@ extern void emit_builtin_error(const char* msg); extern void emit_builtin_error(const char* msg, const ValPtr&); extern void emit_builtin_error(const char* msg, Obj* arg); - } // namespace zeek +} // namespace zeek diff --git a/src/Hash.cc b/src/Hash.cc index 5c90ebd592..d139b2abb8 100644 --- a/src/Hash.cc +++ b/src/Hash.cc @@ -18,8 +18,7 @@ #include "const.bif.netvar_h" -namespace zeek::detail - { +namespace zeek::detail { alignas(32) uint64_t KeyedHash::shared_highwayhash_key[4]; alignas(32) uint64_t KeyedHash::cluster_highwayhash_key[4]; @@ -27,657 +26,552 @@ alignas(16) unsigned long long KeyedHash::shared_siphash_key[2]; // we use the following lines to not pull in the highwayhash headers in Hash.h - but to check the // types did not change underneath us. -static_assert(std::is_same_v, - "Highwayhash return values must match hash_x_t"); -static_assert(std::is_same_v, - "Highwayhash return values must match hash_x_t"); -static_assert(std::is_same_v, - "Highwayhash return values must match hash_x_t"); +static_assert(std::is_same_v, "Highwayhash return values must match hash_x_t"); +static_assert(std::is_same_v, "Highwayhash return values must match hash_x_t"); +static_assert(std::is_same_v, "Highwayhash return values must match hash_x_t"); -void KeyedHash::InitializeSeeds(const std::array& seed_data) - { - static_assert( - std::is_same_v, - "Highwayhash Key is not unsigned long long[2]"); - static_assert(std::is_same_v, - "Highwayhash HHKey is not uint64_t[4]"); - if ( seeds_initialized ) - return; +void KeyedHash::InitializeSeeds(const std::array& seed_data) { + static_assert(std::is_same_v, + "Highwayhash Key is not unsigned long long[2]"); + static_assert(std::is_same_v, + "Highwayhash HHKey is not uint64_t[4]"); + if ( seeds_initialized ) + return; - // leaving this at being generated by md5, allowing user scripts that use hmac_md5 functionality - // to get the same hash values as before. For now. - internal_md5((const u_char*)seed_data.data(), sizeof(seed_data) - 16, - shared_hmac_md5_key); // The last 128 bits of buf are for siphash - // yes, we use the same buffer twice to initialize two different keys. This should not really be - // a security problem of any kind: hmac-md5 is not really used anymore - and even if it was, the - // hashes should not reveal any information about their initialization vector. - static_assert(sizeof(shared_highwayhash_key) == SHA256_DIGEST_LENGTH); - calculate_digest(Hash_SHA256, (const u_char*)seed_data.data(), sizeof(seed_data) - 16, - reinterpret_cast(shared_highwayhash_key)); - memcpy(shared_siphash_key, reinterpret_cast(seed_data.data()) + 64, 16); + // leaving this at being generated by md5, allowing user scripts that use hmac_md5 functionality + // to get the same hash values as before. For now. + internal_md5((const u_char*)seed_data.data(), sizeof(seed_data) - 16, + shared_hmac_md5_key); // The last 128 bits of buf are for siphash + // yes, we use the same buffer twice to initialize two different keys. This should not really be + // a security problem of any kind: hmac-md5 is not really used anymore - and even if it was, the + // hashes should not reveal any information about their initialization vector. + static_assert(sizeof(shared_highwayhash_key) == SHA256_DIGEST_LENGTH); + calculate_digest(Hash_SHA256, (const u_char*)seed_data.data(), sizeof(seed_data) - 16, + reinterpret_cast(shared_highwayhash_key)); + memcpy(shared_siphash_key, reinterpret_cast(seed_data.data()) + 64, 16); - seeds_initialized = true; - } + seeds_initialized = true; +} -void KeyedHash::InitOptions() - { - calculate_digest(Hash_SHA256, BifConst::digest_salt->Bytes(), BifConst::digest_salt->Len(), - reinterpret_cast(cluster_highwayhash_key)); - } +void KeyedHash::InitOptions() { + calculate_digest(Hash_SHA256, BifConst::digest_salt->Bytes(), BifConst::digest_salt->Len(), + reinterpret_cast(cluster_highwayhash_key)); +} -hash64_t KeyedHash::Hash64(const void* bytes, uint64_t size) - { - return highwayhash::SipHash(shared_siphash_key, static_cast(bytes), size); - } +hash64_t KeyedHash::Hash64(const void* bytes, uint64_t size) { + return highwayhash::SipHash(shared_siphash_key, static_cast(bytes), size); +} -void KeyedHash::Hash128(const void* bytes, uint64_t size, hash128_t* result) - { - highwayhash::InstructionSets::Run( - shared_highwayhash_key, static_cast(bytes), size, result); - } +void KeyedHash::Hash128(const void* bytes, uint64_t size, hash128_t* result) { + highwayhash::InstructionSets::Run(shared_highwayhash_key, static_cast(bytes), + size, result); +} -void KeyedHash::Hash256(const void* bytes, uint64_t size, hash256_t* result) - { - highwayhash::InstructionSets::Run( - shared_highwayhash_key, static_cast(bytes), size, result); - } +void KeyedHash::Hash256(const void* bytes, uint64_t size, hash256_t* result) { + highwayhash::InstructionSets::Run(shared_highwayhash_key, static_cast(bytes), + size, result); +} -hash64_t KeyedHash::StaticHash64(const void* bytes, uint64_t size) - { - hash64_t result = 0; - highwayhash::InstructionSets::Run( - cluster_highwayhash_key, static_cast(bytes), size, &result); - return result; - } +hash64_t KeyedHash::StaticHash64(const void* bytes, uint64_t size) { + hash64_t result = 0; + highwayhash::InstructionSets::Run(cluster_highwayhash_key, + static_cast(bytes), size, &result); + return result; +} -void KeyedHash::StaticHash128(const void* bytes, uint64_t size, hash128_t* result) - { - highwayhash::InstructionSets::Run( - cluster_highwayhash_key, static_cast(bytes), size, result); - } +void KeyedHash::StaticHash128(const void* bytes, uint64_t size, hash128_t* result) { + highwayhash::InstructionSets::Run(cluster_highwayhash_key, + static_cast(bytes), size, result); +} -void KeyedHash::StaticHash256(const void* bytes, uint64_t size, hash256_t* result) - { - highwayhash::InstructionSets::Run( - cluster_highwayhash_key, static_cast(bytes), size, result); - } +void KeyedHash::StaticHash256(const void* bytes, uint64_t size, hash256_t* result) { + highwayhash::InstructionSets::Run(cluster_highwayhash_key, + static_cast(bytes), size, result); +} -void init_hash_function() - { - // Make sure we have already called init_random_seed(). - if ( ! KeyedHash::IsInitialized() ) - reporter->InternalError("Zeek's hash functions aren't fully initialized"); - } +void init_hash_function() { + // Make sure we have already called init_random_seed(). + if ( ! KeyedHash::IsInitialized() ) + reporter->InternalError("Zeek's hash functions aren't fully initialized"); +} -HashKey::HashKey(bool b) - { - Set(b); - } +HashKey::HashKey(bool b) { Set(b); } -HashKey::HashKey(int i) - { - Set(i); - } +HashKey::HashKey(int i) { Set(i); } -HashKey::HashKey(zeek_int_t bi) - { - Set(bi); - } +HashKey::HashKey(zeek_int_t bi) { Set(bi); } -HashKey::HashKey(zeek_uint_t bu) - { - Set(bu); - } +HashKey::HashKey(zeek_uint_t bu) { Set(bu); } -HashKey::HashKey(uint32_t u) - { - Set(u); - } +HashKey::HashKey(uint32_t u) { Set(u); } -HashKey::HashKey(const uint32_t u[], size_t n) - { - size = write_size = n * sizeof(u[0]); - key = (char*)u; - } +HashKey::HashKey(const uint32_t u[], size_t n) { + size = write_size = n * sizeof(u[0]); + key = (char*)u; +} -HashKey::HashKey(double d) - { - Set(d); - } +HashKey::HashKey(double d) { Set(d); } -HashKey::HashKey(const void* p) - { - Set(p); - } +HashKey::HashKey(const void* p) { Set(p); } -HashKey::HashKey(const char* s) - { - size = write_size = strlen(s); // note - skip final \0 - key = (char*)s; - } +HashKey::HashKey(const char* s) { + size = write_size = strlen(s); // note - skip final \0 + key = (char*)s; +} -HashKey::HashKey(const String* s) - { - size = write_size = s->Len(); - key = (char*)s->Bytes(); - } +HashKey::HashKey(const String* s) { + size = write_size = s->Len(); + key = (char*)s->Bytes(); +} -HashKey::HashKey(const void* bytes, size_t arg_size) - { - size = write_size = arg_size; - key = CopyKey((char*)bytes, size); - is_our_dynamic = true; - } +HashKey::HashKey(const void* bytes, size_t arg_size) { + size = write_size = arg_size; + key = CopyKey((char*)bytes, size); + is_our_dynamic = true; +} -HashKey::HashKey(const void* arg_key, size_t arg_size, hash_t arg_hash) - { - size = write_size = arg_size; - hash = arg_hash; - key = CopyKey((char*)arg_key, size); - is_our_dynamic = true; - } +HashKey::HashKey(const void* arg_key, size_t arg_size, hash_t arg_hash) { + size = write_size = arg_size; + hash = arg_hash; + key = CopyKey((char*)arg_key, size); + is_our_dynamic = true; +} -HashKey::HashKey(const void* arg_key, size_t arg_size, hash_t arg_hash, bool /* dont_copy */) - { - size = write_size = arg_size; - hash = arg_hash; - key = (char*)arg_key; - } +HashKey::HashKey(const void* arg_key, size_t arg_size, hash_t arg_hash, bool /* dont_copy */) { + size = write_size = arg_size; + hash = arg_hash; + key = (char*)arg_key; +} -HashKey::HashKey(const HashKey& other) : HashKey(other.key, other.size, other.hash) { } +HashKey::HashKey(const HashKey& other) : HashKey(other.key, other.size, other.hash) {} -HashKey::HashKey(HashKey&& other) noexcept - { - hash = other.hash; - size = other.size; - write_size = other.write_size; - read_size = other.read_size; +HashKey::HashKey(HashKey&& other) noexcept { + hash = other.hash; + size = other.size; + write_size = other.write_size; + read_size = other.read_size; - is_our_dynamic = other.is_our_dynamic; - key = other.key; + is_our_dynamic = other.is_our_dynamic; + key = other.key; - other.size = 0; - other.is_our_dynamic = false; - other.key = nullptr; - } + other.size = 0; + other.is_our_dynamic = false; + other.key = nullptr; +} -HashKey::~HashKey() - { - if ( is_our_dynamic ) - delete[] reinterpret_cast(key); - } +HashKey::~HashKey() { + if ( is_our_dynamic ) + delete[] reinterpret_cast(key); +} -hash_t HashKey::Hash() const - { - if ( hash == 0 ) - hash = HashBytes(key, size); +hash_t HashKey::Hash() const { + if ( hash == 0 ) + hash = HashBytes(key, size); #ifdef DEBUG - if ( zeek::detail::debug_logger.IsEnabled(DBG_HASHKEY) ) - { - ODesc d; - Describe(&d); - DBG_LOG(DBG_HASHKEY, "HashKey %p %s", this, d.Description()); - } + if ( zeek::detail::debug_logger.IsEnabled(DBG_HASHKEY) ) { + ODesc d; + Describe(&d); + DBG_LOG(DBG_HASHKEY, "HashKey %p %s", this, d.Description()); + } #endif - return hash; - } - -void* HashKey::TakeKey() - { - if ( is_our_dynamic ) - { - is_our_dynamic = false; - return key; - } - else - return CopyKey(key, size); - } - -void HashKey::Describe(ODesc* d) const - { - char buf[64]; - snprintf(buf, 16, "%0" PRIx64, hash); - d->Add(buf); - d->SP(); - - if ( size > 0 ) - { - d->Add(IsAllocated() ? "(" : "["); - - for ( size_t i = 0; i < size; i++ ) - { - if ( i > 0 ) - { - d->SP(); - // Extra spacing every 8 bytes, for readability. - if ( i % 8 == 0 ) - d->SP(); - } - - // Don't display unwritten content, only say how much there is. - if ( i > write_size ) - { - d->Add("<+"); - d->Add(static_cast(size - write_size - 1)); - d->Add(" of "); - d->Add(static_cast(size)); - d->Add(" available>"); - break; - } - - snprintf(buf, 3, "%02x", key[i]); - d->Add(buf); - } - - d->Add(IsAllocated() ? ")" : "]"); - } - } - -char* HashKey::CopyKey(const char* k, size_t s) const - { - char* k_copy = new char[s]; // s == 0 is okay, returns non-nil - memcpy(k_copy, k, s); - return k_copy; - } - -hash_t HashKey::HashBytes(const void* bytes, size_t size) - { - return KeyedHash::Hash64(bytes, size); - } - -void HashKey::Set(bool b) - { - key_u.b = b; - key = reinterpret_cast(&key_u); - size = write_size = sizeof(b); - } - -void HashKey::Set(int i) - { - key_u.i = i; - key = reinterpret_cast(&key_u); - size = write_size = sizeof(i); - } - -void HashKey::Set(zeek_int_t bi) - { - key_u.bi = bi; - key = reinterpret_cast(&key_u); - size = write_size = sizeof(bi); - } - -void HashKey::Set(zeek_uint_t bu) - { - key_u.bi = zeek_int_t(bu); - key = reinterpret_cast(&key_u); - size = write_size = sizeof(bu); - } - -void HashKey::Set(uint32_t u) - { - key_u.u32 = u; - key = reinterpret_cast(&key_u); - size = write_size = sizeof(u); - } - -void HashKey::Set(double d) - { - key_u.d = d; - key = reinterpret_cast(&key_u); - size = write_size = sizeof(d); - } - -void HashKey::Set(const void* p) - { - key_u.p = p; - key = reinterpret_cast(&key_u); - size = write_size = sizeof(p); - } - -void HashKey::Reserve(const char* tag, size_t addl_size, size_t alignment) - { - ASSERT(! IsAllocated()); - size_t s0 = size; - size_t s1 = util::memory_size_align(size, alignment); - size = s1 + addl_size; - - DBG_LOG(DBG_HASHKEY, "HashKey %p reserving %lu/%lu: %lu -> %lu -> %lu [%s]", this, addl_size, - alignment, s0, s1, size, tag); - } - -void HashKey::Allocate() - { - if ( key != nullptr && key != reinterpret_cast(&key_u) ) - { - reporter->InternalWarning("usage error in HashKey::Allocate(): already allocated"); - return; - } - - is_our_dynamic = true; - key = reinterpret_cast(new double[size / sizeof(double) + 1]); - - read_size = 0; - write_size = 0; - } - -void HashKey::Write(const char* tag, bool b) - { - Write(tag, &b, sizeof(b), 0); - } - -void HashKey::Write(const char* tag, int i, bool align) - { - if ( ! IsAllocated() ) - { - Set(i); - return; - } - - Write(tag, &i, sizeof(i), align ? sizeof(i) : 0); - } - -void HashKey::Write(const char* tag, zeek_int_t bi, bool align) - { - if ( ! IsAllocated() ) - { - Set(bi); - return; - } - - Write(tag, &bi, sizeof(bi), align ? sizeof(bi) : 0); - } - -void HashKey::Write(const char* tag, zeek_uint_t bu, bool align) - { - if ( ! IsAllocated() ) - { - Set(bu); - return; - } - - Write(tag, &bu, sizeof(bu), align ? sizeof(bu) : 0); - } - -void HashKey::Write(const char* tag, uint32_t u, bool align) - { - if ( ! IsAllocated() ) - { - Set(u); - return; - } - - Write(tag, &u, sizeof(u), align ? sizeof(u) : 0); - } - -void HashKey::Write(const char* tag, double d, bool align) - { - if ( ! IsAllocated() ) - { - Set(d); - return; - } - - Write(tag, &d, sizeof(d), align ? sizeof(d) : 0); - } - -void HashKey::Write(const char* tag, const void* bytes, size_t n, size_t alignment) - { - size_t s0 = write_size; - AlignWrite(alignment); - size_t s1 = write_size; - EnsureWriteSpace(n); - - memcpy(key + write_size, bytes, n); - write_size += n; - - DBG_LOG(DBG_HASHKEY, "HashKey %p writing %lu/%lu: %lu -> %lu -> %lu [%s]", this, n, alignment, - s0, s1, write_size, tag); - } - -void HashKey::SkipWrite(const char* tag, size_t n) - { - DBG_LOG(DBG_HASHKEY, "HashKey %p skip-writing %lu: %lu -> %lu [%s]", this, n, write_size, - write_size + n, tag); - - EnsureWriteSpace(n); - write_size += n; - } - -void HashKey::AlignWrite(size_t alignment) - { - ASSERT(IsAllocated()); - - if ( alignment == 0 ) - return; - - size_t old_size = write_size; - - write_size = util::memory_size_align(write_size, alignment); - - if ( write_size > size ) - reporter->InternalError("buffer overflow in HashKey::AlignWrite(): " - "after alignment, %lu bytes used of %lu allocated", - write_size, size); - - while ( old_size < write_size ) - key[old_size++] = '\0'; - } - -void HashKey::AlignRead(size_t alignment) const - { - ASSERT(IsAllocated()); - - if ( alignment == 0 ) - return; - - int old_size = read_size; - - read_size = util::memory_size_align(read_size, alignment); - - if ( read_size > size ) - reporter->InternalError("buffer overflow in HashKey::AlignRead(): " - "after alignment, %lu bytes used of %lu allocated", - read_size, size); - } - -void HashKey::Read(const char* tag, bool& b) const - { - Read(tag, &b, sizeof(b), 0); - } - -void HashKey::Read(const char* tag, int& i, bool align) const - { - Read(tag, &i, sizeof(i), align ? sizeof(i) : 0); - } - -void HashKey::Read(const char* tag, zeek_int_t& i, bool align) const - { - Read(tag, &i, sizeof(i), align ? sizeof(i) : 0); - } - -void HashKey::Read(const char* tag, zeek_uint_t& u, bool align) const - { - Read(tag, &u, sizeof(u), align ? sizeof(u) : 0); - } - -void HashKey::Read(const char* tag, uint32_t& u, bool align) const - { - Read(tag, &u, sizeof(u), align ? sizeof(u) : 0); - } - -void HashKey::Read(const char* tag, double& d, bool align) const - { - Read(tag, &d, sizeof(d), align ? sizeof(d) : 0); - } - -void HashKey::Read(const char* tag, void* out, size_t n, size_t alignment) const - { - size_t s0 = read_size; - AlignRead(alignment); - size_t s1 = read_size; - EnsureReadSpace(n); - - // In case out is nil, make sure nothing is to be read, and only memcpy - // when there is a non-zero amount. Memory checkers don't nullpointers - // in memcpy even if the size is 0. - ASSERT(out != nullptr || (out == nullptr && n == 0)); - - if ( n > 0 ) - { - memcpy(out, key + read_size, n); - read_size += n; - } - - DBG_LOG(DBG_HASHKEY, "HashKey %p reading %lu/%lu: %lu -> %lu -> %lu [%s]", this, n, alignment, - s0, s1, read_size, tag); - } - -void HashKey::SkipRead(const char* tag, size_t n) const - { - DBG_LOG(DBG_HASHKEY, "HashKey %p skip-reading %lu: %lu -> %lu [%s]", this, n, read_size, - read_size + n, tag); - - EnsureReadSpace(n); - read_size += n; - } - -void HashKey::EnsureWriteSpace(size_t n) const - { - if ( n == 0 ) - return; - - if ( ! IsAllocated() ) - reporter->InternalError("usage error in HashKey::EnsureWriteSpace(): " - "size-checking unreserved buffer"); - if ( write_size + n > size ) - reporter->InternalError("buffer overflow in HashKey::Write(): writing %lu " - "bytes with %lu remaining", - n, size - write_size); - } - -void HashKey::EnsureReadSpace(size_t n) const - { - if ( n == 0 ) - return; - - if ( ! IsAllocated() ) - reporter->InternalError("usage error in HashKey::EnsureReadSpace(): " - "size-checking unreserved buffer"); - if ( read_size + n > size ) - reporter->InternalError("buffer overflow in HashKey::EnsureReadSpace(): reading %lu " - "bytes with %lu remaining", - n, size - read_size); - } - -bool HashKey::operator==(const HashKey& other) const - { - // Quick exit for the same object. - if ( this == &other ) - return true; - - return Equal(other.key, other.size, other.hash); - } - -bool HashKey::operator!=(const HashKey& other) const - { - // Quick exit for different objects. - if ( this != &other ) - return true; - - return ! Equal(other.key, other.size, other.hash); - } - -bool HashKey::Equal(const void* other_key, size_t other_size, hash_t other_hash) const - { - // If the key memory is the same just return true. - if ( key == other_key && size == other_size ) - return true; - - // If either key is nullptr, return false. If they were both nullptr, it - // would have fallen into the above block already. - if ( key == nullptr || other_key == nullptr ) - return false; - - return (hash == other_hash) && (size == other_size) && (memcmp(key, other_key, size) == 0); - } - -HashKey& HashKey::operator=(const HashKey& other) - { - if ( this == &other ) - return *this; - - if ( is_our_dynamic && IsAllocated() ) - delete[] key; - - hash = other.hash; - size = other.size; - is_our_dynamic = true; - write_size = other.write_size; - read_size = other.read_size; - - key = CopyKey(other.key, other.size); - - return *this; - } - -HashKey& HashKey::operator=(HashKey&& other) noexcept - { - if ( this == &other ) - return *this; - - hash = other.hash; - size = other.size; - write_size = other.write_size; - read_size = other.read_size; - - if ( is_our_dynamic && IsAllocated() ) - delete[] key; - - is_our_dynamic = other.is_our_dynamic; - key = other.key; - - other.size = 0; - other.is_our_dynamic = false; - other.key = nullptr; - - return *this; - } + return hash; +} + +void* HashKey::TakeKey() { + if ( is_our_dynamic ) { + is_our_dynamic = false; + return key; + } + else + return CopyKey(key, size); +} + +void HashKey::Describe(ODesc* d) const { + char buf[64]; + snprintf(buf, 16, "%0" PRIx64, hash); + d->Add(buf); + d->SP(); + + if ( size > 0 ) { + d->Add(IsAllocated() ? "(" : "["); + + for ( size_t i = 0; i < size; i++ ) { + if ( i > 0 ) { + d->SP(); + // Extra spacing every 8 bytes, for readability. + if ( i % 8 == 0 ) + d->SP(); + } + + // Don't display unwritten content, only say how much there is. + if ( i > write_size ) { + d->Add("<+"); + d->Add(static_cast(size - write_size - 1)); + d->Add(" of "); + d->Add(static_cast(size)); + d->Add(" available>"); + break; + } + + snprintf(buf, 3, "%02x", key[i]); + d->Add(buf); + } + + d->Add(IsAllocated() ? ")" : "]"); + } +} + +char* HashKey::CopyKey(const char* k, size_t s) const { + char* k_copy = new char[s]; // s == 0 is okay, returns non-nil + memcpy(k_copy, k, s); + return k_copy; +} + +hash_t HashKey::HashBytes(const void* bytes, size_t size) { return KeyedHash::Hash64(bytes, size); } + +void HashKey::Set(bool b) { + key_u.b = b; + key = reinterpret_cast(&key_u); + size = write_size = sizeof(b); +} + +void HashKey::Set(int i) { + key_u.i = i; + key = reinterpret_cast(&key_u); + size = write_size = sizeof(i); +} + +void HashKey::Set(zeek_int_t bi) { + key_u.bi = bi; + key = reinterpret_cast(&key_u); + size = write_size = sizeof(bi); +} + +void HashKey::Set(zeek_uint_t bu) { + key_u.bi = zeek_int_t(bu); + key = reinterpret_cast(&key_u); + size = write_size = sizeof(bu); +} + +void HashKey::Set(uint32_t u) { + key_u.u32 = u; + key = reinterpret_cast(&key_u); + size = write_size = sizeof(u); +} + +void HashKey::Set(double d) { + key_u.d = d; + key = reinterpret_cast(&key_u); + size = write_size = sizeof(d); +} + +void HashKey::Set(const void* p) { + key_u.p = p; + key = reinterpret_cast(&key_u); + size = write_size = sizeof(p); +} + +void HashKey::Reserve(const char* tag, size_t addl_size, size_t alignment) { + ASSERT(! IsAllocated()); + size_t s0 = size; + size_t s1 = util::memory_size_align(size, alignment); + size = s1 + addl_size; + + DBG_LOG(DBG_HASHKEY, "HashKey %p reserving %lu/%lu: %lu -> %lu -> %lu [%s]", this, addl_size, alignment, s0, s1, + size, tag); +} + +void HashKey::Allocate() { + if ( key != nullptr && key != reinterpret_cast(&key_u) ) { + reporter->InternalWarning("usage error in HashKey::Allocate(): already allocated"); + return; + } + + is_our_dynamic = true; + key = reinterpret_cast(new double[size / sizeof(double) + 1]); + + read_size = 0; + write_size = 0; +} + +void HashKey::Write(const char* tag, bool b) { Write(tag, &b, sizeof(b), 0); } + +void HashKey::Write(const char* tag, int i, bool align) { + if ( ! IsAllocated() ) { + Set(i); + return; + } + + Write(tag, &i, sizeof(i), align ? sizeof(i) : 0); +} + +void HashKey::Write(const char* tag, zeek_int_t bi, bool align) { + if ( ! IsAllocated() ) { + Set(bi); + return; + } + + Write(tag, &bi, sizeof(bi), align ? sizeof(bi) : 0); +} + +void HashKey::Write(const char* tag, zeek_uint_t bu, bool align) { + if ( ! IsAllocated() ) { + Set(bu); + return; + } + + Write(tag, &bu, sizeof(bu), align ? sizeof(bu) : 0); +} + +void HashKey::Write(const char* tag, uint32_t u, bool align) { + if ( ! IsAllocated() ) { + Set(u); + return; + } + + Write(tag, &u, sizeof(u), align ? sizeof(u) : 0); +} + +void HashKey::Write(const char* tag, double d, bool align) { + if ( ! IsAllocated() ) { + Set(d); + return; + } + + Write(tag, &d, sizeof(d), align ? sizeof(d) : 0); +} + +void HashKey::Write(const char* tag, const void* bytes, size_t n, size_t alignment) { + size_t s0 = write_size; + AlignWrite(alignment); + size_t s1 = write_size; + EnsureWriteSpace(n); + + memcpy(key + write_size, bytes, n); + write_size += n; + + DBG_LOG(DBG_HASHKEY, "HashKey %p writing %lu/%lu: %lu -> %lu -> %lu [%s]", this, n, alignment, s0, s1, write_size, + tag); +} + +void HashKey::SkipWrite(const char* tag, size_t n) { + DBG_LOG(DBG_HASHKEY, "HashKey %p skip-writing %lu: %lu -> %lu [%s]", this, n, write_size, write_size + n, tag); + + EnsureWriteSpace(n); + write_size += n; +} + +void HashKey::AlignWrite(size_t alignment) { + ASSERT(IsAllocated()); + + if ( alignment == 0 ) + return; + + size_t old_size = write_size; + + write_size = util::memory_size_align(write_size, alignment); + + if ( write_size > size ) + reporter->InternalError( + "buffer overflow in HashKey::AlignWrite(): " + "after alignment, %lu bytes used of %lu allocated", + write_size, size); + + while ( old_size < write_size ) + key[old_size++] = '\0'; +} + +void HashKey::AlignRead(size_t alignment) const { + ASSERT(IsAllocated()); + + if ( alignment == 0 ) + return; + + int old_size = read_size; + + read_size = util::memory_size_align(read_size, alignment); + + if ( read_size > size ) + reporter->InternalError( + "buffer overflow in HashKey::AlignRead(): " + "after alignment, %lu bytes used of %lu allocated", + read_size, size); +} + +void HashKey::Read(const char* tag, bool& b) const { Read(tag, &b, sizeof(b), 0); } + +void HashKey::Read(const char* tag, int& i, bool align) const { Read(tag, &i, sizeof(i), align ? sizeof(i) : 0); } + +void HashKey::Read(const char* tag, zeek_int_t& i, bool align) const { + Read(tag, &i, sizeof(i), align ? sizeof(i) : 0); +} + +void HashKey::Read(const char* tag, zeek_uint_t& u, bool align) const { + Read(tag, &u, sizeof(u), align ? sizeof(u) : 0); +} + +void HashKey::Read(const char* tag, uint32_t& u, bool align) const { Read(tag, &u, sizeof(u), align ? sizeof(u) : 0); } + +void HashKey::Read(const char* tag, double& d, bool align) const { Read(tag, &d, sizeof(d), align ? sizeof(d) : 0); } + +void HashKey::Read(const char* tag, void* out, size_t n, size_t alignment) const { + size_t s0 = read_size; + AlignRead(alignment); + size_t s1 = read_size; + EnsureReadSpace(n); + + // In case out is nil, make sure nothing is to be read, and only memcpy + // when there is a non-zero amount. Memory checkers don't nullpointers + // in memcpy even if the size is 0. + ASSERT(out != nullptr || (out == nullptr && n == 0)); + + if ( n > 0 ) { + memcpy(out, key + read_size, n); + read_size += n; + } + + DBG_LOG(DBG_HASHKEY, "HashKey %p reading %lu/%lu: %lu -> %lu -> %lu [%s]", this, n, alignment, s0, s1, read_size, + tag); +} + +void HashKey::SkipRead(const char* tag, size_t n) const { + DBG_LOG(DBG_HASHKEY, "HashKey %p skip-reading %lu: %lu -> %lu [%s]", this, n, read_size, read_size + n, tag); + + EnsureReadSpace(n); + read_size += n; +} + +void HashKey::EnsureWriteSpace(size_t n) const { + if ( n == 0 ) + return; + + if ( ! IsAllocated() ) + reporter->InternalError( + "usage error in HashKey::EnsureWriteSpace(): " + "size-checking unreserved buffer"); + if ( write_size + n > size ) + reporter->InternalError( + "buffer overflow in HashKey::Write(): writing %lu " + "bytes with %lu remaining", + n, size - write_size); +} + +void HashKey::EnsureReadSpace(size_t n) const { + if ( n == 0 ) + return; + + if ( ! IsAllocated() ) + reporter->InternalError( + "usage error in HashKey::EnsureReadSpace(): " + "size-checking unreserved buffer"); + if ( read_size + n > size ) + reporter->InternalError( + "buffer overflow in HashKey::EnsureReadSpace(): reading %lu " + "bytes with %lu remaining", + n, size - read_size); +} + +bool HashKey::operator==(const HashKey& other) const { + // Quick exit for the same object. + if ( this == &other ) + return true; + + return Equal(other.key, other.size, other.hash); +} + +bool HashKey::operator!=(const HashKey& other) const { + // Quick exit for different objects. + if ( this != &other ) + return true; + + return ! Equal(other.key, other.size, other.hash); +} + +bool HashKey::Equal(const void* other_key, size_t other_size, hash_t other_hash) const { + // If the key memory is the same just return true. + if ( key == other_key && size == other_size ) + return true; + + // If either key is nullptr, return false. If they were both nullptr, it + // would have fallen into the above block already. + if ( key == nullptr || other_key == nullptr ) + return false; + + return (hash == other_hash) && (size == other_size) && (memcmp(key, other_key, size) == 0); +} + +HashKey& HashKey::operator=(const HashKey& other) { + if ( this == &other ) + return *this; + + if ( is_our_dynamic && IsAllocated() ) + delete[] key; + + hash = other.hash; + size = other.size; + is_our_dynamic = true; + write_size = other.write_size; + read_size = other.read_size; + + key = CopyKey(other.key, other.size); + + return *this; +} + +HashKey& HashKey::operator=(HashKey&& other) noexcept { + if ( this == &other ) + return *this; + + hash = other.hash; + size = other.size; + write_size = other.write_size; + read_size = other.read_size; + + if ( is_our_dynamic && IsAllocated() ) + delete[] key; + + is_our_dynamic = other.is_our_dynamic; + key = other.key; + + other.size = 0; + other.is_our_dynamic = false; + other.key = nullptr; + + return *this; +} TEST_SUITE_BEGIN("Hash"); -TEST_CASE("equality") - { - HashKey h1(12345); - HashKey h2(12345); - HashKey h3(67890); +TEST_CASE("equality") { + HashKey h1(12345); + HashKey h2(12345); + HashKey h3(67890); - CHECK(h1 == h2); - CHECK(h1 != h3); - } + CHECK(h1 == h2); + CHECK(h1 != h3); +} -TEST_CASE("copy assignment") - { - HashKey h1(12345); - HashKey h2 = h1; - HashKey h3{h1}; +TEST_CASE("copy assignment") { + HashKey h1(12345); + HashKey h2 = h1; + HashKey h3{h1}; - CHECK(h1 == h2); - CHECK(h1 == h3); - } + CHECK(h1 == h2); + CHECK(h1 == h3); +} -TEST_CASE("move assignment") - { - HashKey h1(12345); - HashKey h2(12345); - HashKey h3(12345); +TEST_CASE("move assignment") { + HashKey h1(12345); + HashKey h2(12345); + HashKey h3(12345); - HashKey h4 = std::move(h2); - HashKey h5{h3}; + HashKey h4 = std::move(h2); + HashKey h5{h3}; - CHECK(h1 == h4); - CHECK(h1 == h5); - } + CHECK(h1 == h4); + CHECK(h1 == h5); +} TEST_SUITE_END(); - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/Hash.h b/src/Hash.h index 615eefd466..3ef55b20ac 100644 --- a/src/Hash.h +++ b/src/Hash.h @@ -27,368 +27,356 @@ // to allow md5_hmac_bif access to the hmac seed #include "zeek/ZeekArgs.h" -namespace zeek - { +namespace zeek { class String; class ODesc; - } +} // namespace zeek -namespace zeek::detail - { +namespace zeek::detail { class Frame; class BifReturnVal; - } +} // namespace zeek::detail -namespace zeek::BifFunc - { +namespace zeek::BifFunc { extern zeek::detail::BifReturnVal md5_hmac_bif(zeek::detail::Frame* frame, const zeek::Args*); - } +} -namespace zeek::detail - { +namespace zeek::detail { using hash_t = uint64_t; using hash64_t = uint64_t; using hash128_t = uint64_t[2]; using hash256_t = uint64_t[4]; -class KeyedHash - { +class KeyedHash { public: - /** - * Generate a 64 bit digest hash. - * - * This hash is seeded with random data, unless the ZEEK_SEED_FILE environment - * variable is set. Thus, typically every node will return a different hash - * after every restart. - * - * This should be used for internal hashes that do not have to be stable over - * the cluster/runs - like, e.g. connection ID generation. - * - * @param bytes Bytes to hash - * - * @param size Size of bytes - * - * @returns 64 bit digest hash - */ - static hash64_t Hash64(const void* bytes, uint64_t size); + /** + * Generate a 64 bit digest hash. + * + * This hash is seeded with random data, unless the ZEEK_SEED_FILE environment + * variable is set. Thus, typically every node will return a different hash + * after every restart. + * + * This should be used for internal hashes that do not have to be stable over + * the cluster/runs - like, e.g. connection ID generation. + * + * @param bytes Bytes to hash + * + * @param size Size of bytes + * + * @returns 64 bit digest hash + */ + static hash64_t Hash64(const void* bytes, uint64_t size); - /** - * Generate a 128 bit digest hash. - * - * This hash is seeded with random data, unless the ZEEK_SEED_FILE environment - * variable is set. Thus, typically every node will return a different hash - * after every restart. - * - * This should be used for internal hashes that do not have to be stable over - * the cluster/runs - like, e.g. connection ID generation. - * - * @param bytes Bytes to hash - * - * @param size Size of bytes - * - * @param result Result of the hashing operation. - */ - static void Hash128(const void* bytes, uint64_t size, hash128_t* result); + /** + * Generate a 128 bit digest hash. + * + * This hash is seeded with random data, unless the ZEEK_SEED_FILE environment + * variable is set. Thus, typically every node will return a different hash + * after every restart. + * + * This should be used for internal hashes that do not have to be stable over + * the cluster/runs - like, e.g. connection ID generation. + * + * @param bytes Bytes to hash + * + * @param size Size of bytes + * + * @param result Result of the hashing operation. + */ + static void Hash128(const void* bytes, uint64_t size, hash128_t* result); - /** - * Generate a 256 bit digest hash. - * - * This hash is seeded with random data, unless the ZEEK_SEED_FILE environment - * variable is set. Thus, typically every node will return a different hash - * after every restart. - * - * This should be used for internal hashes that do not have to be stable over - * the cluster/runs - like, e.g. connection ID generation. - * - * @param bytes Bytes to hash - * - * @param size Size of bytes - * - * @param result Result of the hashing operation. - */ - static void Hash256(const void* bytes, uint64_t size, hash256_t* result); + /** + * Generate a 256 bit digest hash. + * + * This hash is seeded with random data, unless the ZEEK_SEED_FILE environment + * variable is set. Thus, typically every node will return a different hash + * after every restart. + * + * This should be used for internal hashes that do not have to be stable over + * the cluster/runs - like, e.g. connection ID generation. + * + * @param bytes Bytes to hash + * + * @param size Size of bytes + * + * @param result Result of the hashing operation. + */ + static void Hash256(const void* bytes, uint64_t size, hash256_t* result); - /** - * Generates a installation-specific 64 bit hash. - * - * This function generates a 64 bit digest hash, which is stable over a cluster - * or a restart. - * - * To be more exact - the seed value for this hash is generated from the script-level - * :zeek:see:`digest_salt` constant. The seeds are stable as long as this value - * is not changed. - * - * This should be used for hashes that have to remain stable over the entire - * cluster. An example are file IDs, which have to be stable over several workers. - * - * @param bytes Bytes to hash - * - * @param size Size of bytes - * - * @returns 64 bit digest hash - */ - static hash64_t StaticHash64(const void* bytes, uint64_t size); + /** + * Generates a installation-specific 64 bit hash. + * + * This function generates a 64 bit digest hash, which is stable over a cluster + * or a restart. + * + * To be more exact - the seed value for this hash is generated from the script-level + * :zeek:see:`digest_salt` constant. The seeds are stable as long as this value + * is not changed. + * + * This should be used for hashes that have to remain stable over the entire + * cluster. An example are file IDs, which have to be stable over several workers. + * + * @param bytes Bytes to hash + * + * @param size Size of bytes + * + * @returns 64 bit digest hash + */ + static hash64_t StaticHash64(const void* bytes, uint64_t size); - /** - * Generates a installation-specific 128 bit hash. - * - * This function generates a 128 bit digest hash, which is stable over a cluster - * or a restart. - * - * To be more exact - the seed value for this hash is generated from the script-level - * :zeek:see:`digest_salt` constant. The seeds are stable as long as this value - * is not changed. - * - * This should be used for hashes that have to remain stable over the entire - * cluster. An example are file IDs, which have to be stable over several workers. - * - * @param bytes Bytes to hash - * - * @param size Size of bytes - * - * @param result Result of the hashing operation. - */ - static void StaticHash128(const void* bytes, uint64_t size, hash128_t* result); + /** + * Generates a installation-specific 128 bit hash. + * + * This function generates a 128 bit digest hash, which is stable over a cluster + * or a restart. + * + * To be more exact - the seed value for this hash is generated from the script-level + * :zeek:see:`digest_salt` constant. The seeds are stable as long as this value + * is not changed. + * + * This should be used for hashes that have to remain stable over the entire + * cluster. An example are file IDs, which have to be stable over several workers. + * + * @param bytes Bytes to hash + * + * @param size Size of bytes + * + * @param result Result of the hashing operation. + */ + static void StaticHash128(const void* bytes, uint64_t size, hash128_t* result); - /** - * Generates a installation-specific 256 bit hash. - * - * This function generates a 128 bit digest hash, which is stable over a cluster - * or a restart. - * - * To be more exact - the seed value for this hash is generated from the script-level - * :zeek:see:`digest_salt` constant. The seeds are stable as long as this value - * is not changed. - * - * This should be used for hashes that have to remain stable over the entire - * cluster. An example are file IDs, which have to be stable over several workers. - * - * @param bytes Bytes to hash - * - * @param size Size of bytes - * - * @param result Result of the hashing operation. - */ - static void StaticHash256(const void* bytes, uint64_t size, hash256_t* result); + /** + * Generates a installation-specific 256 bit hash. + * + * This function generates a 128 bit digest hash, which is stable over a cluster + * or a restart. + * + * To be more exact - the seed value for this hash is generated from the script-level + * :zeek:see:`digest_salt` constant. The seeds are stable as long as this value + * is not changed. + * + * This should be used for hashes that have to remain stable over the entire + * cluster. An example are file IDs, which have to be stable over several workers. + * + * @param bytes Bytes to hash + * + * @param size Size of bytes + * + * @param result Result of the hashing operation. + */ + static void StaticHash256(const void* bytes, uint64_t size, hash256_t* result); - /** - * Size of the initial seed - */ - constexpr static int SEED_INIT_SIZE = 20; + /** + * Size of the initial seed + */ + constexpr static int SEED_INIT_SIZE = 20; - /** - * Initialize the (typically process-specific) seeds. This function is indirectly - * called from main, during early initialization. - * - * @param seed_data random data used as an initial seed - */ - static void InitializeSeeds(const std::array& seed_data); + /** + * Initialize the (typically process-specific) seeds. This function is indirectly + * called from main, during early initialization. + * + * @param seed_data random data used as an initial seed + */ + static void InitializeSeeds(const std::array& seed_data); - /** - * Returns true if the process-specific seeds have been initialized - * - * @return True if the seeds are initialized - */ - static bool IsInitialized() { return seeds_initialized; } + /** + * Returns true if the process-specific seeds have been initialized + * + * @return True if the seeds are initialized + */ + static bool IsInitialized() { return seeds_initialized; } - /** - * Initializes the static hash seeds using the script-level - * :zeek:see:`digest_salt` constant. - */ - static void InitOptions(); + /** + * Initializes the static hash seeds using the script-level + * :zeek:see:`digest_salt` constant. + */ + static void InitOptions(); private: - // actually HHKey. This key changes each start (unless a seed is specified) - alignas(32) static uint64_t shared_highwayhash_key[4]; - // actually HHKey. This key is installation specific and sourced from the digest_salt - // script-level const. - alignas(32) static uint64_t cluster_highwayhash_key[4]; - // actually HH_U64, which has the same type. This key changes each start (unless a seed is - // specified) - alignas(16) static unsigned long long shared_siphash_key[2]; - // This key changes each start (unless a seed is specified) - inline static uint8_t shared_hmac_md5_key[16]; - inline static bool seeds_initialized = false; + // actually HHKey. This key changes each start (unless a seed is specified) + alignas(32) static uint64_t shared_highwayhash_key[4]; + // actually HHKey. This key is installation specific and sourced from the digest_salt + // script-level const. + alignas(32) static uint64_t cluster_highwayhash_key[4]; + // actually HH_U64, which has the same type. This key changes each start (unless a seed is + // specified) + alignas(16) static unsigned long long shared_siphash_key[2]; + // This key changes each start (unless a seed is specified) + inline static uint8_t shared_hmac_md5_key[16]; + inline static bool seeds_initialized = false; - friend void util::detail::hmac_md5(size_t size, const unsigned char* bytes, - unsigned char digest[16]); - friend BifReturnVal BifFunc::md5_hmac_bif(zeek::detail::Frame* frame, const Args*); - }; + friend void util::detail::hmac_md5(size_t size, const unsigned char* bytes, unsigned char digest[16]); + friend BifReturnVal BifFunc::md5_hmac_bif(zeek::detail::Frame* frame, const Args*); +}; -enum HashKeyTag - { - HASH_KEY_INT, - HASH_KEY_DOUBLE, - HASH_KEY_STRING - }; +enum HashKeyTag { HASH_KEY_INT, HASH_KEY_DOUBLE, HASH_KEY_STRING }; constexpr int NUM_HASH_KEYS = HASH_KEY_STRING + 1; -class HashKey - { +class HashKey { public: - explicit HashKey() { key_u.u32 = 0; } - explicit HashKey(bool b); - explicit HashKey(int i); - explicit HashKey(zeek_int_t bi); - explicit HashKey(zeek_uint_t bu); - explicit HashKey(uint32_t u); - HashKey(const uint32_t u[], size_t n); - explicit HashKey(double d); - explicit HashKey(const void* p); - explicit HashKey(const char* s); // No copying, no ownership - explicit HashKey(const String* s); // No copying, no ownership + explicit HashKey() { key_u.u32 = 0; } + explicit HashKey(bool b); + explicit HashKey(int i); + explicit HashKey(zeek_int_t bi); + explicit HashKey(zeek_uint_t bu); + explicit HashKey(uint32_t u); + HashKey(const uint32_t u[], size_t n); + explicit HashKey(double d); + explicit HashKey(const void* p); + explicit HashKey(const char* s); // No copying, no ownership + explicit HashKey(const String* s); // No copying, no ownership - // Builds a key from the given chunk of bytes. Copies the data. - HashKey(const void* bytes, size_t size); + // Builds a key from the given chunk of bytes. Copies the data. + HashKey(const void* bytes, size_t size); - // Create a HashKey given all of its components. Copies the key. - HashKey(const void* key, size_t size, hash_t hash); + // Create a HashKey given all of its components. Copies the key. + HashKey(const void* key, size_t size, hash_t hash); - // Create a Hashkey given all of its components *without* - // copying the key and *without* taking ownership. Note that - // "dont_copy" is a type placeholder to differentiate this member - // function from the one above; its value is not used. - HashKey(const void* key, size_t size, hash_t hash, bool dont_copy); + // Create a Hashkey given all of its components *without* + // copying the key and *without* taking ownership. Note that + // "dont_copy" is a type placeholder to differentiate this member + // function from the one above; its value is not used. + HashKey(const void* key, size_t size, hash_t hash, bool dont_copy); - // Copy constructor. Always copies the key. - HashKey(const HashKey& other); + // Copy constructor. Always copies the key. + HashKey(const HashKey& other); - // Move constructor. Takes ownership of the key. - HashKey(HashKey&& other) noexcept; + // Move constructor. Takes ownership of the key. + HashKey(HashKey&& other) noexcept; - // Destructor - ~HashKey(); + // Destructor + ~HashKey(); - // Hands over the key to the caller. This means that if the - // key is our dynamic, we give it to the caller and mark it - // as not our dynamic. If initially it's not our dynamic, - // we give them a copy of it. - void* TakeKey(); + // Hands over the key to the caller. This means that if the + // key is our dynamic, we give it to the caller and mark it + // as not our dynamic. If initially it's not our dynamic, + // we give them a copy of it. + void* TakeKey(); - const void* Key() const { return key; } - size_t Size() const { return size; } - hash_t Hash() const; + const void* Key() const { return key; } + size_t Size() const { return size; } + hash_t Hash() const; - static hash_t HashBytes(const void* bytes, size_t size); + static hash_t HashBytes(const void* bytes, size_t size); - // A HashKey is "allocated" when the underlying key points somewhere - // other than our internal key_u union. This is almost like - // is_our_dynamic, but remains true also after TakeKey(). - bool IsAllocated() const - { - return (key != nullptr && key != reinterpret_cast(&key_u)); - } + // A HashKey is "allocated" when the underlying key points somewhere + // other than our internal key_u union. This is almost like + // is_our_dynamic, but remains true also after TakeKey(). + bool IsAllocated() const { return (key != nullptr && key != reinterpret_cast(&key_u)); } - // Buffer size reservation. Repeated calls to these methods - // incrementally build up the eventual buffer size to be allocated via - // Allocate(). - template void ReserveType(const char* tag) { Reserve(tag, sizeof(T), sizeof(T)); } - void Reserve(const char* tag, size_t addl_size, size_t alignment = 0); + // Buffer size reservation. Repeated calls to these methods + // incrementally build up the eventual buffer size to be allocated via + // Allocate(). + template + void ReserveType(const char* tag) { + Reserve(tag, sizeof(T), sizeof(T)); + } + void Reserve(const char* tag, size_t addl_size, size_t alignment = 0); - // Allocates the reserved amount of memory - void Allocate(); + // Allocates the reserved amount of memory + void Allocate(); - // Incremental writes into an allocated HashKey. The tags give context - // to what's being written and are only used in debug-build log streams. - // When true, the alignment boolean will cause write-marker alignment to - // the size of the item being written, otherwise writes happen directly - // at the current marker. - void Write(const char* tag, bool b); - void Write(const char* tag, int i, bool align = true); - void Write(const char* tag, zeek_int_t bi, bool align = true); - void Write(const char* tag, zeek_uint_t bu, bool align = true); - void Write(const char* tag, uint32_t u, bool align = true); - void Write(const char* tag, double d, bool align = true); + // Incremental writes into an allocated HashKey. The tags give context + // to what's being written and are only used in debug-build log streams. + // When true, the alignment boolean will cause write-marker alignment to + // the size of the item being written, otherwise writes happen directly + // at the current marker. + void Write(const char* tag, bool b); + void Write(const char* tag, int i, bool align = true); + void Write(const char* tag, zeek_int_t bi, bool align = true); + void Write(const char* tag, zeek_uint_t bu, bool align = true); + void Write(const char* tag, uint32_t u, bool align = true); + void Write(const char* tag, double d, bool align = true); - void Write(const char* tag, const void* bytes, size_t n, size_t alignment = 0); + void Write(const char* tag, const void* bytes, size_t n, size_t alignment = 0); - // For writes that copy directly into the allocated buffer, this method - // advances the write marker without modifying content. - void SkipWrite(const char* tag, size_t n); + // For writes that copy directly into the allocated buffer, this method + // advances the write marker without modifying content. + void SkipWrite(const char* tag, size_t n); - // Aligns the write marker to the next multiple of the given alignment size. - void AlignWrite(size_t alignment); + // Aligns the write marker to the next multiple of the given alignment size. + void AlignWrite(size_t alignment); - // Bounds check: if the buffer does not have at least n bytes available - // to write into, triggers an InternalError. - void EnsureWriteSpace(size_t n) const; + // Bounds check: if the buffer does not have at least n bytes available + // to write into, triggers an InternalError. + void EnsureWriteSpace(size_t n) const; - // Reads don't modify our internal state except for the read offset - // pointer. To blend in more seamlessly with the rest of Zeek we keep - // reads a const operation. - void ResetRead() const { read_size = 0; } + // Reads don't modify our internal state except for the read offset + // pointer. To blend in more seamlessly with the rest of Zeek we keep + // reads a const operation. + void ResetRead() const { read_size = 0; } - // Incremental reads from an allocated HashKey. As with writes, the - // tags are only used for debug-build logging, and alignment prior - // to the read of the item is controlled by the align boolean. - void Read(const char* tag, bool& b) const; - void Read(const char* tag, int& i, bool align = true) const; - void Read(const char* tag, zeek_int_t& bi, bool align = true) const; - void Read(const char* tag, zeek_uint_t& bu, bool align = true) const; - void Read(const char* tag, uint32_t& u, bool align = true) const; - void Read(const char* tag, double& d, bool align = true) const; + // Incremental reads from an allocated HashKey. As with writes, the + // tags are only used for debug-build logging, and alignment prior + // to the read of the item is controlled by the align boolean. + void Read(const char* tag, bool& b) const; + void Read(const char* tag, int& i, bool align = true) const; + void Read(const char* tag, zeek_int_t& bi, bool align = true) const; + void Read(const char* tag, zeek_uint_t& bu, bool align = true) const; + void Read(const char* tag, uint32_t& u, bool align = true) const; + void Read(const char* tag, double& d, bool align = true) const; - void Read(const char* tag, void* out, size_t n, size_t alignment = 0) const; + void Read(const char* tag, void* out, size_t n, size_t alignment = 0) const; - // These mirror the corresponding write methods above. - void SkipRead(const char* tag, size_t n) const; - void AlignRead(size_t alignment) const; - void EnsureReadSpace(size_t n) const; + // These mirror the corresponding write methods above. + void SkipRead(const char* tag, size_t n) const; + void AlignRead(size_t alignment) const; + void EnsureReadSpace(size_t n) const; - void* KeyAtWrite() { return static_cast(key + write_size); } - const void* KeyAtRead() const { return static_cast(key + read_size); } - const void* KeyEnd() const { return static_cast(key + size); } + void* KeyAtWrite() { return static_cast(key + write_size); } + const void* KeyAtRead() const { return static_cast(key + read_size); } + const void* KeyEnd() const { return static_cast(key + size); } - void Describe(ODesc* d) const; + void Describe(ODesc* d) const; - bool operator==(const HashKey& other) const; - bool operator!=(const HashKey& other) const; + bool operator==(const HashKey& other) const; + bool operator!=(const HashKey& other) const; - bool Equal(const void* other_key, size_t other_size, hash_t other_hash) const; + bool Equal(const void* other_key, size_t other_size, hash_t other_hash) const; - // Copy operator. Always copies the key. - HashKey& operator=(const HashKey& other); + // Copy operator. Always copies the key. + HashKey& operator=(const HashKey& other); - // Move operator. Takes ownership of the key. - HashKey& operator=(HashKey&& other) noexcept; + // Move operator. Takes ownership of the key. + HashKey& operator=(HashKey&& other) noexcept; protected: - char* CopyKey(const char* key, size_t size) const; + char* CopyKey(const char* key, size_t size) const; - // Payload setters for types stored directly in the key_u union. These - // adjust the size and write_size markers to indicate a full buffer, and - // use the key_u union for storage. - void Set(bool b); - void Set(int i); - void Set(zeek_int_t bi); - void Set(zeek_uint_t bu); - void Set(uint32_t u); - void Set(double d); - void Set(const void* p); + // Payload setters for types stored directly in the key_u union. These + // adjust the size and write_size markers to indicate a full buffer, and + // use the key_u union for storage. + void Set(bool b); + void Set(int i); + void Set(zeek_int_t bi); + void Set(zeek_uint_t bu); + void Set(uint32_t u); + void Set(double d); + void Set(const void* p); - union { - bool b; - int i; - zeek_int_t bi; - uint32_t u32; - double d; - const void* p; - } key_u; + union { + bool b; + int i; + zeek_int_t bi; + uint32_t u32; + double d; + const void* p; + } key_u; - char* key = nullptr; - mutable hash_t hash = 0; - size_t size = 0; - bool is_our_dynamic = false; - size_t write_size = 0; - mutable size_t read_size = 0; - }; + char* key = nullptr; + mutable hash_t hash = 0; + size_t size = 0; + bool is_our_dynamic = false; + size_t write_size = 0; + mutable size_t read_size = 0; +}; extern void init_hash_function(); - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/ID.cc b/src/ID.cc index 67c9c2ffec..72effdb4a3 100644 --- a/src/ID.cc +++ b/src/ID.cc @@ -22,8 +22,7 @@ #include "zeek/zeekygen/ScriptInfo.h" #include "zeek/zeekygen/utils.h" -namespace zeek - { +namespace zeek { RecordTypePtr id::conn_id; RecordTypePtr id::endpoint; @@ -37,662 +36,554 @@ TableTypePtr id::count_set; VectorTypePtr id::string_vec; VectorTypePtr id::index_vec; -const detail::IDPtr& id::find(std::string_view name) - { - return zeek::detail::global_scope()->Find(name); - } +const detail::IDPtr& id::find(std::string_view name) { return zeek::detail::global_scope()->Find(name); } -const TypePtr& id::find_type(std::string_view name) - { - auto id = zeek::detail::global_scope()->Find(name); +const TypePtr& id::find_type(std::string_view name) { + auto id = zeek::detail::global_scope()->Find(name); - if ( ! id ) - reporter->InternalError("Failed to find type named: %s", std::string(name).data()); + if ( ! id ) + reporter->InternalError("Failed to find type named: %s", std::string(name).data()); - return id->GetType(); - } + return id->GetType(); +} -const ValPtr& id::find_val(std::string_view name) - { - auto id = zeek::detail::global_scope()->Find(name); +const ValPtr& id::find_val(std::string_view name) { + auto id = zeek::detail::global_scope()->Find(name); - if ( ! id ) - reporter->InternalError("Failed to find variable named: %s", std::string(name).data()); + if ( ! id ) + reporter->InternalError("Failed to find variable named: %s", std::string(name).data()); - return id->GetVal(); - } + return id->GetVal(); +} -const ValPtr& id::find_const(std::string_view name) - { - auto id = zeek::detail::global_scope()->Find(name); +const ValPtr& id::find_const(std::string_view name) { + auto id = zeek::detail::global_scope()->Find(name); - if ( ! id ) - reporter->InternalError("Failed to find variable named: %s", std::string(name).data()); + if ( ! id ) + reporter->InternalError("Failed to find variable named: %s", std::string(name).data()); - if ( ! id->IsConst() ) - reporter->InternalError("Variable is not 'const', but expected to be: %s", - std::string(name).data()); + if ( ! id->IsConst() ) + reporter->InternalError("Variable is not 'const', but expected to be: %s", std::string(name).data()); - return id->GetVal(); - } + return id->GetVal(); +} -FuncPtr id::find_func(std::string_view name) - { - const auto& v = id::find_val(name); +FuncPtr id::find_func(std::string_view name) { + const auto& v = id::find_val(name); - if ( ! v ) - return nullptr; + if ( ! v ) + return nullptr; - if ( ! IsFunc(v->GetType()->Tag()) ) - reporter->InternalError("Expected variable '%s' to be a function", - std::string(name).data()); + if ( ! IsFunc(v->GetType()->Tag()) ) + reporter->InternalError("Expected variable '%s' to be a function", std::string(name).data()); - return v.get()->As()->AsFuncPtr(); - } + return v.get()->As()->AsFuncPtr(); +} -void id::detail::init_types() - { - conn_id = id::find_type("conn_id"); - endpoint = id::find_type("endpoint"); - connection = id::find_type("connection"); - fa_file = id::find_type("fa_file"); - fa_metadata = id::find_type("fa_metadata"); - transport_proto = id::find_type("transport_proto"); - string_set = id::find_type("string_set"); - string_array = id::find_type("string_array"); - count_set = id::find_type("count_set"); - string_vec = id::find_type("string_vec"); - index_vec = id::find_type("index_vec"); - } +void id::detail::init_types() { + conn_id = id::find_type("conn_id"); + endpoint = id::find_type("endpoint"); + connection = id::find_type("connection"); + fa_file = id::find_type("fa_file"); + fa_metadata = id::find_type("fa_metadata"); + transport_proto = id::find_type("transport_proto"); + string_set = id::find_type("string_set"); + string_array = id::find_type("string_array"); + count_set = id::find_type("count_set"); + string_vec = id::find_type("string_vec"); + index_vec = id::find_type("index_vec"); +} -namespace detail - { +namespace detail { -ID::ID(const char* arg_name, IDScope arg_scope, bool arg_is_export) - { - name = util::copy_string(arg_name); - scope = arg_scope; - is_export = arg_is_export; - is_option = false; - is_blank = name && extract_var_name(name) == "_"; - is_const = false; - is_enum_const = false; - is_type = false; - offset = 0; +ID::ID(const char* arg_name, IDScope arg_scope, bool arg_is_export) { + name = util::copy_string(arg_name); + scope = arg_scope; + is_export = arg_is_export; + is_option = false; + is_blank = name && extract_var_name(name) == "_"; + is_const = false; + is_enum_const = false; + is_type = false; + offset = 0; - if ( is_blank ) - SetType(base_type(TYPE_ANY)); + if ( is_blank ) + SetType(base_type(TYPE_ANY)); - opt_info = new IDOptInfo(this); + opt_info = new IDOptInfo(this); - infer_return_type = false; + infer_return_type = false; - SetLocationInfo(&start_location, &end_location); - } + SetLocationInfo(&start_location, &end_location); +} -ID::~ID() - { - ClearOptInfo(); - delete[] name; - } +ID::~ID() { + ClearOptInfo(); + delete[] name; +} -std::string ID::ModuleName() const - { - return extract_module_name(name); - } +std::string ID::ModuleName() const { return extract_module_name(name); } -void ID::SetType(TypePtr t) - { - type = std::move(t); - } +void ID::SetType(TypePtr t) { type = std::move(t); } -void ID::ClearVal() - { - val = nullptr; - } +void ID::ClearVal() { val = nullptr; } -void ID::SetVal(ValPtr v) - { - val = std::move(v); - Modified(); +void ID::SetVal(ValPtr v) { + val = std::move(v); + Modified(); #ifdef DEBUG - UpdateValID(); + UpdateValID(); #endif - if ( type && val && type->Tag() == TYPE_FUNC && - type->AsFuncType()->Flavor() == FUNC_FLAVOR_EVENT ) - { - EventHandler* handler = event_registry->Lookup(name); - auto func = val.get()->As()->AsFuncPtr(); - if ( ! handler ) - { - handler = new EventHandler(name); - handler->SetFunc(func); - event_registry->Register(handler, true); - - if ( ! IsExport() ) - register_new_event({NewRef{}, this}); - } - else - { - // Otherwise, internally defined events cannot - // have local handler. - handler->SetFunc(func); - } - } - } - -void ID::SetVal(ValPtr v, InitClass c) - { - if ( c == INIT_NONE || c == INIT_FULL ) - { - SetVal(std::move(v)); - return; - } - - if ( type->Tag() != TYPE_TABLE && (type->Tag() != TYPE_PATTERN || c == INIT_REMOVE) && - (type->Tag() != TYPE_VECTOR || c == INIT_REMOVE) ) - { - if ( c == INIT_EXTRA ) - Error("+= initializer only applies to tables, sets, vectors and patterns", v.get()); - else - Error("-= initializer only applies to tables and sets", v.get()); - } - - else - { - if ( c == INIT_EXTRA ) - { - if ( ! val ) - { - SetVal(std::move(v)); - return; - } - else - v->AddTo(val.get(), false); - } - else - { - if ( val ) - v->RemoveFrom(val.get()); - } - } - } - -void ID::SetVal(ExprPtr ev, InitClass c) - { - const auto& a = attrs->Find(c == INIT_EXTRA ? ATTR_ADD_FUNC : ATTR_DEL_FUNC); - - if ( ! a ) - Internal("no add/delete function in ID::SetVal"); - - if ( ! val ) - { - Error(zeek::util::fmt("%s initializer applied to ID without value", - c == INIT_EXTRA ? "+=" : "-="), - this); - return; - } - - EvalFunc(a->GetExpr(), std::move(ev)); - } - -bool ID::IsRedefinable() const - { - return GetAttr(ATTR_REDEF) != nullptr; - } - -void ID::SetAttrs(AttributesPtr a) - { - attrs = nullptr; - AddAttrs(std::move(a)); - } - -void ID::UpdateValAttrs() - { - if ( ! attrs ) - return; - - auto tag = GetType()->Tag(); - - if ( tag == TYPE_FUNC ) - { - const auto& attr = attrs->Find(ATTR_ERROR_HANDLER); - - if ( attr ) - event_registry->SetErrorHandler(Name()); - } - - if ( tag == TYPE_RECORD ) - { - const auto& attr = attrs->Find(ATTR_LOG); - - if ( attr ) - { - // Apply &log to all record fields. - RecordType* rt = GetType()->AsRecordType(); - for ( int i = 0; i < rt->NumFields(); ++i ) - { - TypeDecl* fd = rt->FieldDecl(i); - - if ( ! fd->attrs ) - fd->attrs = make_intrusive(rt->GetFieldType(i), true, IsGlobal()); - - fd->attrs->AddAttr(make_intrusive(ATTR_LOG)); - } - } - } - - if ( ! val ) - return; - - auto vtag = val->GetType()->Tag(); - - if ( vtag == TYPE_TABLE ) - val->AsTableVal()->SetAttrs(attrs); - - else if ( vtag == TYPE_FILE ) - val->AsFile()->SetAttrs(attrs.get()); - } - -const AttrPtr& ID::GetAttr(AttrTag t) const - { - return attrs ? attrs->Find(t) : Attr::nil; - } - -bool ID::IsDeprecated() const - { - return GetAttr(ATTR_DEPRECATED) != nullptr; - } - -void ID::MakeDeprecated(ExprPtr deprecation) - { - if ( IsDeprecated() ) - return; - - AddAttr(make_intrusive(ATTR_DEPRECATED, std::move(deprecation))); - } - -std::string ID::GetDeprecationWarning() const - { - std::string result; - const auto& depr_attr = GetAttr(ATTR_DEPRECATED); - - if ( depr_attr ) - result = depr_attr->DeprecationMessage(); - - if ( result.empty() ) - return util::fmt("deprecated (%s)", Name()); - else - return util::fmt("deprecated (%s): %s", Name(), result.c_str()); - } - -void ID::AddAttr(AttrPtr a, bool is_redef) - { - std::vector attrv{std::move(a)}; - auto attrs = make_intrusive(std::move(attrv), GetType(), false, IsGlobal()); - AddAttrs(std::move(attrs), is_redef); - } - -void ID::AddAttrs(AttributesPtr a, bool is_redef) - { - if ( attrs ) - attrs->AddAttrs(a, is_redef); - else - attrs = std::move(a); - - UpdateValAttrs(); - } - -void ID::RemoveAttr(AttrTag a) - { - if ( attrs ) - attrs->RemoveAttr(a); - } - -void ID::SetOption() - { - if ( is_option ) - return; - - is_option = true; - - // option implied redefinable - if ( ! IsRedefinable() ) - AddAttr(make_intrusive(ATTR_REDEF)); - } - -void ID::EvalFunc(ExprPtr ef, ExprPtr ev) - { - auto arg1 = make_intrusive(val); - auto args = make_intrusive(); - args->Append(std::move(arg1)); - args->Append(std::move(ev)); - auto ce = make_intrusive(std::move(ef), std::move(args)); - SetVal(ce->Eval(nullptr)); - } - -TraversalCode ID::Traverse(TraversalCallback* cb) const - { - TraversalCode tc = cb->PreID(this); - HANDLE_TC_STMT_PRE(tc); - - if ( is_type ) - { - tc = cb->PreTypedef(this); - HANDLE_TC_STMT_PRE(tc); - - tc = cb->PostTypedef(this); - HANDLE_TC_STMT_PRE(tc); - } - - // FIXME: Perhaps we should be checking at other than global scope. - else if ( val && IsFunc(val->GetType()->Tag()) && cb->current_scope == detail::global_scope() ) - { - tc = val->AsFunc()->Traverse(cb); - HANDLE_TC_STMT_PRE(tc); - } - - else if ( ! is_enum_const ) - { - tc = cb->PreDecl(this); - HANDLE_TC_STMT_PRE(tc); - - tc = cb->PostDecl(this); - HANDLE_TC_STMT_PRE(tc); - } - - tc = cb->PostID(this); - HANDLE_TC_EXPR_POST(tc); - } - -void ID::Error(const char* msg, const Obj* o2) - { - Obj::Error(msg, o2, true); - SetType(error_type()); - } - -void ID::Describe(ODesc* d) const - { - d->Add(name); - } - -void ID::DescribeExtended(ODesc* d) const - { - d->Add(name); - - if ( type ) - { - d->Add(" : "); - type->Describe(d); - } - - if ( val ) - { - d->Add(" = "); - val->Describe(d); - } - - if ( attrs ) - { - d->Add(" "); - attrs->Describe(d); - } - } - -void ID::DescribeReSTShort(ODesc* d) const - { - if ( is_type ) - d->Add(":zeek:type:`"); - else - d->Add(":zeek:id:`"); - - d->Add(name); - d->Add("`"); - - if ( type ) - { - d->Add(": "); - d->Add(":zeek:type:`"); - - if ( ! is_type && ! type->GetName().empty() ) - d->Add(type->GetName().c_str()); - else - { - TypeTag t = type->Tag(); - - switch ( t ) - { - case TYPE_TABLE: - d->Add(type->IsSet() ? "set" : type_name(t)); - break; - - case TYPE_FUNC: - d->Add(type->AsFuncType()->FlavorString().c_str()); - break; - - case TYPE_ENUM: - if ( is_type ) - d->Add(type_name(t)); - else - d->Add(zeekygen_mgr->GetEnumTypeName(Name()).c_str()); - break; - - default: - d->Add(type_name(t)); - break; - } - } - - d->Add("`"); - } - - if ( attrs ) - { - d->SP(); - attrs->DescribeReST(d, true); - } - } - -void ID::DescribeReST(ODesc* d, bool roles_only) const - { - if ( roles_only ) - { - if ( is_type ) - d->Add(":zeek:type:`"); - else - d->Add(":zeek:id:`"); - d->Add(name); - d->Add("`"); - } - else - { - if ( is_type ) - d->Add(".. zeek:type:: "); - else - d->Add(".. zeek:id:: "); - - d->Add(name); - - if ( auto sc = zeek::zeekygen::detail::source_code_range(this) ) - { - d->PushIndent(); - d->Add(util::fmt(":source-code: %s", sc->data())); - d->PopIndentNoNL(); - } - } - - d->PushIndent(); - d->NL(); - - if ( type ) - { - d->Add(":Type: "); - - if ( ! is_type && ! type->GetName().empty() ) - { - d->Add(":zeek:type:`"); - d->Add(type->GetName()); - d->Add("`"); - } - else - { - type->DescribeReST(d, roles_only); - - if ( IsFunc(type->Tag()) ) - { - auto ft = type->AsFuncType(); - - if ( ft->Flavor() == FUNC_FLAVOR_EVENT || ft->Flavor() == FUNC_FLAVOR_HOOK ) - { - const auto& protos = ft->Prototypes(); - - if ( protos.size() > 1 ) - { - auto first = true; - - for ( const auto& proto : protos ) - { - if ( first ) - { - first = false; - continue; - } - - d->NL(); - d->Add(":Type: :zeek:type:`"); - d->Add(ft->FlavorString()); - d->Add("` ("); - proto.args->DescribeFieldsReST(d, true); - d->Add(")"); - } - } - } - } - } - - d->NL(); - } - - if ( attrs ) - { - d->Add(":Attributes: "); - attrs->DescribeReST(d); - d->NL(); - } - - if ( val && type && type->Tag() != TYPE_FUNC && type->InternalType() != TYPE_INTERNAL_VOID && - // Values within Version module are likely to include a - // constantly-changing version number and be a frequent - // source of error/desynchronization, so don't include them. - ModuleName() != "Version" ) - { - d->Add(":Default:"); - auto ii = zeekygen_mgr->GetIdentifierInfo(Name()); - auto redefs = ii->GetRedefs(); - const auto& iv = ! redefs.empty() && ii->InitialVal() ? ii->InitialVal() : val; - - if ( type->InternalType() == TYPE_INTERNAL_OTHER ) - { - switch ( type->Tag() ) - { - case TYPE_TABLE: - if ( iv->AsTable()->Length() == 0 ) - { - d->Add(" ``{}``"); - d->NL(); - break; - } - // Fall-through. - - default: - d->NL(); - d->PushIndent(); - d->Add("::"); - d->NL(); - d->PushIndent(); - iv->DescribeReST(d); - d->PopIndent(); - d->PopIndent(); - } - } - - else - { - d->SP(); - iv->DescribeReST(d); - d->NL(); - } - - for ( auto& ir : redefs ) - { - if ( ! ir->init_expr ) - continue; - - if ( ir->ic == INIT_NONE ) - continue; - - std::string redef_str; - ODesc expr_desc; - ir->init_expr->Describe(&expr_desc); - redef_str = expr_desc.Description(); - redef_str = util::strreplace(redef_str, "\n", " "); - - d->Add(":Redefinition: "); - d->Add(util::fmt("from :doc:`/scripts/%s`", ir->from_script.data())); - d->NL(); - d->PushIndent(); - - if ( ir->ic == INIT_FULL ) - d->Add("``=``"); - else if ( ir->ic == INIT_EXTRA ) - d->Add("``+=``"); - else if ( ir->ic == INIT_REMOVE ) - d->Add("``-=``"); - else - assert(false); - - d->Add("::"); - d->NL(); - d->PushIndent(); - d->Add(redef_str.data()); - d->PopIndent(); - d->PopIndent(); - } - } - } + if ( type && val && type->Tag() == TYPE_FUNC && type->AsFuncType()->Flavor() == FUNC_FLAVOR_EVENT ) { + EventHandler* handler = event_registry->Lookup(name); + auto func = val.get()->As()->AsFuncPtr(); + if ( ! handler ) { + handler = new EventHandler(name); + handler->SetFunc(func); + event_registry->Register(handler, true); + + if ( ! IsExport() ) + register_new_event({NewRef{}, this}); + } + else { + // Otherwise, internally defined events cannot + // have local handler. + handler->SetFunc(func); + } + } +} + +void ID::SetVal(ValPtr v, InitClass c) { + if ( c == INIT_NONE || c == INIT_FULL ) { + SetVal(std::move(v)); + return; + } + + if ( type->Tag() != TYPE_TABLE && (type->Tag() != TYPE_PATTERN || c == INIT_REMOVE) && + (type->Tag() != TYPE_VECTOR || c == INIT_REMOVE) ) { + if ( c == INIT_EXTRA ) + Error("+= initializer only applies to tables, sets, vectors and patterns", v.get()); + else + Error("-= initializer only applies to tables and sets", v.get()); + } + + else { + if ( c == INIT_EXTRA ) { + if ( ! val ) { + SetVal(std::move(v)); + return; + } + else + v->AddTo(val.get(), false); + } + else { + if ( val ) + v->RemoveFrom(val.get()); + } + } +} + +void ID::SetVal(ExprPtr ev, InitClass c) { + const auto& a = attrs->Find(c == INIT_EXTRA ? ATTR_ADD_FUNC : ATTR_DEL_FUNC); + + if ( ! a ) + Internal("no add/delete function in ID::SetVal"); + + if ( ! val ) { + Error(zeek::util::fmt("%s initializer applied to ID without value", c == INIT_EXTRA ? "+=" : "-="), this); + return; + } + + EvalFunc(a->GetExpr(), std::move(ev)); +} + +bool ID::IsRedefinable() const { return GetAttr(ATTR_REDEF) != nullptr; } + +void ID::SetAttrs(AttributesPtr a) { + attrs = nullptr; + AddAttrs(std::move(a)); +} + +void ID::UpdateValAttrs() { + if ( ! attrs ) + return; + + auto tag = GetType()->Tag(); + + if ( tag == TYPE_FUNC ) { + const auto& attr = attrs->Find(ATTR_ERROR_HANDLER); + + if ( attr ) + event_registry->SetErrorHandler(Name()); + } + + if ( tag == TYPE_RECORD ) { + const auto& attr = attrs->Find(ATTR_LOG); + + if ( attr ) { + // Apply &log to all record fields. + RecordType* rt = GetType()->AsRecordType(); + for ( int i = 0; i < rt->NumFields(); ++i ) { + TypeDecl* fd = rt->FieldDecl(i); + + if ( ! fd->attrs ) + fd->attrs = make_intrusive(rt->GetFieldType(i), true, IsGlobal()); + + fd->attrs->AddAttr(make_intrusive(ATTR_LOG)); + } + } + } + + if ( ! val ) + return; + + auto vtag = val->GetType()->Tag(); + + if ( vtag == TYPE_TABLE ) + val->AsTableVal()->SetAttrs(attrs); + + else if ( vtag == TYPE_FILE ) + val->AsFile()->SetAttrs(attrs.get()); +} + +const AttrPtr& ID::GetAttr(AttrTag t) const { return attrs ? attrs->Find(t) : Attr::nil; } + +bool ID::IsDeprecated() const { return GetAttr(ATTR_DEPRECATED) != nullptr; } + +void ID::MakeDeprecated(ExprPtr deprecation) { + if ( IsDeprecated() ) + return; + + AddAttr(make_intrusive(ATTR_DEPRECATED, std::move(deprecation))); +} + +std::string ID::GetDeprecationWarning() const { + std::string result; + const auto& depr_attr = GetAttr(ATTR_DEPRECATED); + + if ( depr_attr ) + result = depr_attr->DeprecationMessage(); + + if ( result.empty() ) + return util::fmt("deprecated (%s)", Name()); + else + return util::fmt("deprecated (%s): %s", Name(), result.c_str()); +} + +void ID::AddAttr(AttrPtr a, bool is_redef) { + std::vector attrv{std::move(a)}; + auto attrs = make_intrusive(std::move(attrv), GetType(), false, IsGlobal()); + AddAttrs(std::move(attrs), is_redef); +} + +void ID::AddAttrs(AttributesPtr a, bool is_redef) { + if ( attrs ) + attrs->AddAttrs(a, is_redef); + else + attrs = std::move(a); + + UpdateValAttrs(); +} + +void ID::RemoveAttr(AttrTag a) { + if ( attrs ) + attrs->RemoveAttr(a); +} + +void ID::SetOption() { + if ( is_option ) + return; + + is_option = true; + + // option implied redefinable + if ( ! IsRedefinable() ) + AddAttr(make_intrusive(ATTR_REDEF)); +} + +void ID::EvalFunc(ExprPtr ef, ExprPtr ev) { + auto arg1 = make_intrusive(val); + auto args = make_intrusive(); + args->Append(std::move(arg1)); + args->Append(std::move(ev)); + auto ce = make_intrusive(std::move(ef), std::move(args)); + SetVal(ce->Eval(nullptr)); +} + +TraversalCode ID::Traverse(TraversalCallback* cb) const { + TraversalCode tc = cb->PreID(this); + HANDLE_TC_STMT_PRE(tc); + + if ( is_type ) { + tc = cb->PreTypedef(this); + HANDLE_TC_STMT_PRE(tc); + + tc = cb->PostTypedef(this); + HANDLE_TC_STMT_PRE(tc); + } + + // FIXME: Perhaps we should be checking at other than global scope. + else if ( val && IsFunc(val->GetType()->Tag()) && cb->current_scope == detail::global_scope() ) { + tc = val->AsFunc()->Traverse(cb); + HANDLE_TC_STMT_PRE(tc); + } + + else if ( ! is_enum_const ) { + tc = cb->PreDecl(this); + HANDLE_TC_STMT_PRE(tc); + + tc = cb->PostDecl(this); + HANDLE_TC_STMT_PRE(tc); + } + + tc = cb->PostID(this); + HANDLE_TC_EXPR_POST(tc); +} + +void ID::Error(const char* msg, const Obj* o2) { + Obj::Error(msg, o2, true); + SetType(error_type()); +} + +void ID::Describe(ODesc* d) const { d->Add(name); } + +void ID::DescribeExtended(ODesc* d) const { + d->Add(name); + + if ( type ) { + d->Add(" : "); + type->Describe(d); + } + + if ( val ) { + d->Add(" = "); + val->Describe(d); + } + + if ( attrs ) { + d->Add(" "); + attrs->Describe(d); + } +} + +void ID::DescribeReSTShort(ODesc* d) const { + if ( is_type ) + d->Add(":zeek:type:`"); + else + d->Add(":zeek:id:`"); + + d->Add(name); + d->Add("`"); + + if ( type ) { + d->Add(": "); + d->Add(":zeek:type:`"); + + if ( ! is_type && ! type->GetName().empty() ) + d->Add(type->GetName().c_str()); + else { + TypeTag t = type->Tag(); + + switch ( t ) { + case TYPE_TABLE: d->Add(type->IsSet() ? "set" : type_name(t)); break; + + case TYPE_FUNC: d->Add(type->AsFuncType()->FlavorString().c_str()); break; + + case TYPE_ENUM: + if ( is_type ) + d->Add(type_name(t)); + else + d->Add(zeekygen_mgr->GetEnumTypeName(Name()).c_str()); + break; + + default: d->Add(type_name(t)); break; + } + } + + d->Add("`"); + } + + if ( attrs ) { + d->SP(); + attrs->DescribeReST(d, true); + } +} + +void ID::DescribeReST(ODesc* d, bool roles_only) const { + if ( roles_only ) { + if ( is_type ) + d->Add(":zeek:type:`"); + else + d->Add(":zeek:id:`"); + d->Add(name); + d->Add("`"); + } + else { + if ( is_type ) + d->Add(".. zeek:type:: "); + else + d->Add(".. zeek:id:: "); + + d->Add(name); + + if ( auto sc = zeek::zeekygen::detail::source_code_range(this) ) { + d->PushIndent(); + d->Add(util::fmt(":source-code: %s", sc->data())); + d->PopIndentNoNL(); + } + } + + d->PushIndent(); + d->NL(); + + if ( type ) { + d->Add(":Type: "); + + if ( ! is_type && ! type->GetName().empty() ) { + d->Add(":zeek:type:`"); + d->Add(type->GetName()); + d->Add("`"); + } + else { + type->DescribeReST(d, roles_only); + + if ( IsFunc(type->Tag()) ) { + auto ft = type->AsFuncType(); + + if ( ft->Flavor() == FUNC_FLAVOR_EVENT || ft->Flavor() == FUNC_FLAVOR_HOOK ) { + const auto& protos = ft->Prototypes(); + + if ( protos.size() > 1 ) { + auto first = true; + + for ( const auto& proto : protos ) { + if ( first ) { + first = false; + continue; + } + + d->NL(); + d->Add(":Type: :zeek:type:`"); + d->Add(ft->FlavorString()); + d->Add("` ("); + proto.args->DescribeFieldsReST(d, true); + d->Add(")"); + } + } + } + } + } + + d->NL(); + } + + if ( attrs ) { + d->Add(":Attributes: "); + attrs->DescribeReST(d); + d->NL(); + } + + if ( val && type && type->Tag() != TYPE_FUNC && type->InternalType() != TYPE_INTERNAL_VOID && + // Values within Version module are likely to include a + // constantly-changing version number and be a frequent + // source of error/desynchronization, so don't include them. + ModuleName() != "Version" ) { + d->Add(":Default:"); + auto ii = zeekygen_mgr->GetIdentifierInfo(Name()); + auto redefs = ii->GetRedefs(); + const auto& iv = ! redefs.empty() && ii->InitialVal() ? ii->InitialVal() : val; + + if ( type->InternalType() == TYPE_INTERNAL_OTHER ) { + switch ( type->Tag() ) { + case TYPE_TABLE: + if ( iv->AsTable()->Length() == 0 ) { + d->Add(" ``{}``"); + d->NL(); + break; + } + // Fall-through. + + default: + d->NL(); + d->PushIndent(); + d->Add("::"); + d->NL(); + d->PushIndent(); + iv->DescribeReST(d); + d->PopIndent(); + d->PopIndent(); + } + } + + else { + d->SP(); + iv->DescribeReST(d); + d->NL(); + } + + for ( auto& ir : redefs ) { + if ( ! ir->init_expr ) + continue; + + if ( ir->ic == INIT_NONE ) + continue; + + std::string redef_str; + ODesc expr_desc; + ir->init_expr->Describe(&expr_desc); + redef_str = expr_desc.Description(); + redef_str = util::strreplace(redef_str, "\n", " "); + + d->Add(":Redefinition: "); + d->Add(util::fmt("from :doc:`/scripts/%s`", ir->from_script.data())); + d->NL(); + d->PushIndent(); + + if ( ir->ic == INIT_FULL ) + d->Add("``=``"); + else if ( ir->ic == INIT_EXTRA ) + d->Add("``+=``"); + else if ( ir->ic == INIT_REMOVE ) + d->Add("``-=``"); + else + assert(false); + + d->Add("::"); + d->NL(); + d->PushIndent(); + d->Add(redef_str.data()); + d->PopIndent(); + d->PopIndent(); + } + } +} #ifdef DEBUG -void ID::UpdateValID() - { - if ( IsGlobal() && val && name && name[0] != '#' ) - val->SetID(this); - } +void ID::UpdateValID() { + if ( IsGlobal() && val && name && name[0] != '#' ) + val->SetID(this); +} #endif -void ID::AddOptionHandler(FuncPtr callback, int priority) - { - option_handlers.emplace(priority, std::move(callback)); - } +void ID::AddOptionHandler(FuncPtr callback, int priority) { option_handlers.emplace(priority, std::move(callback)); } -std::vector ID::GetOptionHandlers() const - { - // multimap is sorted - // It might be worth caching this if we expect it to be called - // a lot... - std::vector v; - for ( auto& element : option_handlers ) - v.push_back(element.second.get()); - return v; - } +std::vector ID::GetOptionHandlers() const { + // multimap is sorted + // It might be worth caching this if we expect it to be called + // a lot... + std::vector v; + for ( auto& element : option_handlers ) + v.push_back(element.second.get()); + return v; +} -void ID::ClearOptInfo() - { - delete opt_info; - opt_info = nullptr; - } +void ID::ClearOptInfo() { + delete opt_info; + opt_info = nullptr; +} - } // namespace detail +} // namespace detail - } // namespace zeek +} // namespace zeek diff --git a/src/ID.h b/src/ID.h index d3c231e6a4..53b4b44fce 100644 --- a/src/ID.h +++ b/src/ID.h @@ -13,8 +13,7 @@ #include "zeek/Obj.h" #include "zeek/TraverseTypes.h" -namespace zeek - { +namespace zeek { class Func; class Val; @@ -31,29 +30,22 @@ using EnumTypePtr = IntrusivePtr; using ValPtr = IntrusivePtr; using FuncPtr = IntrusivePtr; - } +} // namespace zeek -namespace zeek::detail - { +namespace zeek::detail { class Attributes; class Expr; using ExprPtr = IntrusivePtr; -enum InitClass - { - INIT_NONE, - INIT_FULL, - INIT_EXTRA, - INIT_REMOVE, - INIT_SKIP, - }; -enum IDScope - { - SCOPE_FUNCTION, - SCOPE_MODULE, - SCOPE_GLOBAL - }; +enum InitClass { + INIT_NONE, + INIT_FULL, + INIT_EXTRA, + INIT_REMOVE, + INIT_SKIP, +}; +enum IDScope { SCOPE_FUNCTION, SCOPE_MODULE, SCOPE_GLOBAL }; class ID; using IDPtr = IntrusivePtr; @@ -61,130 +53,131 @@ using IDSet = std::unordered_set; class IDOptInfo; -class ID final : public Obj, public notifier::detail::Modifiable - { +class ID final : public Obj, public notifier::detail::Modifiable { public: - static inline const IDPtr nil; + static inline const IDPtr nil; - ID(const char* name, IDScope arg_scope, bool arg_is_export); + ID(const char* name, IDScope arg_scope, bool arg_is_export); - ~ID() override; + ~ID() override; - const char* Name() const { return name; } + const char* Name() const { return name; } - int Scope() const { return scope; } - bool IsGlobal() const { return scope != SCOPE_FUNCTION; } + int Scope() const { return scope; } + bool IsGlobal() const { return scope != SCOPE_FUNCTION; } - bool IsExport() const { return is_export; } - void SetExport() { is_export = true; } + bool IsExport() const { return is_export; } + void SetExport() { is_export = true; } - std::string ModuleName() const; + std::string ModuleName() const; - void SetType(TypePtr t); + void SetType(TypePtr t); - const TypePtr& GetType() const { return type; } + const TypePtr& GetType() const { return type; } - template IntrusivePtr GetType() const { return cast_intrusive(type); } + template + IntrusivePtr GetType() const { + return cast_intrusive(type); + } - bool IsType() const { return is_type; } + bool IsType() const { return is_type; } - void MakeType() { is_type = true; } + void MakeType() { is_type = true; } - void SetVal(ValPtr v); + void SetVal(ValPtr v); - void SetVal(ValPtr v, InitClass c); - void SetVal(ExprPtr ev, InitClass c); + void SetVal(ValPtr v, InitClass c); + void SetVal(ExprPtr ev, InitClass c); - bool HasVal() const { return val != nullptr; } + bool HasVal() const { return val != nullptr; } - const ValPtr& GetVal() const { return val; } + const ValPtr& GetVal() const { return val; } - void ClearVal(); + void ClearVal(); - void SetConst() { is_const = true; } - bool IsConst() const { return is_const; } + void SetConst() { is_const = true; } + bool IsConst() const { return is_const; } - void SetOption(); - bool IsOption() const { return is_option; } - bool IsBlank() const { return is_blank; }; + void SetOption(); + bool IsOption() const { return is_option; } + bool IsBlank() const { return is_blank; }; - void SetEnumConst() { is_enum_const = true; } - bool IsEnumConst() const { return is_enum_const; } + void SetEnumConst() { is_enum_const = true; } + bool IsEnumConst() const { return is_enum_const; } - void SetOffset(int arg_offset) { offset = arg_offset; } - int Offset() const { return offset; } + void SetOffset(int arg_offset) { offset = arg_offset; } + int Offset() const { return offset; } - bool IsRedefinable() const; + bool IsRedefinable() const; - void SetAttrs(AttributesPtr attr); - void AddAttr(AttrPtr a, bool is_redef = false); - void AddAttrs(AttributesPtr attr, bool is_redef = false); - void RemoveAttr(AttrTag a); - void UpdateValAttrs(); + void SetAttrs(AttributesPtr attr); + void AddAttr(AttrPtr a, bool is_redef = false); + void AddAttrs(AttributesPtr attr, bool is_redef = false); + void RemoveAttr(AttrTag a); + void UpdateValAttrs(); - const AttributesPtr& GetAttrs() const { return attrs; } + const AttributesPtr& GetAttrs() const { return attrs; } - const AttrPtr& GetAttr(AttrTag t) const; + const AttrPtr& GetAttr(AttrTag t) const; - bool IsDeprecated() const; + bool IsDeprecated() const; - void MakeDeprecated(ExprPtr deprecation); + void MakeDeprecated(ExprPtr deprecation); - std::string GetDeprecationWarning() const; + std::string GetDeprecationWarning() const; - void Error(const char* msg, const Obj* o2 = nullptr); + void Error(const char* msg, const Obj* o2 = nullptr); - void Describe(ODesc* d) const override; - // Adds type and value to description. - void DescribeExtended(ODesc* d) const; - // Produces a description that's reST-ready. - void DescribeReST(ODesc* d, bool roles_only = false) const; - void DescribeReSTShort(ODesc* d) const; + void Describe(ODesc* d) const override; + // Adds type and value to description. + void DescribeExtended(ODesc* d) const; + // Produces a description that's reST-ready. + void DescribeReST(ODesc* d, bool roles_only = false) const; + void DescribeReSTShort(ODesc* d) const; - bool DoInferReturnType() const { return infer_return_type; } - void SetInferReturnType(bool infer) { infer_return_type = infer; } + bool DoInferReturnType() const { return infer_return_type; } + void SetInferReturnType(bool infer) { infer_return_type = infer; } - TraversalCode Traverse(TraversalCallback* cb) const; + TraversalCode Traverse(TraversalCallback* cb) const; - bool HasOptionHandlers() const { return ! option_handlers.empty(); } + bool HasOptionHandlers() const { return ! option_handlers.empty(); } - void AddOptionHandler(FuncPtr callback, int priority); - std::vector GetOptionHandlers() const; + void AddOptionHandler(FuncPtr callback, int priority); + std::vector GetOptionHandlers() const; - IDOptInfo* GetOptInfo() const { return opt_info; } - void ClearOptInfo(); + IDOptInfo* GetOptInfo() const { return opt_info; } + void ClearOptInfo(); protected: - void EvalFunc(ExprPtr ef, ExprPtr ev); + void EvalFunc(ExprPtr ef, ExprPtr ev); #ifdef DEBUG - void UpdateValID(); + void UpdateValID(); #endif - const char* name; - IDScope scope; - bool is_export; - bool infer_return_type; - TypePtr type; - bool is_const, is_enum_const, is_type, is_option, is_blank; - int offset; - ValPtr val; - AttributesPtr attrs; + const char* name; + IDScope scope; + bool is_export; + bool infer_return_type; + TypePtr type; + bool is_const, is_enum_const, is_type, is_option, is_blank; + int offset; + ValPtr val; + AttributesPtr attrs; - // contains list of functions that are called when an option changes - std::multimap option_handlers; + // contains list of functions that are called when an option changes + std::multimap option_handlers; - // Information managed by script optimization. We package this - // up into a separate object for purposes of modularity, and, - // via the associated pointer, to allow it to be modified in - // contexts where the ID is itself "const". - IDOptInfo* opt_info; - }; + // Information managed by script optimization. We package this + // up into a separate object for purposes of modularity, and, + // via the associated pointer, to allow it to be modified in + // contexts where the ID is itself "const". + IDOptInfo* opt_info; +}; - } // namespace zeek::detail +} // namespace zeek::detail -namespace zeek::id - { +namespace zeek::id { /** * Lookup an ID in the global module and return it, if one exists; @@ -208,10 +201,10 @@ const TypePtr& find_type(std::string_view name); * @param name The identifier name to lookup * @return The type of the identifier. */ -template IntrusivePtr find_type(std::string_view name) - { - return cast_intrusive(find_type(name)); - } +template +IntrusivePtr find_type(std::string_view name) { + return cast_intrusive(find_type(name)); +} /** * Lookup an ID by its name and return its value. A fatal occurs if the ID @@ -227,10 +220,10 @@ const ValPtr& find_val(std::string_view name); * @param name The identifier name to lookup * @return The current value of the identifier. */ -template IntrusivePtr find_val(std::string_view name) - { - return cast_intrusive(find_val(name)); - } +template +IntrusivePtr find_val(std::string_view name) { + return cast_intrusive(find_val(name)); +} /** * Lookup an ID by its name and return its value. A fatal occurs if the ID @@ -246,10 +239,10 @@ const ValPtr& find_const(std::string_view name); * @param name The identifier name to lookup * @return The current value of the identifier. */ -template IntrusivePtr find_const(std::string_view name) - { - return cast_intrusive(find_const(name)); - } +template +IntrusivePtr find_const(std::string_view name) { + return cast_intrusive(find_const(name)); +} /** * Lookup an ID by its name and return the function it references. @@ -271,10 +264,9 @@ extern TableTypePtr count_set; extern VectorTypePtr string_vec; extern VectorTypePtr index_vec; -namespace detail - { +namespace detail { void init_types(); - } // namespace detail - } // namespace zeek::id +} // namespace detail +} // namespace zeek::id diff --git a/src/IP.cc b/src/IP.cc index e3ac2e480f..7fe2075cc5 100644 --- a/src/IP.cc +++ b/src/IP.cc @@ -13,833 +13,702 @@ #include "zeek/Var.h" #include "zeek/ZeekString.h" -namespace zeek - { - -bool IPv6_Hdr::IsOptionTruncated(uint16_t off) const - { - if ( Length() < off ) - { - reporter->Weird("truncated_IPv6_option"); - return true; - } - - return false; - } - -static VectorValPtr BuildOptionsVal(const u_char* data, int len) - { - auto vv = make_intrusive(id::find_type("ip6_options")); - - while ( len > 0 && static_cast(len) >= sizeof(struct ip6_opt) ) - { - static auto ip6_option_type = id::find_type("ip6_option"); - const struct ip6_opt* opt = (const struct ip6_opt*)data; - auto rv = make_intrusive(ip6_option_type); - rv->Assign(0, opt->ip6o_type); - - if ( opt->ip6o_type == 0 ) - { - // Pad1 option - rv->Assign(1, 0); - rv->Assign(2, val_mgr->EmptyString()); - data += sizeof(uint8_t); - len -= sizeof(uint8_t); - } - else - { - // PadN or other option - uint16_t off = 2 * sizeof(uint8_t); - - if ( len < opt->ip6o_len + off ) - break; - - rv->Assign(1, opt->ip6o_len); - rv->Assign(2, new String(data + off, opt->ip6o_len, true)); - data += opt->ip6o_len + off; - len -= opt->ip6o_len + off; - } - - vv->Assign(vv->Size(), std::move(rv)); - } - - return vv; - } - -RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const - { - RecordValPtr rv; - - switch ( type ) - { - case IPPROTO_IPV6: - { - static auto ip6_hdr_type = id::find_type("ip6_hdr"); - rv = make_intrusive(ip6_hdr_type); - const struct ip6_hdr* ip6 = (const struct ip6_hdr*)data; - rv->Assign(0, static_cast(ntohl(ip6->ip6_flow) & 0x0ff00000) >> 20); - rv->Assign(1, static_cast(ntohl(ip6->ip6_flow) & 0x000fffff)); - rv->Assign(2, ntohs(ip6->ip6_plen)); - rv->Assign(3, ip6->ip6_nxt); - rv->Assign(4, ip6->ip6_hlim); - rv->Assign(5, make_intrusive(IPAddr(ip6->ip6_src))); - rv->Assign(6, make_intrusive(IPAddr(ip6->ip6_dst))); - if ( ! chain ) - chain = make_intrusive(id::find_type("ip6_ext_hdr_chain")); - rv->Assign(7, std::move(chain)); - } - break; - - case IPPROTO_HOPOPTS: - { - uint16_t off = 2 * sizeof(uint8_t); - if ( IsOptionTruncated(off) ) - return nullptr; - - static auto ip6_hopopts_type = id::find_type("ip6_hopopts"); - rv = make_intrusive(ip6_hopopts_type); - const struct ip6_hbh* hbh = (const struct ip6_hbh*)data; - rv->Assign(0, hbh->ip6h_nxt); - rv->Assign(1, hbh->ip6h_len); - rv->Assign(2, BuildOptionsVal(data + off, Length() - off)); - } - break; - - case IPPROTO_DSTOPTS: - { - uint16_t off = 2 * sizeof(uint8_t); - if ( IsOptionTruncated(off) ) - return nullptr; - - static auto ip6_dstopts_type = id::find_type("ip6_dstopts"); - rv = make_intrusive(ip6_dstopts_type); - const struct ip6_dest* dst = (const struct ip6_dest*)data; - rv->Assign(0, dst->ip6d_nxt); - rv->Assign(1, dst->ip6d_len); - rv->Assign(2, BuildOptionsVal(data + off, Length() - off)); - } - break; - - case IPPROTO_ROUTING: - { - uint16_t off = 4 * sizeof(uint8_t); - if ( IsOptionTruncated(off) ) - return nullptr; - - static auto ip6_routing_type = id::find_type("ip6_routing"); - rv = make_intrusive(ip6_routing_type); - const struct ip6_rthdr* rt = (const struct ip6_rthdr*)data; - rv->Assign(0, rt->ip6r_nxt); - rv->Assign(1, rt->ip6r_len); - rv->Assign(2, rt->ip6r_type); - rv->Assign(3, rt->ip6r_segleft); - rv->Assign(4, new String(data + off, Length() - off, true)); - } - break; - - case IPPROTO_FRAGMENT: - { - static auto ip6_fragment_type = id::find_type("ip6_fragment"); - rv = make_intrusive(ip6_fragment_type); - const struct ip6_frag* frag = (const struct ip6_frag*)data; - rv->Assign(0, frag->ip6f_nxt); - rv->Assign(1, frag->ip6f_reserved); - rv->Assign(2, (ntohs(frag->ip6f_offlg) & 0xfff8) >> 3); - rv->Assign(3, (ntohs(frag->ip6f_offlg) & 0x0006) >> 1); - rv->Assign(4, static_cast(ntohs(frag->ip6f_offlg) & 0x0001)); - rv->Assign(5, static_cast(ntohl(frag->ip6f_ident))); - } - break; - - case IPPROTO_AH: - { - static auto ip6_ah_type = id::find_type("ip6_ah"); - rv = make_intrusive(ip6_ah_type); - rv->Assign(0, ((ip6_ext*)data)->ip6e_nxt); - rv->Assign(1, ((ip6_ext*)data)->ip6e_len); - rv->Assign(2, ntohs(((uint16_t*)data)[1])); - rv->Assign(3, static_cast(ntohl(((uint32_t*)data)[1]))); - - if ( Length() >= 12 ) - { - // Sequence Number and ICV fields can only be extracted if - // Payload Len was non-zero for this header. - rv->Assign(4, static_cast(ntohl(((uint32_t*)data)[2]))); - uint16_t off = 3 * sizeof(uint32_t); - rv->Assign(5, new String(data + off, Length() - off, true)); - } - } - break; - - case IPPROTO_ESP: - { - static auto ip6_esp_type = id::find_type("ip6_esp"); - rv = make_intrusive(ip6_esp_type); - const uint32_t* esp = (const uint32_t*)data; - rv->Assign(0, static_cast(ntohl(esp[0]))); - rv->Assign(1, static_cast(ntohl(esp[1]))); - } - break; - - case IPPROTO_MOBILITY: - { - static auto ip6_mob_type = id::find_type("ip6_mobility_hdr"); - rv = make_intrusive(ip6_mob_type); - const struct ip6_mobility* mob = (const struct ip6_mobility*)data; - rv->Assign(0, mob->ip6mob_payload); - rv->Assign(1, mob->ip6mob_len); - rv->Assign(2, mob->ip6mob_type); - rv->Assign(3, mob->ip6mob_rsv); - rv->Assign(4, ntohs(mob->ip6mob_chksum)); - - static auto ip6_mob_msg_type = id::find_type("ip6_mobility_msg"); - auto msg = make_intrusive(ip6_mob_msg_type); - msg->Assign(0, mob->ip6mob_type); - - uint16_t off = sizeof(ip6_mobility); - const u_char* msg_data = data + off; - - static auto ip6_mob_brr_type = id::find_type("ip6_mobility_brr"); - static auto ip6_mob_hoti_type = id::find_type("ip6_mobility_hoti"); - static auto ip6_mob_coti_type = id::find_type("ip6_mobility_coti"); - static auto ip6_mob_hot_type = id::find_type("ip6_mobility_hot"); - static auto ip6_mob_cot_type = id::find_type("ip6_mobility_cot"); - static auto ip6_mob_bu_type = id::find_type("ip6_mobility_bu"); - static auto ip6_mob_back_type = id::find_type("ip6_mobility_back"); - static auto ip6_mob_be_type = id::find_type("ip6_mobility_be"); - - switch ( mob->ip6mob_type ) - { - case 0: - { - off += sizeof(uint16_t); - if ( IsOptionTruncated(off) ) - break; - - auto m = make_intrusive(ip6_mob_brr_type); - m->Assign(0, ntohs(*((uint16_t*)msg_data))); - m->Assign(1, BuildOptionsVal(data + off, Length() - off)); - msg->Assign(1, std::move(m)); - break; - } - - case 1: - { - off += sizeof(uint16_t) + sizeof(uint64_t); - if ( IsOptionTruncated(off) ) - break; - - auto m = make_intrusive(ip6_mob_hoti_type); - m->Assign(0, ntohs(*((uint16_t*)msg_data))); - m->Assign(1, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t))))); - m->Assign(2, BuildOptionsVal(data + off, Length() - off)); - msg->Assign(2, std::move(m)); - break; - } - - case 2: - { - off += sizeof(uint16_t) + sizeof(uint64_t); - if ( IsOptionTruncated(off) ) - break; - - auto m = make_intrusive(ip6_mob_coti_type); - m->Assign(0, ntohs(*((uint16_t*)msg_data))); - m->Assign(1, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t))))); - m->Assign(2, BuildOptionsVal(data + off, Length() - off)); - msg->Assign(3, std::move(m)); - break; - } - - case 3: - { - off += sizeof(uint16_t) + 2 * sizeof(uint64_t); - if ( IsOptionTruncated(off) ) - break; - - auto m = make_intrusive(ip6_mob_hot_type); - m->Assign(0, ntohs(*((uint16_t*)msg_data))); - m->Assign(1, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t))))); - m->Assign( - 2, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t) + sizeof(uint64_t))))); - m->Assign(3, BuildOptionsVal(data + off, Length() - off)); - msg->Assign(4, std::move(m)); - break; - } - - case 4: - { - off += sizeof(uint16_t) + 2 * sizeof(uint64_t); - if ( IsOptionTruncated(off) ) - break; - - auto m = make_intrusive(ip6_mob_cot_type); - m->Assign(0, ntohs(*((uint16_t*)msg_data))); - m->Assign(1, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t))))); - m->Assign( - 2, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t) + sizeof(uint64_t))))); - m->Assign(3, BuildOptionsVal(data + off, Length() - off)); - msg->Assign(5, std::move(m)); - break; - } - - case 5: - { - off += 3 * sizeof(uint16_t); - if ( IsOptionTruncated(off) ) - break; - - auto m = make_intrusive(ip6_mob_bu_type); - m->Assign(0, ntohs(*((uint16_t*)msg_data))); - m->Assign(1, static_cast( - ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x8000)); - m->Assign(2, static_cast( - ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x4000)); - m->Assign(3, static_cast( - ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x2000)); - m->Assign(4, static_cast( - ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x1000)); - m->Assign(5, ntohs(*((uint16_t*)(msg_data + 2 * sizeof(uint16_t))))); - m->Assign(6, BuildOptionsVal(data + off, Length() - off)); - msg->Assign(6, std::move(m)); - break; - } - - case 6: - { - off += 3 * sizeof(uint16_t); - if ( IsOptionTruncated(off) ) - break; - - auto m = make_intrusive(ip6_mob_back_type); - m->Assign(0, *((uint8_t*)msg_data)); - m->Assign(1, - static_cast(*((uint8_t*)(msg_data + sizeof(uint8_t))) & 0x80)); - m->Assign(2, ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t))))); - m->Assign(3, ntohs(*((uint16_t*)(msg_data + 2 * sizeof(uint16_t))))); - m->Assign(4, BuildOptionsVal(data + off, Length() - off)); - msg->Assign(7, std::move(m)); - break; - } - - case 7: - { - off += sizeof(uint16_t) + sizeof(in6_addr); - if ( IsOptionTruncated(off) ) - break; - - auto m = make_intrusive(ip6_mob_be_type); - m->Assign(0, *((uint8_t*)msg_data)); - const in6_addr* hoa = (const in6_addr*)(msg_data + sizeof(uint16_t)); - m->Assign(1, make_intrusive(IPAddr(*hoa))); - m->Assign(2, BuildOptionsVal(data + off, Length() - off)); - msg->Assign(8, std::move(m)); - break; - } - - default: - reporter->Weird("unknown_mobility_type", util::fmt("%d", mob->ip6mob_type)); - break; - } - - rv->Assign(5, std::move(msg)); - } - break; - - default: - break; - } - - return rv; - } - -RecordValPtr IPv6_Hdr::ToVal() const - { - return ToVal(nullptr); - } - -IPAddr IP_Hdr::IPHeaderSrcAddr() const - { - return ip4 ? IPAddr(ip4->ip_src) : IPAddr(ip6->ip6_src); - } - -IPAddr IP_Hdr::IPHeaderDstAddr() const - { - return ip4 ? IPAddr(ip4->ip_dst) : IPAddr(ip6->ip6_dst); - } - -IPAddr IP_Hdr::SrcAddr() const - { - return ip4 ? IPAddr(ip4->ip_src) : ip6_hdrs->SrcAddr(); - } - -IPAddr IP_Hdr::DstAddr() const - { - return ip4 ? IPAddr(ip4->ip_dst) : ip6_hdrs->DstAddr(); - } - -RecordValPtr IP_Hdr::ToIPHdrVal() const - { - RecordValPtr rval; - - if ( ip4 ) - { - static auto ip4_hdr_type = id::find_type("ip4_hdr"); - rval = make_intrusive(ip4_hdr_type); - rval->Assign(0, ip4->ip_hl * 4); - rval->Assign(1, ip4->ip_tos); - rval->Assign(2, ntohs(ip4->ip_len)); - rval->Assign(3, ntohs(ip4->ip_id)); - rval->Assign(4, DF()); - rval->Assign(5, MF()); - rval->Assign(6, FragOffset()); // 13 bit offset as multiple of 8 - rval->Assign(7, ip4->ip_ttl); - rval->Assign(8, ip4->ip_p); - rval->Assign(9, ntohs(ip4->ip_sum)); - rval->Assign(10, make_intrusive(ip4->ip_src.s_addr)); - rval->Assign(11, make_intrusive(ip4->ip_dst.s_addr)); - } - else - { - rval = ((*ip6_hdrs)[0])->ToVal(ip6_hdrs->ToVal()); - } - - return rval; - } - -RecordValPtr IP_Hdr::ToPktHdrVal() const - { - static auto pkt_hdr_type = id::find_type("pkt_hdr"); - return ToPktHdrVal(make_intrusive(pkt_hdr_type), 0); - } - -RecordValPtr IP_Hdr::ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const - { - static auto tcp_hdr_type = id::find_type("tcp_hdr"); - static auto udp_hdr_type = id::find_type("udp_hdr"); - static auto icmp_hdr_type = id::find_type("icmp_hdr"); - - if ( ip4 ) - pkt_hdr->Assign(sindex + 0, ToIPHdrVal()); - else - pkt_hdr->Assign(sindex + 1, ToIPHdrVal()); - - // L4 header. - const u_char* data = Payload(); - - int proto = NextProto(); - switch ( proto ) - { - case IPPROTO_TCP: - { - if ( PayloadLen() < sizeof(struct tcphdr) ) - break; - - const struct tcphdr* tp = (const struct tcphdr*)data; - auto tcp_hdr = make_intrusive(tcp_hdr_type); - - int tcp_hdr_len = tp->th_off * 4; - - // account for cases in which the payload length in the TCP header is not set, - // or is set to an impossible value. In these cases, return 0. - int data_len = 0; - auto payload_len = PayloadLen(); - if ( payload_len >= tcp_hdr_len ) - data_len = payload_len - tcp_hdr_len; - - tcp_hdr->Assign(0, val_mgr->Port(ntohs(tp->th_sport), TRANSPORT_TCP)); - tcp_hdr->Assign(1, val_mgr->Port(ntohs(tp->th_dport), TRANSPORT_TCP)); - tcp_hdr->Assign(2, static_cast(ntohl(tp->th_seq))); - tcp_hdr->Assign(3, static_cast(ntohl(tp->th_ack))); - tcp_hdr->Assign(4, tcp_hdr_len); - tcp_hdr->Assign(5, data_len); - tcp_hdr->Assign(6, tp->th_x2); - tcp_hdr->Assign(7, tp->th_flags); - tcp_hdr->Assign(8, ntohs(tp->th_win)); - - pkt_hdr->Assign(sindex + 2, std::move(tcp_hdr)); - break; - } - - case IPPROTO_UDP: - { - if ( PayloadLen() < sizeof(struct udphdr) ) - break; - - const struct udphdr* up = (const struct udphdr*)data; - auto udp_hdr = make_intrusive(udp_hdr_type); - - udp_hdr->Assign(0, val_mgr->Port(ntohs(up->uh_sport), TRANSPORT_UDP)); - udp_hdr->Assign(1, val_mgr->Port(ntohs(up->uh_dport), TRANSPORT_UDP)); - udp_hdr->Assign(2, ntohs(up->uh_ulen)); - - pkt_hdr->Assign(sindex + 3, std::move(udp_hdr)); - break; - } - - case IPPROTO_ICMP: - { - if ( PayloadLen() < sizeof(struct icmp) ) - break; - - const struct icmp* icmpp = (const struct icmp*)data; - auto icmp_hdr = make_intrusive(icmp_hdr_type); - - icmp_hdr->Assign(0, icmpp->icmp_type); - - pkt_hdr->Assign(sindex + 4, std::move(icmp_hdr)); - break; - } - - case IPPROTO_ICMPV6: - { - if ( PayloadLen() < sizeof(struct icmp6_hdr) ) - break; - - const struct icmp6_hdr* icmpp = (const struct icmp6_hdr*)data; - auto icmp_hdr = make_intrusive(icmp_hdr_type); - - icmp_hdr->Assign(0, icmpp->icmp6_type); - - pkt_hdr->Assign(sindex + 4, std::move(icmp_hdr)); - break; - } - - default: - { - // This is not a protocol we understand. - break; - } - } - - return pkt_hdr; - } - -static inline bool isIPv6ExtHeader(uint8_t type) - { - switch ( type ) - { - case IPPROTO_HOPOPTS: - case IPPROTO_ROUTING: - case IPPROTO_DSTOPTS: - case IPPROTO_FRAGMENT: - case IPPROTO_AH: - case IPPROTO_ESP: - case IPPROTO_MOBILITY: - return true; - default: - return false; - } - } - -IPv6_Hdr_Chain::~IPv6_Hdr_Chain() - { - for ( size_t i = 0; i < chain.size(); ++i ) - delete chain[i]; - delete homeAddr; - delete finalDst; - } - -void IPv6_Hdr_Chain::Init(const struct ip6_hdr* ip6, uint64_t total_len, bool set_next, - uint16_t next) - { - length = 0; - uint8_t current_type, next_type; - next_type = IPPROTO_IPV6; - const u_char* hdrs = (const u_char*)ip6; - - if ( total_len < (int)sizeof(struct ip6_hdr) ) - { - reporter->InternalWarning("truncated IP header in IPv6_HdrChain::Init"); - return; - } - - do - { - // We can't determine a given header's length if there's less than - // two bytes of data available (2nd byte of extension headers is length) - if ( total_len < 2 ) - return; - - current_type = next_type; - IPv6_Hdr* p = new IPv6_Hdr(current_type, hdrs); - - next_type = p->NextHdr(); - uint16_t cur_len = p->Length(); - - // If this header is truncated, don't add it to chain, don't go further. - if ( cur_len > total_len ) - { - delete p; - return; - } - - if ( set_next && next_type == IPPROTO_FRAGMENT ) - { - p->ChangeNext(next); - next_type = next; - } - - chain.push_back(p); - - // Check for routing headers and remember final destination address. - if ( current_type == IPPROTO_ROUTING ) - ProcessRoutingHeader((const struct ip6_rthdr*)hdrs, cur_len); - - // Only Mobile IPv6 has a destination option we care about right now. - if ( current_type == IPPROTO_DSTOPTS ) - ProcessDstOpts((const struct ip6_dest*)hdrs, cur_len); - - hdrs += cur_len; - length += cur_len; - total_len -= cur_len; - - } while ( current_type != IPPROTO_FRAGMENT && current_type != IPPROTO_ESP && - current_type != IPPROTO_MOBILITY && isIPv6ExtHeader(next_type) ); - } - -bool IPv6_Hdr_Chain::IsFragment() const - { - if ( chain.empty() ) - { - reporter->InternalWarning("empty IPv6 header chain"); - return false; - } - - return chain[chain.size() - 1]->Type() == IPPROTO_FRAGMENT; - } - -IPAddr IPv6_Hdr_Chain::SrcAddr() const - { - if ( homeAddr ) - return {*homeAddr}; - - if ( chain.empty() ) - { - reporter->InternalWarning("empty IPv6 header chain"); - return {}; - } - - return IPAddr{((const struct ip6_hdr*)(chain[0]->Data()))->ip6_src}; - } - -IPAddr IPv6_Hdr_Chain::DstAddr() const - { - if ( finalDst ) - return {*finalDst}; - - if ( chain.empty() ) - { - reporter->InternalWarning("empty IPv6 header chain"); - return {}; - } - - return IPAddr{((const struct ip6_hdr*)(chain[0]->Data()))->ip6_dst}; - } - -void IPv6_Hdr_Chain::ProcessRoutingHeader(const struct ip6_rthdr* r, uint16_t len) - { - if ( finalDst ) - { - // RFC 2460 section 4.1 says Routing should occur at most once. - reporter->Weird(SrcAddr(), DstAddr(), "multiple_routing_headers"); - return; - } - - // Last 16 bytes of header (for all known types) is the address we want. - const in6_addr* addr = (const in6_addr*)(((const u_char*)r) + len - 16); - - switch ( r->ip6r_type ) - { - case 0: // Defined by RFC 2460, deprecated by RFC 5095 - { - if ( r->ip6r_segleft > 0 && r->ip6r_len >= 2 ) - { - if ( r->ip6r_len % 2 == 0 ) - finalDst = new IPAddr(*addr); - else - reporter->Weird(SrcAddr(), DstAddr(), "odd_routing0_len"); - } - - // Always raise a weird since this type is deprecated. - reporter->Weird(SrcAddr(), DstAddr(), "routing0_hdr"); - } - break; - - case 2: // Defined by Mobile IPv6 RFC 6275. - { - if ( r->ip6r_segleft > 0 ) - { - if ( r->ip6r_len == 2 ) - finalDst = new IPAddr(*addr); - else - reporter->Weird(SrcAddr(), DstAddr(), "bad_routing2_len"); - } - } - break; - - default: - reporter->Weird(SrcAddr(), DstAddr(), "unknown_routing_type", - util::fmt("%d", r->ip6r_type)); - break; - } - } - -void IPv6_Hdr_Chain::ProcessDstOpts(const struct ip6_dest* d, uint16_t len) - { - // Skip two bytes to get the beginning of the first option structure. These - // two bytes are the protocol for the next header and extension header length, - // already known to exist before calling this method. See header format: - // https://datatracker.ietf.org/doc/html/rfc8200#section-4.6 - assert(len >= 2); - - const u_char* data = (const u_char*)d; - len -= 2 * sizeof(uint8_t); - data += 2 * sizeof(uint8_t); - - while ( len > 0 ) - { - const struct ip6_opt* opt = (const struct ip6_opt*)data; - switch ( opt->ip6o_type ) - { - case 0: - // If option type is zero, it's a Pad0 and can be just a single - // byte in width. Skip over it. - data += sizeof(uint8_t); - len -= sizeof(uint8_t); - break; - default: - { - // Double-check that the len can hold the whole option structure. - // Otherwise we get a buffer-overflow when we check the option_len. - // Also check that it holds everything for the option itself. - if ( len < sizeof(struct ip6_opt) || len < sizeof(struct ip6_opt) + opt->ip6o_len ) - { - reporter->Weird(SrcAddr(), DstAddr(), "bad_ipv6_dest_opt_len"); - len = 0; - break; - } - - if ( opt->ip6o_type == - 201 ) // Home Address Option, Mobile IPv6 RFC 6275 section 6.3 - { - if ( opt->ip6o_len == sizeof(struct in6_addr) ) - { - if ( homeAddr ) - reporter->Weird(SrcAddr(), DstAddr(), "multiple_home_addr_opts"); - else - homeAddr = new IPAddr( - *((const in6_addr*)(data + sizeof(struct ip6_opt)))); - } - else - reporter->Weird(SrcAddr(), DstAddr(), "bad_home_addr_len"); - } - - data += sizeof(struct ip6_opt) + opt->ip6o_len; - len -= sizeof(struct ip6_opt) + opt->ip6o_len; - } - break; - } - } - } - -VectorValPtr IPv6_Hdr_Chain::ToVal() const - { - static auto ip6_ext_hdr_type = id::find_type("ip6_ext_hdr"); - static auto ip6_hopopts_type = id::find_type("ip6_hopopts"); - static auto ip6_dstopts_type = id::find_type("ip6_dstopts"); - static auto ip6_routing_type = id::find_type("ip6_routing"); - static auto ip6_fragment_type = id::find_type("ip6_fragment"); - static auto ip6_ah_type = id::find_type("ip6_ah"); - static auto ip6_esp_type = id::find_type("ip6_esp"); - static auto ip6_ext_hdr_chain_type = id::find_type("ip6_ext_hdr_chain"); - auto rval = make_intrusive(ip6_ext_hdr_chain_type); - - for ( size_t i = 1; i < chain.size(); ++i ) - { - auto v = chain[i]->ToVal(); - auto ext_hdr = make_intrusive(ip6_ext_hdr_type); - uint8_t type = chain[i]->Type(); - ext_hdr->Assign(0, type); - - switch ( type ) - { - case IPPROTO_HOPOPTS: - ext_hdr->Assign(1, std::move(v)); - break; - case IPPROTO_DSTOPTS: - ext_hdr->Assign(2, std::move(v)); - break; - case IPPROTO_ROUTING: - ext_hdr->Assign(3, std::move(v)); - break; - case IPPROTO_FRAGMENT: - ext_hdr->Assign(4, std::move(v)); - break; - case IPPROTO_AH: - ext_hdr->Assign(5, std::move(v)); - break; - case IPPROTO_ESP: - ext_hdr->Assign(6, std::move(v)); - break; - case IPPROTO_MOBILITY: - ext_hdr->Assign(7, std::move(v)); - break; - default: - reporter->InternalWarning("IPv6_Hdr_Chain bad header %d", type); - continue; - } - - rval->Assign(rval->Size(), std::move(ext_hdr)); - } - - return rval; - } - -IP_Hdr* IP_Hdr::Copy() const - { - char* new_hdr = new char[HdrLen()]; - - if ( ip4 ) - { - memcpy(new_hdr, ip4, HdrLen()); - return new IP_Hdr((const struct ip*)new_hdr, true); - } - - memcpy(new_hdr, ip6, HdrLen()); - const struct ip6_hdr* new_ip6 = (const struct ip6_hdr*)new_hdr; - IPv6_Hdr_Chain* new_ip6_hdrs = ip6_hdrs->Copy(new_ip6); - return new IP_Hdr(new_ip6, true, 0, new_ip6_hdrs); - } - -IPv6_Hdr_Chain* IPv6_Hdr_Chain::Copy(const ip6_hdr* new_hdr) const - { - IPv6_Hdr_Chain* rval = new IPv6_Hdr_Chain; - rval->length = length; - - if ( homeAddr ) - rval->homeAddr = new IPAddr(*homeAddr); - - if ( finalDst ) - rval->finalDst = new IPAddr(*finalDst); - - if ( chain.empty() ) - { - reporter->InternalWarning("empty IPv6 header chain"); - delete rval; - return nullptr; - } - - const u_char* new_data = (const u_char*)new_hdr; - const u_char* old_data = chain[0]->Data(); - - for ( size_t i = 0; i < chain.size(); ++i ) - { - int off = chain[i]->Data() - old_data; - rval->chain.push_back(new IPv6_Hdr(chain[i]->Type(), new_data + off)); - } - - return rval; - } - - } // namespace zeek +namespace zeek { + +bool IPv6_Hdr::IsOptionTruncated(uint16_t off) const { + if ( Length() < off ) { + reporter->Weird("truncated_IPv6_option"); + return true; + } + + return false; +} + +static VectorValPtr BuildOptionsVal(const u_char* data, int len) { + auto vv = make_intrusive(id::find_type("ip6_options")); + + while ( len > 0 && static_cast(len) >= sizeof(struct ip6_opt) ) { + static auto ip6_option_type = id::find_type("ip6_option"); + const struct ip6_opt* opt = (const struct ip6_opt*)data; + auto rv = make_intrusive(ip6_option_type); + rv->Assign(0, opt->ip6o_type); + + if ( opt->ip6o_type == 0 ) { + // Pad1 option + rv->Assign(1, 0); + rv->Assign(2, val_mgr->EmptyString()); + data += sizeof(uint8_t); + len -= sizeof(uint8_t); + } + else { + // PadN or other option + uint16_t off = 2 * sizeof(uint8_t); + + if ( len < opt->ip6o_len + off ) + break; + + rv->Assign(1, opt->ip6o_len); + rv->Assign(2, new String(data + off, opt->ip6o_len, true)); + data += opt->ip6o_len + off; + len -= opt->ip6o_len + off; + } + + vv->Assign(vv->Size(), std::move(rv)); + } + + return vv; +} + +RecordValPtr IPv6_Hdr::ToVal(VectorValPtr chain) const { + RecordValPtr rv; + + switch ( type ) { + case IPPROTO_IPV6: { + static auto ip6_hdr_type = id::find_type("ip6_hdr"); + rv = make_intrusive(ip6_hdr_type); + const struct ip6_hdr* ip6 = (const struct ip6_hdr*)data; + rv->Assign(0, static_cast(ntohl(ip6->ip6_flow) & 0x0ff00000) >> 20); + rv->Assign(1, static_cast(ntohl(ip6->ip6_flow) & 0x000fffff)); + rv->Assign(2, ntohs(ip6->ip6_plen)); + rv->Assign(3, ip6->ip6_nxt); + rv->Assign(4, ip6->ip6_hlim); + rv->Assign(5, make_intrusive(IPAddr(ip6->ip6_src))); + rv->Assign(6, make_intrusive(IPAddr(ip6->ip6_dst))); + if ( ! chain ) + chain = make_intrusive(id::find_type("ip6_ext_hdr_chain")); + rv->Assign(7, std::move(chain)); + } break; + + case IPPROTO_HOPOPTS: { + uint16_t off = 2 * sizeof(uint8_t); + if ( IsOptionTruncated(off) ) + return nullptr; + + static auto ip6_hopopts_type = id::find_type("ip6_hopopts"); + rv = make_intrusive(ip6_hopopts_type); + const struct ip6_hbh* hbh = (const struct ip6_hbh*)data; + rv->Assign(0, hbh->ip6h_nxt); + rv->Assign(1, hbh->ip6h_len); + rv->Assign(2, BuildOptionsVal(data + off, Length() - off)); + } break; + + case IPPROTO_DSTOPTS: { + uint16_t off = 2 * sizeof(uint8_t); + if ( IsOptionTruncated(off) ) + return nullptr; + + static auto ip6_dstopts_type = id::find_type("ip6_dstopts"); + rv = make_intrusive(ip6_dstopts_type); + const struct ip6_dest* dst = (const struct ip6_dest*)data; + rv->Assign(0, dst->ip6d_nxt); + rv->Assign(1, dst->ip6d_len); + rv->Assign(2, BuildOptionsVal(data + off, Length() - off)); + } break; + + case IPPROTO_ROUTING: { + uint16_t off = 4 * sizeof(uint8_t); + if ( IsOptionTruncated(off) ) + return nullptr; + + static auto ip6_routing_type = id::find_type("ip6_routing"); + rv = make_intrusive(ip6_routing_type); + const struct ip6_rthdr* rt = (const struct ip6_rthdr*)data; + rv->Assign(0, rt->ip6r_nxt); + rv->Assign(1, rt->ip6r_len); + rv->Assign(2, rt->ip6r_type); + rv->Assign(3, rt->ip6r_segleft); + rv->Assign(4, new String(data + off, Length() - off, true)); + } break; + + case IPPROTO_FRAGMENT: { + static auto ip6_fragment_type = id::find_type("ip6_fragment"); + rv = make_intrusive(ip6_fragment_type); + const struct ip6_frag* frag = (const struct ip6_frag*)data; + rv->Assign(0, frag->ip6f_nxt); + rv->Assign(1, frag->ip6f_reserved); + rv->Assign(2, (ntohs(frag->ip6f_offlg) & 0xfff8) >> 3); + rv->Assign(3, (ntohs(frag->ip6f_offlg) & 0x0006) >> 1); + rv->Assign(4, static_cast(ntohs(frag->ip6f_offlg) & 0x0001)); + rv->Assign(5, static_cast(ntohl(frag->ip6f_ident))); + } break; + + case IPPROTO_AH: { + static auto ip6_ah_type = id::find_type("ip6_ah"); + rv = make_intrusive(ip6_ah_type); + rv->Assign(0, ((ip6_ext*)data)->ip6e_nxt); + rv->Assign(1, ((ip6_ext*)data)->ip6e_len); + rv->Assign(2, ntohs(((uint16_t*)data)[1])); + rv->Assign(3, static_cast(ntohl(((uint32_t*)data)[1]))); + + if ( Length() >= 12 ) { + // Sequence Number and ICV fields can only be extracted if + // Payload Len was non-zero for this header. + rv->Assign(4, static_cast(ntohl(((uint32_t*)data)[2]))); + uint16_t off = 3 * sizeof(uint32_t); + rv->Assign(5, new String(data + off, Length() - off, true)); + } + } break; + + case IPPROTO_ESP: { + static auto ip6_esp_type = id::find_type("ip6_esp"); + rv = make_intrusive(ip6_esp_type); + const uint32_t* esp = (const uint32_t*)data; + rv->Assign(0, static_cast(ntohl(esp[0]))); + rv->Assign(1, static_cast(ntohl(esp[1]))); + } break; + + case IPPROTO_MOBILITY: { + static auto ip6_mob_type = id::find_type("ip6_mobility_hdr"); + rv = make_intrusive(ip6_mob_type); + const struct ip6_mobility* mob = (const struct ip6_mobility*)data; + rv->Assign(0, mob->ip6mob_payload); + rv->Assign(1, mob->ip6mob_len); + rv->Assign(2, mob->ip6mob_type); + rv->Assign(3, mob->ip6mob_rsv); + rv->Assign(4, ntohs(mob->ip6mob_chksum)); + + static auto ip6_mob_msg_type = id::find_type("ip6_mobility_msg"); + auto msg = make_intrusive(ip6_mob_msg_type); + msg->Assign(0, mob->ip6mob_type); + + uint16_t off = sizeof(ip6_mobility); + const u_char* msg_data = data + off; + + static auto ip6_mob_brr_type = id::find_type("ip6_mobility_brr"); + static auto ip6_mob_hoti_type = id::find_type("ip6_mobility_hoti"); + static auto ip6_mob_coti_type = id::find_type("ip6_mobility_coti"); + static auto ip6_mob_hot_type = id::find_type("ip6_mobility_hot"); + static auto ip6_mob_cot_type = id::find_type("ip6_mobility_cot"); + static auto ip6_mob_bu_type = id::find_type("ip6_mobility_bu"); + static auto ip6_mob_back_type = id::find_type("ip6_mobility_back"); + static auto ip6_mob_be_type = id::find_type("ip6_mobility_be"); + + switch ( mob->ip6mob_type ) { + case 0: { + off += sizeof(uint16_t); + if ( IsOptionTruncated(off) ) + break; + + auto m = make_intrusive(ip6_mob_brr_type); + m->Assign(0, ntohs(*((uint16_t*)msg_data))); + m->Assign(1, BuildOptionsVal(data + off, Length() - off)); + msg->Assign(1, std::move(m)); + break; + } + + case 1: { + off += sizeof(uint16_t) + sizeof(uint64_t); + if ( IsOptionTruncated(off) ) + break; + + auto m = make_intrusive(ip6_mob_hoti_type); + m->Assign(0, ntohs(*((uint16_t*)msg_data))); + m->Assign(1, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t))))); + m->Assign(2, BuildOptionsVal(data + off, Length() - off)); + msg->Assign(2, std::move(m)); + break; + } + + case 2: { + off += sizeof(uint16_t) + sizeof(uint64_t); + if ( IsOptionTruncated(off) ) + break; + + auto m = make_intrusive(ip6_mob_coti_type); + m->Assign(0, ntohs(*((uint16_t*)msg_data))); + m->Assign(1, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t))))); + m->Assign(2, BuildOptionsVal(data + off, Length() - off)); + msg->Assign(3, std::move(m)); + break; + } + + case 3: { + off += sizeof(uint16_t) + 2 * sizeof(uint64_t); + if ( IsOptionTruncated(off) ) + break; + + auto m = make_intrusive(ip6_mob_hot_type); + m->Assign(0, ntohs(*((uint16_t*)msg_data))); + m->Assign(1, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t))))); + m->Assign(2, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t) + sizeof(uint64_t))))); + m->Assign(3, BuildOptionsVal(data + off, Length() - off)); + msg->Assign(4, std::move(m)); + break; + } + + case 4: { + off += sizeof(uint16_t) + 2 * sizeof(uint64_t); + if ( IsOptionTruncated(off) ) + break; + + auto m = make_intrusive(ip6_mob_cot_type); + m->Assign(0, ntohs(*((uint16_t*)msg_data))); + m->Assign(1, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t))))); + m->Assign(2, ntohll(*((uint64_t*)(msg_data + sizeof(uint16_t) + sizeof(uint64_t))))); + m->Assign(3, BuildOptionsVal(data + off, Length() - off)); + msg->Assign(5, std::move(m)); + break; + } + + case 5: { + off += 3 * sizeof(uint16_t); + if ( IsOptionTruncated(off) ) + break; + + auto m = make_intrusive(ip6_mob_bu_type); + m->Assign(0, ntohs(*((uint16_t*)msg_data))); + m->Assign(1, static_cast(ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x8000)); + m->Assign(2, static_cast(ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x4000)); + m->Assign(3, static_cast(ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x2000)); + m->Assign(4, static_cast(ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t)))) & 0x1000)); + m->Assign(5, ntohs(*((uint16_t*)(msg_data + 2 * sizeof(uint16_t))))); + m->Assign(6, BuildOptionsVal(data + off, Length() - off)); + msg->Assign(6, std::move(m)); + break; + } + + case 6: { + off += 3 * sizeof(uint16_t); + if ( IsOptionTruncated(off) ) + break; + + auto m = make_intrusive(ip6_mob_back_type); + m->Assign(0, *((uint8_t*)msg_data)); + m->Assign(1, static_cast(*((uint8_t*)(msg_data + sizeof(uint8_t))) & 0x80)); + m->Assign(2, ntohs(*((uint16_t*)(msg_data + sizeof(uint16_t))))); + m->Assign(3, ntohs(*((uint16_t*)(msg_data + 2 * sizeof(uint16_t))))); + m->Assign(4, BuildOptionsVal(data + off, Length() - off)); + msg->Assign(7, std::move(m)); + break; + } + + case 7: { + off += sizeof(uint16_t) + sizeof(in6_addr); + if ( IsOptionTruncated(off) ) + break; + + auto m = make_intrusive(ip6_mob_be_type); + m->Assign(0, *((uint8_t*)msg_data)); + const in6_addr* hoa = (const in6_addr*)(msg_data + sizeof(uint16_t)); + m->Assign(1, make_intrusive(IPAddr(*hoa))); + m->Assign(2, BuildOptionsVal(data + off, Length() - off)); + msg->Assign(8, std::move(m)); + break; + } + + default: reporter->Weird("unknown_mobility_type", util::fmt("%d", mob->ip6mob_type)); break; + } + + rv->Assign(5, std::move(msg)); + } break; + + default: break; + } + + return rv; +} + +RecordValPtr IPv6_Hdr::ToVal() const { return ToVal(nullptr); } + +IPAddr IP_Hdr::IPHeaderSrcAddr() const { return ip4 ? IPAddr(ip4->ip_src) : IPAddr(ip6->ip6_src); } + +IPAddr IP_Hdr::IPHeaderDstAddr() const { return ip4 ? IPAddr(ip4->ip_dst) : IPAddr(ip6->ip6_dst); } + +IPAddr IP_Hdr::SrcAddr() const { return ip4 ? IPAddr(ip4->ip_src) : ip6_hdrs->SrcAddr(); } + +IPAddr IP_Hdr::DstAddr() const { return ip4 ? IPAddr(ip4->ip_dst) : ip6_hdrs->DstAddr(); } + +RecordValPtr IP_Hdr::ToIPHdrVal() const { + RecordValPtr rval; + + if ( ip4 ) { + static auto ip4_hdr_type = id::find_type("ip4_hdr"); + rval = make_intrusive(ip4_hdr_type); + rval->Assign(0, ip4->ip_hl * 4); + rval->Assign(1, ip4->ip_tos); + rval->Assign(2, ntohs(ip4->ip_len)); + rval->Assign(3, ntohs(ip4->ip_id)); + rval->Assign(4, DF()); + rval->Assign(5, MF()); + rval->Assign(6, FragOffset()); // 13 bit offset as multiple of 8 + rval->Assign(7, ip4->ip_ttl); + rval->Assign(8, ip4->ip_p); + rval->Assign(9, ntohs(ip4->ip_sum)); + rval->Assign(10, make_intrusive(ip4->ip_src.s_addr)); + rval->Assign(11, make_intrusive(ip4->ip_dst.s_addr)); + } + else { + rval = ((*ip6_hdrs)[0])->ToVal(ip6_hdrs->ToVal()); + } + + return rval; +} + +RecordValPtr IP_Hdr::ToPktHdrVal() const { + static auto pkt_hdr_type = id::find_type("pkt_hdr"); + return ToPktHdrVal(make_intrusive(pkt_hdr_type), 0); +} + +RecordValPtr IP_Hdr::ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const { + static auto tcp_hdr_type = id::find_type("tcp_hdr"); + static auto udp_hdr_type = id::find_type("udp_hdr"); + static auto icmp_hdr_type = id::find_type("icmp_hdr"); + + if ( ip4 ) + pkt_hdr->Assign(sindex + 0, ToIPHdrVal()); + else + pkt_hdr->Assign(sindex + 1, ToIPHdrVal()); + + // L4 header. + const u_char* data = Payload(); + + int proto = NextProto(); + switch ( proto ) { + case IPPROTO_TCP: { + if ( PayloadLen() < sizeof(struct tcphdr) ) + break; + + const struct tcphdr* tp = (const struct tcphdr*)data; + auto tcp_hdr = make_intrusive(tcp_hdr_type); + + int tcp_hdr_len = tp->th_off * 4; + + // account for cases in which the payload length in the TCP header is not set, + // or is set to an impossible value. In these cases, return 0. + int data_len = 0; + auto payload_len = PayloadLen(); + if ( payload_len >= tcp_hdr_len ) + data_len = payload_len - tcp_hdr_len; + + tcp_hdr->Assign(0, val_mgr->Port(ntohs(tp->th_sport), TRANSPORT_TCP)); + tcp_hdr->Assign(1, val_mgr->Port(ntohs(tp->th_dport), TRANSPORT_TCP)); + tcp_hdr->Assign(2, static_cast(ntohl(tp->th_seq))); + tcp_hdr->Assign(3, static_cast(ntohl(tp->th_ack))); + tcp_hdr->Assign(4, tcp_hdr_len); + tcp_hdr->Assign(5, data_len); + tcp_hdr->Assign(6, tp->th_x2); + tcp_hdr->Assign(7, tp->th_flags); + tcp_hdr->Assign(8, ntohs(tp->th_win)); + + pkt_hdr->Assign(sindex + 2, std::move(tcp_hdr)); + break; + } + + case IPPROTO_UDP: { + if ( PayloadLen() < sizeof(struct udphdr) ) + break; + + const struct udphdr* up = (const struct udphdr*)data; + auto udp_hdr = make_intrusive(udp_hdr_type); + + udp_hdr->Assign(0, val_mgr->Port(ntohs(up->uh_sport), TRANSPORT_UDP)); + udp_hdr->Assign(1, val_mgr->Port(ntohs(up->uh_dport), TRANSPORT_UDP)); + udp_hdr->Assign(2, ntohs(up->uh_ulen)); + + pkt_hdr->Assign(sindex + 3, std::move(udp_hdr)); + break; + } + + case IPPROTO_ICMP: { + if ( PayloadLen() < sizeof(struct icmp) ) + break; + + const struct icmp* icmpp = (const struct icmp*)data; + auto icmp_hdr = make_intrusive(icmp_hdr_type); + + icmp_hdr->Assign(0, icmpp->icmp_type); + + pkt_hdr->Assign(sindex + 4, std::move(icmp_hdr)); + break; + } + + case IPPROTO_ICMPV6: { + if ( PayloadLen() < sizeof(struct icmp6_hdr) ) + break; + + const struct icmp6_hdr* icmpp = (const struct icmp6_hdr*)data; + auto icmp_hdr = make_intrusive(icmp_hdr_type); + + icmp_hdr->Assign(0, icmpp->icmp6_type); + + pkt_hdr->Assign(sindex + 4, std::move(icmp_hdr)); + break; + } + + default: { + // This is not a protocol we understand. + break; + } + } + + return pkt_hdr; +} + +static inline bool isIPv6ExtHeader(uint8_t type) { + switch ( type ) { + case IPPROTO_HOPOPTS: + case IPPROTO_ROUTING: + case IPPROTO_DSTOPTS: + case IPPROTO_FRAGMENT: + case IPPROTO_AH: + case IPPROTO_ESP: + case IPPROTO_MOBILITY: return true; + default: return false; + } +} + +IPv6_Hdr_Chain::~IPv6_Hdr_Chain() { + for ( size_t i = 0; i < chain.size(); ++i ) + delete chain[i]; + delete homeAddr; + delete finalDst; +} + +void IPv6_Hdr_Chain::Init(const struct ip6_hdr* ip6, uint64_t total_len, bool set_next, uint16_t next) { + length = 0; + uint8_t current_type, next_type; + next_type = IPPROTO_IPV6; + const u_char* hdrs = (const u_char*)ip6; + + if ( total_len < (int)sizeof(struct ip6_hdr) ) { + reporter->InternalWarning("truncated IP header in IPv6_HdrChain::Init"); + return; + } + + do { + // We can't determine a given header's length if there's less than + // two bytes of data available (2nd byte of extension headers is length) + if ( total_len < 2 ) + return; + + current_type = next_type; + IPv6_Hdr* p = new IPv6_Hdr(current_type, hdrs); + + next_type = p->NextHdr(); + uint16_t cur_len = p->Length(); + + // If this header is truncated, don't add it to chain, don't go further. + if ( cur_len > total_len ) { + delete p; + return; + } + + if ( set_next && next_type == IPPROTO_FRAGMENT ) { + p->ChangeNext(next); + next_type = next; + } + + chain.push_back(p); + + // Check for routing headers and remember final destination address. + if ( current_type == IPPROTO_ROUTING ) + ProcessRoutingHeader((const struct ip6_rthdr*)hdrs, cur_len); + + // Only Mobile IPv6 has a destination option we care about right now. + if ( current_type == IPPROTO_DSTOPTS ) + ProcessDstOpts((const struct ip6_dest*)hdrs, cur_len); + + hdrs += cur_len; + length += cur_len; + total_len -= cur_len; + + } while ( current_type != IPPROTO_FRAGMENT && current_type != IPPROTO_ESP && current_type != IPPROTO_MOBILITY && + isIPv6ExtHeader(next_type) ); +} + +bool IPv6_Hdr_Chain::IsFragment() const { + if ( chain.empty() ) { + reporter->InternalWarning("empty IPv6 header chain"); + return false; + } + + return chain[chain.size() - 1]->Type() == IPPROTO_FRAGMENT; +} + +IPAddr IPv6_Hdr_Chain::SrcAddr() const { + if ( homeAddr ) + return {*homeAddr}; + + if ( chain.empty() ) { + reporter->InternalWarning("empty IPv6 header chain"); + return {}; + } + + return IPAddr{((const struct ip6_hdr*)(chain[0]->Data()))->ip6_src}; +} + +IPAddr IPv6_Hdr_Chain::DstAddr() const { + if ( finalDst ) + return {*finalDst}; + + if ( chain.empty() ) { + reporter->InternalWarning("empty IPv6 header chain"); + return {}; + } + + return IPAddr{((const struct ip6_hdr*)(chain[0]->Data()))->ip6_dst}; +} + +void IPv6_Hdr_Chain::ProcessRoutingHeader(const struct ip6_rthdr* r, uint16_t len) { + if ( finalDst ) { + // RFC 2460 section 4.1 says Routing should occur at most once. + reporter->Weird(SrcAddr(), DstAddr(), "multiple_routing_headers"); + return; + } + + // Last 16 bytes of header (for all known types) is the address we want. + const in6_addr* addr = (const in6_addr*)(((const u_char*)r) + len - 16); + + switch ( r->ip6r_type ) { + case 0: // Defined by RFC 2460, deprecated by RFC 5095 + { + if ( r->ip6r_segleft > 0 && r->ip6r_len >= 2 ) { + if ( r->ip6r_len % 2 == 0 ) + finalDst = new IPAddr(*addr); + else + reporter->Weird(SrcAddr(), DstAddr(), "odd_routing0_len"); + } + + // Always raise a weird since this type is deprecated. + reporter->Weird(SrcAddr(), DstAddr(), "routing0_hdr"); + } break; + + case 2: // Defined by Mobile IPv6 RFC 6275. + { + if ( r->ip6r_segleft > 0 ) { + if ( r->ip6r_len == 2 ) + finalDst = new IPAddr(*addr); + else + reporter->Weird(SrcAddr(), DstAddr(), "bad_routing2_len"); + } + } break; + + default: reporter->Weird(SrcAddr(), DstAddr(), "unknown_routing_type", util::fmt("%d", r->ip6r_type)); break; + } +} + +void IPv6_Hdr_Chain::ProcessDstOpts(const struct ip6_dest* d, uint16_t len) { + // Skip two bytes to get the beginning of the first option structure. These + // two bytes are the protocol for the next header and extension header length, + // already known to exist before calling this method. See header format: + // https://datatracker.ietf.org/doc/html/rfc8200#section-4.6 + assert(len >= 2); + + const u_char* data = (const u_char*)d; + len -= 2 * sizeof(uint8_t); + data += 2 * sizeof(uint8_t); + + while ( len > 0 ) { + const struct ip6_opt* opt = (const struct ip6_opt*)data; + switch ( opt->ip6o_type ) { + case 0: + // If option type is zero, it's a Pad0 and can be just a single + // byte in width. Skip over it. + data += sizeof(uint8_t); + len -= sizeof(uint8_t); + break; + default: { + // Double-check that the len can hold the whole option structure. + // Otherwise we get a buffer-overflow when we check the option_len. + // Also check that it holds everything for the option itself. + if ( len < sizeof(struct ip6_opt) || len < sizeof(struct ip6_opt) + opt->ip6o_len ) { + reporter->Weird(SrcAddr(), DstAddr(), "bad_ipv6_dest_opt_len"); + len = 0; + break; + } + + if ( opt->ip6o_type == 201 ) // Home Address Option, Mobile IPv6 RFC 6275 section 6.3 + { + if ( opt->ip6o_len == sizeof(struct in6_addr) ) { + if ( homeAddr ) + reporter->Weird(SrcAddr(), DstAddr(), "multiple_home_addr_opts"); + else + homeAddr = new IPAddr(*((const in6_addr*)(data + sizeof(struct ip6_opt)))); + } + else + reporter->Weird(SrcAddr(), DstAddr(), "bad_home_addr_len"); + } + + data += sizeof(struct ip6_opt) + opt->ip6o_len; + len -= sizeof(struct ip6_opt) + opt->ip6o_len; + } break; + } + } +} + +VectorValPtr IPv6_Hdr_Chain::ToVal() const { + static auto ip6_ext_hdr_type = id::find_type("ip6_ext_hdr"); + static auto ip6_hopopts_type = id::find_type("ip6_hopopts"); + static auto ip6_dstopts_type = id::find_type("ip6_dstopts"); + static auto ip6_routing_type = id::find_type("ip6_routing"); + static auto ip6_fragment_type = id::find_type("ip6_fragment"); + static auto ip6_ah_type = id::find_type("ip6_ah"); + static auto ip6_esp_type = id::find_type("ip6_esp"); + static auto ip6_ext_hdr_chain_type = id::find_type("ip6_ext_hdr_chain"); + auto rval = make_intrusive(ip6_ext_hdr_chain_type); + + for ( size_t i = 1; i < chain.size(); ++i ) { + auto v = chain[i]->ToVal(); + auto ext_hdr = make_intrusive(ip6_ext_hdr_type); + uint8_t type = chain[i]->Type(); + ext_hdr->Assign(0, type); + + switch ( type ) { + case IPPROTO_HOPOPTS: ext_hdr->Assign(1, std::move(v)); break; + case IPPROTO_DSTOPTS: ext_hdr->Assign(2, std::move(v)); break; + case IPPROTO_ROUTING: ext_hdr->Assign(3, std::move(v)); break; + case IPPROTO_FRAGMENT: ext_hdr->Assign(4, std::move(v)); break; + case IPPROTO_AH: ext_hdr->Assign(5, std::move(v)); break; + case IPPROTO_ESP: ext_hdr->Assign(6, std::move(v)); break; + case IPPROTO_MOBILITY: ext_hdr->Assign(7, std::move(v)); break; + default: reporter->InternalWarning("IPv6_Hdr_Chain bad header %d", type); continue; + } + + rval->Assign(rval->Size(), std::move(ext_hdr)); + } + + return rval; +} + +IP_Hdr* IP_Hdr::Copy() const { + char* new_hdr = new char[HdrLen()]; + + if ( ip4 ) { + memcpy(new_hdr, ip4, HdrLen()); + return new IP_Hdr((const struct ip*)new_hdr, true); + } + + memcpy(new_hdr, ip6, HdrLen()); + const struct ip6_hdr* new_ip6 = (const struct ip6_hdr*)new_hdr; + IPv6_Hdr_Chain* new_ip6_hdrs = ip6_hdrs->Copy(new_ip6); + return new IP_Hdr(new_ip6, true, 0, new_ip6_hdrs); +} + +IPv6_Hdr_Chain* IPv6_Hdr_Chain::Copy(const ip6_hdr* new_hdr) const { + IPv6_Hdr_Chain* rval = new IPv6_Hdr_Chain; + rval->length = length; + + if ( homeAddr ) + rval->homeAddr = new IPAddr(*homeAddr); + + if ( finalDst ) + rval->finalDst = new IPAddr(*finalDst); + + if ( chain.empty() ) { + reporter->InternalWarning("empty IPv6 header chain"); + delete rval; + return nullptr; + } + + const u_char* new_data = (const u_char*)new_hdr; + const u_char* old_data = chain[0]->Data(); + + for ( size_t i = 0; i < chain.size(); ++i ) { + int off = chain[i]->Data() - old_data; + rval->chain.push_back(new IPv6_Hdr(chain[i]->Type(), new_data + off)); + } + + return rval; +} + +} // namespace zeek diff --git a/src/IP.h b/src/IP.h index 40d1ba574d..8744c8b955 100644 --- a/src/IP.h +++ b/src/IP.h @@ -20,8 +20,7 @@ #include "zeek/IntrusivePtr.h" -namespace zeek - { +namespace zeek { class IPAddr; class RecordVal; @@ -29,530 +28,478 @@ class VectorVal; using RecordValPtr = IntrusivePtr; using VectorValPtr = IntrusivePtr; -namespace detail - { +namespace detail { class FragReassembler; - } +} #ifndef IPPROTO_MOBILITY #define IPPROTO_MOBILITY 135 #endif -struct ip6_mobility - { - uint8_t ip6mob_payload; - uint8_t ip6mob_len; - uint8_t ip6mob_type; - uint8_t ip6mob_rsv; - uint16_t ip6mob_chksum; - }; +struct ip6_mobility { + uint8_t ip6mob_payload; + uint8_t ip6mob_len; + uint8_t ip6mob_type; + uint8_t ip6mob_rsv; + uint16_t ip6mob_chksum; +}; /** * Base class for IPv6 header/extensions. */ -class IPv6_Hdr - { +class IPv6_Hdr { public: - /** - * Construct an IPv6 header or extension header from assigned type number. - */ - IPv6_Hdr(uint8_t t, const u_char* d) : type(t), data(d) { } + /** + * Construct an IPv6 header or extension header from assigned type number. + */ + IPv6_Hdr(uint8_t t, const u_char* d) : type(t), data(d) {} - /** - * Replace the value of the next protocol field. - */ - void ChangeNext(uint8_t next_type) - { - switch ( type ) - { - case IPPROTO_IPV6: - ((ip6_hdr*)data)->ip6_nxt = next_type; - break; - case IPPROTO_HOPOPTS: - case IPPROTO_DSTOPTS: - case IPPROTO_ROUTING: - case IPPROTO_FRAGMENT: - case IPPROTO_AH: - case IPPROTO_MOBILITY: - ((ip6_ext*)data)->ip6e_nxt = next_type; - break; - case IPPROTO_ESP: - default: - break; - } - } + /** + * Replace the value of the next protocol field. + */ + void ChangeNext(uint8_t next_type) { + switch ( type ) { + case IPPROTO_IPV6: ((ip6_hdr*)data)->ip6_nxt = next_type; break; + case IPPROTO_HOPOPTS: + case IPPROTO_DSTOPTS: + case IPPROTO_ROUTING: + case IPPROTO_FRAGMENT: + case IPPROTO_AH: + case IPPROTO_MOBILITY: ((ip6_ext*)data)->ip6e_nxt = next_type; break; + case IPPROTO_ESP: + default: break; + } + } - ~IPv6_Hdr() { } + ~IPv6_Hdr() {} - /** - * Returns the assigned IPv6 extension header type number of the header - * that immediately follows this one. - */ - uint8_t NextHdr() const - { - switch ( type ) - { - case IPPROTO_IPV6: - return ((ip6_hdr*)data)->ip6_nxt; - case IPPROTO_HOPOPTS: - case IPPROTO_DSTOPTS: - case IPPROTO_ROUTING: - case IPPROTO_FRAGMENT: - case IPPROTO_AH: - case IPPROTO_MOBILITY: - return ((ip6_ext*)data)->ip6e_nxt; - case IPPROTO_ESP: - default: - return IPPROTO_NONE; - } - } + /** + * Returns the assigned IPv6 extension header type number of the header + * that immediately follows this one. + */ + uint8_t NextHdr() const { + switch ( type ) { + case IPPROTO_IPV6: return ((ip6_hdr*)data)->ip6_nxt; + case IPPROTO_HOPOPTS: + case IPPROTO_DSTOPTS: + case IPPROTO_ROUTING: + case IPPROTO_FRAGMENT: + case IPPROTO_AH: + case IPPROTO_MOBILITY: return ((ip6_ext*)data)->ip6e_nxt; + case IPPROTO_ESP: + default: return IPPROTO_NONE; + } + } - /** - * Returns the length of the header in bytes. - */ - uint16_t Length() const - { - switch ( type ) - { - case IPPROTO_IPV6: - return 40; - case IPPROTO_HOPOPTS: - case IPPROTO_DSTOPTS: - case IPPROTO_ROUTING: - case IPPROTO_MOBILITY: - return 8 + 8 * ((ip6_ext*)data)->ip6e_len; - case IPPROTO_FRAGMENT: - return 8; - case IPPROTO_AH: - return 8 + 4 * ((ip6_ext*)data)->ip6e_len; - case IPPROTO_ESP: - return 8; // encrypted payload begins after 8 bytes - default: - return 0; - } - } + /** + * Returns the length of the header in bytes. + */ + uint16_t Length() const { + switch ( type ) { + case IPPROTO_IPV6: return 40; + case IPPROTO_HOPOPTS: + case IPPROTO_DSTOPTS: + case IPPROTO_ROUTING: + case IPPROTO_MOBILITY: return 8 + 8 * ((ip6_ext*)data)->ip6e_len; + case IPPROTO_FRAGMENT: return 8; + case IPPROTO_AH: return 8 + 4 * ((ip6_ext*)data)->ip6e_len; + case IPPROTO_ESP: return 8; // encrypted payload begins after 8 bytes + default: return 0; + } + } - /** - * Returns the RFC 1700 et seq. IANA assigned number for the header. - */ - uint8_t Type() const { return type; } + /** + * Returns the RFC 1700 et seq. IANA assigned number for the header. + */ + uint8_t Type() const { return type; } - /** - * Returns pointer to the start of where header structure resides in memory. - */ - const u_char* Data() const { return data; } + /** + * Returns pointer to the start of where header structure resides in memory. + */ + const u_char* Data() const { return data; } - /** - * Returns the script-layer record representation of the header. - */ - RecordValPtr ToVal(VectorValPtr chain) const; - RecordValPtr ToVal() const; + /** + * Returns the script-layer record representation of the header. + */ + RecordValPtr ToVal(VectorValPtr chain) const; + RecordValPtr ToVal() const; protected: - uint8_t type; - const u_char* data; + uint8_t type; + const u_char* data; private: - bool IsOptionTruncated(uint16_t off) const; - }; + bool IsOptionTruncated(uint16_t off) const; +}; -class IPv6_Hdr_Chain - { +class IPv6_Hdr_Chain { public: - /** - * Initializes the header chain from an IPv6 header structure. - */ - IPv6_Hdr_Chain(const struct ip6_hdr* ip6, uint64_t len) { Init(ip6, len, false); } + /** + * Initializes the header chain from an IPv6 header structure. + */ + IPv6_Hdr_Chain(const struct ip6_hdr* ip6, uint64_t len) { Init(ip6, len, false); } - ~IPv6_Hdr_Chain(); + ~IPv6_Hdr_Chain(); - /** - * @return a copy of the header chain, but with pointers to individual - * IPv6 headers now pointing within \a new_hdr. - */ - IPv6_Hdr_Chain* Copy(const struct ip6_hdr* new_hdr) const; + /** + * @return a copy of the header chain, but with pointers to individual + * IPv6 headers now pointing within \a new_hdr. + */ + IPv6_Hdr_Chain* Copy(const struct ip6_hdr* new_hdr) const; - /** - * Returns the number of headers in the chain. - */ - size_t Size() const { return chain.size(); } + /** + * Returns the number of headers in the chain. + */ + size_t Size() const { return chain.size(); } - /** - * Returns the sum of the length of all headers in the chain in bytes. - */ - uint16_t TotalLength() const { return length; } + /** + * Returns the sum of the length of all headers in the chain in bytes. + */ + uint16_t TotalLength() const { return length; } - /** - * Accesses the header at the given location in the chain. - */ - const IPv6_Hdr* operator[](const size_t i) const { return chain[i]; } + /** + * Accesses the header at the given location in the chain. + */ + const IPv6_Hdr* operator[](const size_t i) const { return chain[i]; } - /** - * Returns whether the header chain indicates a fragmented packet. - */ - bool IsFragment() const; + /** + * Returns whether the header chain indicates a fragmented packet. + */ + bool IsFragment() const; - /** - * Returns pointer to fragment header structure if the chain contains one. - */ - const struct ip6_frag* GetFragHdr() const - { - return IsFragment() ? (const struct ip6_frag*)chain[chain.size() - 1]->Data() : nullptr; - } + /** + * Returns pointer to fragment header structure if the chain contains one. + */ + const struct ip6_frag* GetFragHdr() const { + return IsFragment() ? (const struct ip6_frag*)chain[chain.size() - 1]->Data() : nullptr; + } - /** - * If the header chain is a fragment, returns the offset in number of bytes - * relative to the start of the Fragmentable Part of the original packet. - */ - uint16_t FragOffset() const - { - return IsFragment() ? (ntohs(GetFragHdr()->ip6f_offlg) & 0xfff8) : 0; - } + /** + * If the header chain is a fragment, returns the offset in number of bytes + * relative to the start of the Fragmentable Part of the original packet. + */ + uint16_t FragOffset() const { return IsFragment() ? (ntohs(GetFragHdr()->ip6f_offlg) & 0xfff8) : 0; } - /** - * If the header chain is a fragment, returns the identification field. - */ - uint32_t ID() const { return IsFragment() ? ntohl(GetFragHdr()->ip6f_ident) : 0; } + /** + * If the header chain is a fragment, returns the identification field. + */ + uint32_t ID() const { return IsFragment() ? ntohl(GetFragHdr()->ip6f_ident) : 0; } - /** - * If the header chain is a fragment, returns the M (more fragments) flag. - */ - int MF() const { return IsFragment() ? (ntohs(GetFragHdr()->ip6f_offlg) & 0x0001) != 0 : 0; } + /** + * If the header chain is a fragment, returns the M (more fragments) flag. + */ + int MF() const { return IsFragment() ? (ntohs(GetFragHdr()->ip6f_offlg) & 0x0001) != 0 : 0; } - /** - * If the chain contains a Destination Options header with a Home Address - * option as defined by Mobile IPv6 (RFC 6275), then return it, else - * return the source address in the main IPv6 header. - */ - IPAddr SrcAddr() const; + /** + * If the chain contains a Destination Options header with a Home Address + * option as defined by Mobile IPv6 (RFC 6275), then return it, else + * return the source address in the main IPv6 header. + */ + IPAddr SrcAddr() const; - /** - * If the chain contains a Routing header with non-zero segments left, - * then return the last address of the first such header, else return - * the destination address of the main IPv6 header. - */ - IPAddr DstAddr() const; + /** + * If the chain contains a Routing header with non-zero segments left, + * then return the last address of the first such header, else return + * the destination address of the main IPv6 header. + */ + IPAddr DstAddr() const; - /** - * Returns a vector of ip6_ext_hdr RecordVals that includes script-layer - * representation of all extension headers in the chain. - */ - VectorValPtr ToVal() const; + /** + * Returns a vector of ip6_ext_hdr RecordVals that includes script-layer + * representation of all extension headers in the chain. + */ + VectorValPtr ToVal() const; protected: - // for access to protected ctor that changes next header values that - // point to a fragment - friend class detail::FragReassembler; + // for access to protected ctor that changes next header values that + // point to a fragment + friend class detail::FragReassembler; - IPv6_Hdr_Chain() = default; + IPv6_Hdr_Chain() = default; - /** - * Initializes the header chain from an IPv6 header structure, and replaces - * the first next protocol pointer field that points to a fragment header. - */ - IPv6_Hdr_Chain(const struct ip6_hdr* ip6, uint16_t next, uint64_t len) - { - Init(ip6, len, true, next); - } + /** + * Initializes the header chain from an IPv6 header structure, and replaces + * the first next protocol pointer field that points to a fragment header. + */ + IPv6_Hdr_Chain(const struct ip6_hdr* ip6, uint16_t next, uint64_t len) { Init(ip6, len, true, next); } - /** - * Initializes the header chain from an IPv6 header structure of a given - * length, possibly setting the first next protocol pointer field that - * points to a fragment header. - */ - void Init(const struct ip6_hdr* ip6, uint64_t total_len, bool set_next, uint16_t next = 0); + /** + * Initializes the header chain from an IPv6 header structure of a given + * length, possibly setting the first next protocol pointer field that + * points to a fragment header. + */ + void Init(const struct ip6_hdr* ip6, uint64_t total_len, bool set_next, uint16_t next = 0); - /** - * Process a routing header and allocate/remember the final destination - * address if it has segments left and is a valid routing header. - */ - void ProcessRoutingHeader(const struct ip6_rthdr* r, uint16_t len); + /** + * Process a routing header and allocate/remember the final destination + * address if it has segments left and is a valid routing header. + */ + void ProcessRoutingHeader(const struct ip6_rthdr* r, uint16_t len); - /** - * Inspect a Destination Option header's options for things we need to - * remember, such as the Home Address option from Mobile IPv6. - */ - void ProcessDstOpts(const struct ip6_dest* d, uint16_t len); + /** + * Inspect a Destination Option header's options for things we need to + * remember, such as the Home Address option from Mobile IPv6. + */ + void ProcessDstOpts(const struct ip6_dest* d, uint16_t len); - std::vector chain; + std::vector chain; - /** - * The summation of all header lengths in the chain in bytes. - */ - uint16_t length = 0; + /** + * The summation of all header lengths in the chain in bytes. + */ + uint16_t length = 0; - /** - * Home Address of the packet's source as defined by Mobile IPv6 (RFC 6275). - */ - IPAddr* homeAddr = nullptr; + /** + * Home Address of the packet's source as defined by Mobile IPv6 (RFC 6275). + */ + IPAddr* homeAddr = nullptr; - /** - * The final destination address in chain's first Routing header that has - * non-zero segments left. - */ - IPAddr* finalDst = nullptr; - }; + /** + * The final destination address in chain's first Routing header that has + * non-zero segments left. + */ + IPAddr* finalDst = nullptr; +}; /** * A class that wraps either an IPv4 or IPv6 packet and abstracts methods * for inquiring about common features between the two. */ -class IP_Hdr - { +class IP_Hdr { public: - /** - * Construct the header wrapper from an IPv4 packet. Caller must have - * already checked that the header is not truncated. - * @param arg_ip4 pointer to memory containing an IPv4 packet. - * @param arg_del whether to take ownership of \a arg_ip4 pointer's memory. - * @param reassembled whether this header is for a reassembled packet. - */ - IP_Hdr(const struct ip* arg_ip4, bool arg_del, bool reassembled = false) - : ip4(arg_ip4), del(arg_del), reassembled(reassembled) - { - } + /** + * Construct the header wrapper from an IPv4 packet. Caller must have + * already checked that the header is not truncated. + * @param arg_ip4 pointer to memory containing an IPv4 packet. + * @param arg_del whether to take ownership of \a arg_ip4 pointer's memory. + * @param reassembled whether this header is for a reassembled packet. + */ + IP_Hdr(const struct ip* arg_ip4, bool arg_del, bool reassembled = false) + : ip4(arg_ip4), del(arg_del), reassembled(reassembled) {} - /** - * Construct the header wrapper from an IPv6 packet. Caller must have - * already checked that the static IPv6 header is not truncated. If - * the packet contains extension headers and they are truncated, that can - * be checked afterwards by comparing \a len with \a TotalLen. E.g. - * The IP packet analyzer does this to skip truncated packets. - * @param arg_ip6 pointer to memory containing an IPv6 packet. - * @param arg_del whether to take ownership of \a arg_ip6 pointer's memory. - * @param len the packet's length in bytes. - * @param c an already-constructed header chain to take ownership of. - * @param reassembled whether this header is for a reassembled packet. - */ - IP_Hdr(const struct ip6_hdr* arg_ip6, bool arg_del, uint64_t len, - const IPv6_Hdr_Chain* c = nullptr, bool reassembled = false) - : ip6(arg_ip6), ip6_hdrs(c ? c : new IPv6_Hdr_Chain(ip6, len)), del(arg_del), - reassembled(reassembled) - { - } + /** + * Construct the header wrapper from an IPv6 packet. Caller must have + * already checked that the static IPv6 header is not truncated. If + * the packet contains extension headers and they are truncated, that can + * be checked afterwards by comparing \a len with \a TotalLen. E.g. + * The IP packet analyzer does this to skip truncated packets. + * @param arg_ip6 pointer to memory containing an IPv6 packet. + * @param arg_del whether to take ownership of \a arg_ip6 pointer's memory. + * @param len the packet's length in bytes. + * @param c an already-constructed header chain to take ownership of. + * @param reassembled whether this header is for a reassembled packet. + */ + IP_Hdr(const struct ip6_hdr* arg_ip6, bool arg_del, uint64_t len, const IPv6_Hdr_Chain* c = nullptr, + bool reassembled = false) + : ip6(arg_ip6), ip6_hdrs(c ? c : new IPv6_Hdr_Chain(ip6, len)), del(arg_del), reassembled(reassembled) {} - /** - * Copy a header. The internal buffer which contains the header data - * must not be truncated. Also note that if that buffer points to a full - * packet payload, only the IP header portion is copied. - */ - IP_Hdr* Copy() const; + /** + * Copy a header. The internal buffer which contains the header data + * must not be truncated. Also note that if that buffer points to a full + * packet payload, only the IP header portion is copied. + */ + IP_Hdr* Copy() const; - /** - * Destructor. - */ - ~IP_Hdr() - { - delete ip6_hdrs; + /** + * Destructor. + */ + ~IP_Hdr() { + delete ip6_hdrs; - if ( del ) - { - delete[](struct ip*) ip4; - delete[](struct ip6_hdr*) ip6; - } - } + if ( del ) { + delete[](struct ip*) ip4; + delete[](struct ip6_hdr*) ip6; + } + } - /** - * If an IPv4 packet is wrapped, return a pointer to it, else null. - */ - const struct ip* IP4_Hdr() const { return ip4; } + /** + * If an IPv4 packet is wrapped, return a pointer to it, else null. + */ + const struct ip* IP4_Hdr() const { return ip4; } - /** - * If an IPv6 packet is wrapped, return a pointer to it, else null. - */ - const struct ip6_hdr* IP6_Hdr() const { return ip6; } + /** + * If an IPv6 packet is wrapped, return a pointer to it, else null. + */ + const struct ip6_hdr* IP6_Hdr() const { return ip6; } - /** - * Returns the source address held in the IP header. - */ - IPAddr IPHeaderSrcAddr() const; + /** + * Returns the source address held in the IP header. + */ + IPAddr IPHeaderSrcAddr() const; - /** - * Returns the destination address held in the IP header. - */ - IPAddr IPHeaderDstAddr() const; + /** + * Returns the destination address held in the IP header. + */ + IPAddr IPHeaderDstAddr() const; - /** - * For IPv4 or IPv6 headers that don't contain a Home Address option - * (Mobile IPv6, RFC 6275), return source address held in the IP header. - * For IPv6 headers that contain a Home Address option, return that address. - */ - IPAddr SrcAddr() const; + /** + * For IPv4 or IPv6 headers that don't contain a Home Address option + * (Mobile IPv6, RFC 6275), return source address held in the IP header. + * For IPv6 headers that contain a Home Address option, return that address. + */ + IPAddr SrcAddr() const; - /** - * For IPv4 or IPv6 headers that don't contain a Routing header with - * non-zero segments left, return destination address held in the IP header. - * For IPv6 headers with a Routing header that has non-zero segments left, - * return the last address in the first such Routing header. - */ - IPAddr DstAddr() const; + /** + * For IPv4 or IPv6 headers that don't contain a Routing header with + * non-zero segments left, return destination address held in the IP header. + * For IPv6 headers with a Routing header that has non-zero segments left, + * return the last address in the first such Routing header. + */ + IPAddr DstAddr() const; - /** - * Returns a pointer to the payload of the IP packet, usually an - * upper-layer protocol. - */ - const u_char* Payload() const - { - if ( ip4 ) - return ((const u_char*)ip4) + ip4->ip_hl * 4; + /** + * Returns a pointer to the payload of the IP packet, usually an + * upper-layer protocol. + */ + const u_char* Payload() const { + if ( ip4 ) + return ((const u_char*)ip4) + ip4->ip_hl * 4; - return ((const u_char*)ip6) + ip6_hdrs->TotalLength(); - } + return ((const u_char*)ip6) + ip6_hdrs->TotalLength(); + } - /** - * Returns a pointer to the mobility header of the IP packet, if present, - * else a null pointer. - */ - const ip6_mobility* MobilityHeader() const - { - if ( ip4 ) - return nullptr; - else if ( (*ip6_hdrs)[ip6_hdrs->Size() - 1]->Type() != IPPROTO_MOBILITY ) - return nullptr; - else - return (const ip6_mobility*)(*ip6_hdrs)[ip6_hdrs->Size() - 1]->Data(); - } + /** + * Returns a pointer to the mobility header of the IP packet, if present, + * else a null pointer. + */ + const ip6_mobility* MobilityHeader() const { + if ( ip4 ) + return nullptr; + else if ( (*ip6_hdrs)[ip6_hdrs->Size() - 1]->Type() != IPPROTO_MOBILITY ) + return nullptr; + else + return (const ip6_mobility*)(*ip6_hdrs)[ip6_hdrs->Size() - 1]->Data(); + } - /** - * Returns the length of the IP packet's payload (length of packet minus - * header length or, for IPv6, also minus length of all extension headers). - * - * Also returns 0 if the IPv4 length field is set to zero - which is, e.g., - * the case when TCP segment offloading is enabled. - */ - uint16_t PayloadLen() const - { - if ( ip4 ) - { - // prevent overflow in case of segment offloading/zeroed header length. - auto total_len = ntohs(ip4->ip_len); - return total_len ? total_len - ip4->ip_hl * 4 : 0; - } + /** + * Returns the length of the IP packet's payload (length of packet minus + * header length or, for IPv6, also minus length of all extension headers). + * + * Also returns 0 if the IPv4 length field is set to zero - which is, e.g., + * the case when TCP segment offloading is enabled. + */ + uint16_t PayloadLen() const { + if ( ip4 ) { + // prevent overflow in case of segment offloading/zeroed header length. + auto total_len = ntohs(ip4->ip_len); + return total_len ? total_len - ip4->ip_hl * 4 : 0; + } - return ntohs(ip6->ip6_plen) + 40 - ip6_hdrs->TotalLength(); - } + return ntohs(ip6->ip6_plen) + 40 - ip6_hdrs->TotalLength(); + } - /** - * Returns the length of the IP packet (length of headers and payload). - */ - uint32_t TotalLen() const - { - if ( ip4 ) - return ntohs(ip4->ip_len); + /** + * Returns the length of the IP packet (length of headers and payload). + */ + uint32_t TotalLen() const { + if ( ip4 ) + return ntohs(ip4->ip_len); - return ntohs(ip6->ip6_plen) + 40; - } + return ntohs(ip6->ip6_plen) + 40; + } - /** - * Returns length of IP packet header (includes extension headers for IPv6). - */ - uint16_t HdrLen() const { return ip4 ? ip4->ip_hl * 4 : ip6_hdrs->TotalLength(); } + /** + * Returns length of IP packet header (includes extension headers for IPv6). + */ + uint16_t HdrLen() const { return ip4 ? ip4->ip_hl * 4 : ip6_hdrs->TotalLength(); } - /** - * For IPv6 header chains, returns the type of the last header in the chain. - */ - uint8_t LastHeader() const - { - if ( ip4 ) - return IPPROTO_RAW; + /** + * For IPv6 header chains, returns the type of the last header in the chain. + */ + uint8_t LastHeader() const { + if ( ip4 ) + return IPPROTO_RAW; - size_t i = ip6_hdrs->Size(); - if ( i > 0 ) - return (*ip6_hdrs)[i - 1]->Type(); + size_t i = ip6_hdrs->Size(); + if ( i > 0 ) + return (*ip6_hdrs)[i - 1]->Type(); - return IPPROTO_NONE; - } + return IPPROTO_NONE; + } - /** - * Returns the protocol type of the IP packet's payload, usually an - * upper-layer protocol. For IPv6, this returns the last (extension) - * header's Next Header value. - */ - unsigned char NextProto() const - { - if ( ip4 ) - return ip4->ip_p; + /** + * Returns the protocol type of the IP packet's payload, usually an + * upper-layer protocol. For IPv6, this returns the last (extension) + * header's Next Header value. + */ + unsigned char NextProto() const { + if ( ip4 ) + return ip4->ip_p; - size_t i = ip6_hdrs->Size(); - if ( i > 0 ) - return (*ip6_hdrs)[i - 1]->NextHdr(); + size_t i = ip6_hdrs->Size(); + if ( i > 0 ) + return (*ip6_hdrs)[i - 1]->NextHdr(); - return IPPROTO_NONE; - } + return IPPROTO_NONE; + } - /** - * Returns the IPv4 Time to Live or IPv6 Hop Limit field. - */ - unsigned char TTL() const { return ip4 ? ip4->ip_ttl : ip6->ip6_hlim; } + /** + * Returns the IPv4 Time to Live or IPv6 Hop Limit field. + */ + unsigned char TTL() const { return ip4 ? ip4->ip_ttl : ip6->ip6_hlim; } - /** - * Returns whether the IP header indicates this packet is a fragment. - */ - bool IsFragment() const - { - return ip4 ? (ntohs(ip4->ip_off) & 0x3fff) != 0 : ip6_hdrs->IsFragment(); - } + /** + * Returns whether the IP header indicates this packet is a fragment. + */ + bool IsFragment() const { return ip4 ? (ntohs(ip4->ip_off) & 0x3fff) != 0 : ip6_hdrs->IsFragment(); } - /** - * Returns the fragment packet's offset in relation to the original - * packet in bytes. - */ - uint16_t FragOffset() const - { - return ip4 ? (ntohs(ip4->ip_off) & 0x1fff) * 8 : ip6_hdrs->FragOffset(); - } + /** + * Returns the fragment packet's offset in relation to the original + * packet in bytes. + */ + uint16_t FragOffset() const { return ip4 ? (ntohs(ip4->ip_off) & 0x1fff) * 8 : ip6_hdrs->FragOffset(); } - /** - * Returns the fragment packet's identification field. - */ - uint32_t ID() const { return ip4 ? ntohs(ip4->ip_id) : ip6_hdrs->ID(); } + /** + * Returns the fragment packet's identification field. + */ + uint32_t ID() const { return ip4 ? ntohs(ip4->ip_id) : ip6_hdrs->ID(); } - /** - * Returns whether a fragment packet's "More Fragments" field is set. - */ - int MF() const { return ip4 ? (ntohs(ip4->ip_off) & 0x2000) != 0 : ip6_hdrs->MF(); } + /** + * Returns whether a fragment packet's "More Fragments" field is set. + */ + int MF() const { return ip4 ? (ntohs(ip4->ip_off) & 0x2000) != 0 : ip6_hdrs->MF(); } - /** - * Returns whether a fragment packet's "Don't Fragment" field is set. - * Note that IPv6 has no such field. - */ - int DF() const { return ip4 ? ((ntohs(ip4->ip_off) & 0x4000) != 0) : 0; } + /** + * Returns whether a fragment packet's "Don't Fragment" field is set. + * Note that IPv6 has no such field. + */ + int DF() const { return ip4 ? ((ntohs(ip4->ip_off) & 0x4000) != 0) : 0; } - /** - * Returns value of an IPv6 header's flow label field or 0 if it's IPv4. - */ - uint32_t FlowLabel() const { return ip4 ? 0 : (ntohl(ip6->ip6_flow) & 0x000fffff); } + /** + * Returns value of an IPv6 header's flow label field or 0 if it's IPv4. + */ + uint32_t FlowLabel() const { return ip4 ? 0 : (ntohl(ip6->ip6_flow) & 0x000fffff); } - /** - * Returns number of IP headers in packet (includes IPv6 extension headers). - */ - size_t NumHeaders() const { return ip4 ? 1 : ip6_hdrs->Size(); } + /** + * Returns number of IP headers in packet (includes IPv6 extension headers). + */ + size_t NumHeaders() const { return ip4 ? 1 : ip6_hdrs->Size(); } - /** - * Returns an ip_hdr or ip6_hdr_chain RecordVal. - */ - RecordValPtr ToIPHdrVal() const; + /** + * Returns an ip_hdr or ip6_hdr_chain RecordVal. + */ + RecordValPtr ToIPHdrVal() const; - /** - * Returns a pkt_hdr RecordVal, which includes not only the IP header, but - * also upper-layer (tcp/udp/icmp) headers. - */ - RecordValPtr ToPktHdrVal() const; + /** + * Returns a pkt_hdr RecordVal, which includes not only the IP header, but + * also upper-layer (tcp/udp/icmp) headers. + */ + RecordValPtr ToPktHdrVal() const; - /** - * Same as above, but simply add our values into the record at the - * specified starting index. - */ - RecordValPtr ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const; + /** + * Same as above, but simply add our values into the record at the + * specified starting index. + */ + RecordValPtr ToPktHdrVal(RecordValPtr pkt_hdr, int sindex) const; - bool Reassembled() const { return reassembled; } + bool Reassembled() const { return reassembled; } private: - const struct ip* ip4 = nullptr; - const struct ip6_hdr* ip6 = nullptr; - const IPv6_Hdr_Chain* ip6_hdrs = nullptr; - bool del = false; - bool reassembled = false; - }; + const struct ip* ip4 = nullptr; + const struct ip6_hdr* ip6 = nullptr; + const IPv6_Hdr_Chain* ip6_hdrs = nullptr; + bool del = false; + bool reassembled = false; +}; - } // namespace zeek +} // namespace zeek diff --git a/src/IPAddr.cc b/src/IPAddr.cc index 3b4e037c97..7e162c3b88 100644 --- a/src/IPAddr.cc +++ b/src/IPAddr.cc @@ -13,415 +13,365 @@ #include "zeek/ZeekString.h" #include "zeek/analyzer/Manager.h" -namespace zeek - { +namespace zeek { const IPAddr IPAddr::v4_unspecified = IPAddr(in4_addr{}); const IPAddr IPAddr::v6_unspecified = IPAddr(); -namespace detail - { - -ConnKey::ConnKey(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, - TransportProto t, bool one_way) - { - Init(src, dst, src_port, dst_port, t, one_way); - } - -ConnKey::ConnKey(const ConnTuple& id) - { - Init(id.src_addr, id.dst_addr, id.src_port, id.dst_port, id.proto, id.is_one_way); - } - -ConnKey& ConnKey::operator=(const ConnKey& rhs) - { - if ( this == &rhs ) - return *this; - - // Because of padding in the object, this needs to memset to clear out - // the extra memory used by padding. Otherwise, the session key stuff - // doesn't work quite right. - memset(this, 0, sizeof(ConnKey)); - - memcpy(&ip1, &rhs.ip1, sizeof(in6_addr)); - memcpy(&ip2, &rhs.ip2, sizeof(in6_addr)); - port1 = rhs.port1; - port2 = rhs.port2; - transport = rhs.transport; - valid = rhs.valid; - - return *this; - } - -ConnKey::ConnKey(Val* v) - { - const auto& vt = v->GetType(); - if ( ! IsRecord(vt->Tag()) ) - { - valid = false; - return; - } - - RecordType* vr = vt->AsRecordType(); - auto vl = v->As(); - - int orig_h, orig_p; // indices into record's value list - int resp_h, resp_p; - - if ( vr == id::conn_id ) - { - orig_h = 0; - orig_p = 1; - resp_h = 2; - resp_p = 3; - } - else - { - // While it's not a conn_id, it may have equivalent fields. - orig_h = vr->FieldOffset("orig_h"); - resp_h = vr->FieldOffset("resp_h"); - orig_p = vr->FieldOffset("orig_p"); - resp_p = vr->FieldOffset("resp_p"); - - if ( orig_h < 0 || resp_h < 0 || orig_p < 0 || resp_p < 0 ) - { - valid = false; - return; - } - - // ### we ought to check that the fields have the right - // types, too. - } - - const IPAddr& orig_addr = vl->GetFieldAs(orig_h); - const IPAddr& resp_addr = vl->GetFieldAs(resp_h); - - auto orig_portv = vl->GetFieldAs(orig_p); - auto resp_portv = vl->GetFieldAs(resp_p); - - Init(orig_addr, resp_addr, htons((unsigned short)orig_portv->Port()), - htons((unsigned short)resp_portv->Port()), orig_portv->PortType(), false); - } - -void ConnKey::Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, - TransportProto t, bool one_way) - { - // Because of padding in the object, this needs to memset to clear out - // the extra memory used by padding. Otherwise, the session key stuff - // doesn't work quite right. - memset(this, 0, sizeof(ConnKey)); - - // Lookup up connection based on canonical ordering, which is - // the smaller of and - // followed by the other. - if ( one_way || addr_port_canon_lt(src, src_port, dst, dst_port) ) - { - ip1 = src.in6; - ip2 = dst.in6; - port1 = src_port; - port2 = dst_port; - } - else - { - ip1 = dst.in6; - ip2 = src.in6; - port1 = dst_port; - port2 = src_port; - } - - transport = t; - valid = true; - } - - } // namespace detail - -IPAddr::IPAddr(const String& s) - { - Init(s.CheckString()); - } - -std::unique_ptr IPAddr::MakeHashKey() const - { - return std::make_unique((void*)in6.s6_addr, sizeof(in6.s6_addr)); - } - -static inline uint32_t bit_mask32(int bottom_bits) - { - if ( bottom_bits >= 32 ) - return 0xffffffff; - - return (((uint32_t)1) << bottom_bits) - 1; - } - -void IPAddr::Mask(int top_bits_to_keep) - { - if ( top_bits_to_keep < 0 || top_bits_to_keep > 128 ) - { - reporter->Error("Bad IPAddr::Mask value %d", top_bits_to_keep); - return; - } - - uint32_t mask_bits[4] = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff}; - std::ldiv_t res = std::ldiv(top_bits_to_keep, 32); - - if ( res.quot < 4 ) - mask_bits[res.quot] = htonl(mask_bits[res.quot] & ~bit_mask32(32 - res.rem)); - - for ( unsigned int i = res.quot + 1; i < 4; ++i ) - mask_bits[i] = 0; - - uint32_t* p = reinterpret_cast(in6.s6_addr); - - for ( unsigned int i = 0; i < 4; ++i ) - p[i] &= mask_bits[i]; - } - -void IPAddr::ReverseMask(int top_bits_to_chop) - { - if ( top_bits_to_chop < 0 || top_bits_to_chop > 128 ) - { - reporter->Error("Bad IPAddr::ReverseMask value %d", top_bits_to_chop); - return; - } - - uint32_t mask_bits[4] = {0, 0, 0, 0}; - std::ldiv_t res = std::ldiv(top_bits_to_chop, 32); - - if ( res.quot < 4 ) - mask_bits[res.quot] = htonl(bit_mask32(32 - res.rem)); - - for ( unsigned int i = res.quot + 1; i < 4; ++i ) - mask_bits[i] = 0xffffffff; - - uint32_t* p = reinterpret_cast(in6.s6_addr); - - for ( unsigned int i = 0; i < 4; ++i ) - p[i] &= mask_bits[i]; - } - -bool IPAddr::ConvertString(const char* s, in6_addr* result) - { - for ( auto p = s; *p; ++p ) - if ( *p == ':' ) - // IPv6 - return (inet_pton(AF_INET6, s, result->s6_addr) == 1); - - // IPv4 - // Parse the address directly instead of using inet_pton since - // some platforms have more sensitive implementations than others - // that can't e.g. handle leading zeroes. - int a[4]; - int n = 0; - int match_count = sscanf(s, "%d.%d.%d.%d%n", a + 0, a + 1, a + 2, a + 3, &n); - - if ( match_count != 4 ) - return false; - - if ( s[n] != '\0' ) - return false; - - for ( auto i = 0; i < 4; ++i ) - if ( a[i] < 0 || a[i] > 255 ) - return false; - - uint32_t addr = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]; - addr = htonl(addr); - memcpy(result->s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix)); - memcpy(&result->s6_addr[12], &addr, sizeof(uint32_t)); - return true; - } - -void IPAddr::Init(const char* s) - { - if ( ! ConvertString(s, &in6) ) - { - reporter->Error("Bad IP address: %s", s); - memset(in6.s6_addr, 0, sizeof(in6.s6_addr)); - } - } - -std::string IPAddr::AsString() const - { - if ( GetFamily() == IPv4 ) - { - char s[INET_ADDRSTRLEN]; - - if ( ! zeek_inet_ntop(AF_INET, &in6.s6_addr[12], s, INET_ADDRSTRLEN) ) - return "> 24) & 0xff; - uint32_t a2 = (a >> 16) & 0xff; - uint32_t a1 = (a >> 8) & 0xff; - uint32_t a0 = a & 0xff; - snprintf(buf, sizeof(buf), "%u.%u.%u.%u.in-addr.arpa", a0, a1, a2, a3); - return buf; - } - else - { - static const char hex_digit[] = "0123456789abcdef"; - std::string ptr_name("ip6.arpa"); - uint32_t* p = (uint32_t*)in6.s6_addr; - - for ( unsigned int i = 0; i < 4; ++i ) - { - uint32_t a = ntohl(p[i]); - for ( unsigned int j = 1; j <= 8; ++j ) - { - ptr_name.insert(0, 1, '.'); - ptr_name.insert(0, 1, hex_digit[(a >> (32 - j * 4)) & 0x0f]); - } - } - - return ptr_name; - } - } - -IPPrefix::IPPrefix(const in4_addr& in4, uint8_t length) : prefix(in4), length(96 + length) - { - if ( length > 32 ) - { - reporter->Error("Bad in4_addr IPPrefix length : %d", length); - this->length = 0; - } - - prefix.Mask(this->length); - } - -IPPrefix::IPPrefix(const in6_addr& in6, uint8_t length) : prefix(in6), length(length) - { - if ( length > 128 ) - { - reporter->Error("Bad in6_addr IPPrefix length : %d", length); - this->length = 0; - } - - prefix.Mask(this->length); - } - -bool IPAddr::CheckPrefixLength(uint8_t length, bool len_is_v6_relative) const - { - if ( GetFamily() == IPv4 && ! len_is_v6_relative ) - { - if ( length > 32 ) - return false; - } - - else - { - if ( length > 128 ) - return false; - } - - return true; - } - -IPPrefix::IPPrefix(const IPAddr& addr, uint8_t length, bool len_is_v6_relative) : prefix(addr) - { - if ( prefix.CheckPrefixLength(length, len_is_v6_relative) ) - { - if ( prefix.GetFamily() == IPv4 && ! len_is_v6_relative ) - this->length = length + 96; - else - this->length = length; - } - else - { - auto vstr = prefix.GetFamily() == IPv4 ? "v4" : "v6"; - reporter->Error("Bad IPAddr(%s) IPPrefix length : %d", vstr, length); - this->length = 0; - } - - prefix.Mask(this->length); - } - -std::string IPPrefix::AsString() const - { - char l[16]; - - if ( prefix.GetFamily() == IPv4 ) - modp_uitoa10(length - 96, l); - else - modp_uitoa10(length, l); - - return prefix.AsString() + "/" + l; - } - -std::unique_ptr IPPrefix::MakeHashKey() const - { - struct - { - in6_addr ip; - uint32_t len; - } key; - - key.ip = prefix.in6; - key.len = Length(); - - return std::make_unique(&key, sizeof(key)); - } - -bool IPPrefix::ConvertString(const char* text, IPPrefix* result) - { - std::string s(text); - size_t slash_loc = s.find('/'); - - if ( slash_loc == std::string::npos ) - return false; - - auto ip_str = s.substr(0, slash_loc); - auto len = atoi(s.substr(slash_loc + 1).data()); - - in6_addr tmp; - - if ( ! IPAddr::ConvertString(ip_str.data(), &tmp) ) - return false; - - auto ip = IPAddr(tmp); - - if ( ! ip.CheckPrefixLength(len) ) - return false; - - *result = IPPrefix(ip, len); - return true; - } - - } // namespace zeek +namespace detail { + +ConnKey::ConnKey(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, TransportProto t, + bool one_way) { + Init(src, dst, src_port, dst_port, t, one_way); +} + +ConnKey::ConnKey(const ConnTuple& id) { + Init(id.src_addr, id.dst_addr, id.src_port, id.dst_port, id.proto, id.is_one_way); +} + +ConnKey& ConnKey::operator=(const ConnKey& rhs) { + if ( this == &rhs ) + return *this; + + // Because of padding in the object, this needs to memset to clear out + // the extra memory used by padding. Otherwise, the session key stuff + // doesn't work quite right. + memset(this, 0, sizeof(ConnKey)); + + memcpy(&ip1, &rhs.ip1, sizeof(in6_addr)); + memcpy(&ip2, &rhs.ip2, sizeof(in6_addr)); + port1 = rhs.port1; + port2 = rhs.port2; + transport = rhs.transport; + valid = rhs.valid; + + return *this; +} + +ConnKey::ConnKey(Val* v) { + const auto& vt = v->GetType(); + if ( ! IsRecord(vt->Tag()) ) { + valid = false; + return; + } + + RecordType* vr = vt->AsRecordType(); + auto vl = v->As(); + + int orig_h, orig_p; // indices into record's value list + int resp_h, resp_p; + + if ( vr == id::conn_id ) { + orig_h = 0; + orig_p = 1; + resp_h = 2; + resp_p = 3; + } + else { + // While it's not a conn_id, it may have equivalent fields. + orig_h = vr->FieldOffset("orig_h"); + resp_h = vr->FieldOffset("resp_h"); + orig_p = vr->FieldOffset("orig_p"); + resp_p = vr->FieldOffset("resp_p"); + + if ( orig_h < 0 || resp_h < 0 || orig_p < 0 || resp_p < 0 ) { + valid = false; + return; + } + + // ### we ought to check that the fields have the right + // types, too. + } + + const IPAddr& orig_addr = vl->GetFieldAs(orig_h); + const IPAddr& resp_addr = vl->GetFieldAs(resp_h); + + auto orig_portv = vl->GetFieldAs(orig_p); + auto resp_portv = vl->GetFieldAs(resp_p); + + Init(orig_addr, resp_addr, htons((unsigned short)orig_portv->Port()), htons((unsigned short)resp_portv->Port()), + orig_portv->PortType(), false); +} + +void ConnKey::Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, TransportProto t, + bool one_way) { + // Because of padding in the object, this needs to memset to clear out + // the extra memory used by padding. Otherwise, the session key stuff + // doesn't work quite right. + memset(this, 0, sizeof(ConnKey)); + + // Lookup up connection based on canonical ordering, which is + // the smaller of and + // followed by the other. + if ( one_way || addr_port_canon_lt(src, src_port, dst, dst_port) ) { + ip1 = src.in6; + ip2 = dst.in6; + port1 = src_port; + port2 = dst_port; + } + else { + ip1 = dst.in6; + ip2 = src.in6; + port1 = dst_port; + port2 = src_port; + } + + transport = t; + valid = true; +} + +} // namespace detail + +IPAddr::IPAddr(const String& s) { Init(s.CheckString()); } + +std::unique_ptr IPAddr::MakeHashKey() const { + return std::make_unique((void*)in6.s6_addr, sizeof(in6.s6_addr)); +} + +static inline uint32_t bit_mask32(int bottom_bits) { + if ( bottom_bits >= 32 ) + return 0xffffffff; + + return (((uint32_t)1) << bottom_bits) - 1; +} + +void IPAddr::Mask(int top_bits_to_keep) { + if ( top_bits_to_keep < 0 || top_bits_to_keep > 128 ) { + reporter->Error("Bad IPAddr::Mask value %d", top_bits_to_keep); + return; + } + + uint32_t mask_bits[4] = {0xffffffff, 0xffffffff, 0xffffffff, 0xffffffff}; + std::ldiv_t res = std::ldiv(top_bits_to_keep, 32); + + if ( res.quot < 4 ) + mask_bits[res.quot] = htonl(mask_bits[res.quot] & ~bit_mask32(32 - res.rem)); + + for ( unsigned int i = res.quot + 1; i < 4; ++i ) + mask_bits[i] = 0; + + uint32_t* p = reinterpret_cast(in6.s6_addr); + + for ( unsigned int i = 0; i < 4; ++i ) + p[i] &= mask_bits[i]; +} + +void IPAddr::ReverseMask(int top_bits_to_chop) { + if ( top_bits_to_chop < 0 || top_bits_to_chop > 128 ) { + reporter->Error("Bad IPAddr::ReverseMask value %d", top_bits_to_chop); + return; + } + + uint32_t mask_bits[4] = {0, 0, 0, 0}; + std::ldiv_t res = std::ldiv(top_bits_to_chop, 32); + + if ( res.quot < 4 ) + mask_bits[res.quot] = htonl(bit_mask32(32 - res.rem)); + + for ( unsigned int i = res.quot + 1; i < 4; ++i ) + mask_bits[i] = 0xffffffff; + + uint32_t* p = reinterpret_cast(in6.s6_addr); + + for ( unsigned int i = 0; i < 4; ++i ) + p[i] &= mask_bits[i]; +} + +bool IPAddr::ConvertString(const char* s, in6_addr* result) { + for ( auto p = s; *p; ++p ) + if ( *p == ':' ) + // IPv6 + return (inet_pton(AF_INET6, s, result->s6_addr) == 1); + + // IPv4 + // Parse the address directly instead of using inet_pton since + // some platforms have more sensitive implementations than others + // that can't e.g. handle leading zeroes. + int a[4]; + int n = 0; + int match_count = sscanf(s, "%d.%d.%d.%d%n", a + 0, a + 1, a + 2, a + 3, &n); + + if ( match_count != 4 ) + return false; + + if ( s[n] != '\0' ) + return false; + + for ( auto i = 0; i < 4; ++i ) + if ( a[i] < 0 || a[i] > 255 ) + return false; + + uint32_t addr = (a[0] << 24) | (a[1] << 16) | (a[2] << 8) | a[3]; + addr = htonl(addr); + memcpy(result->s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix)); + memcpy(&result->s6_addr[12], &addr, sizeof(uint32_t)); + return true; +} + +void IPAddr::Init(const char* s) { + if ( ! ConvertString(s, &in6) ) { + reporter->Error("Bad IP address: %s", s); + memset(in6.s6_addr, 0, sizeof(in6.s6_addr)); + } +} + +std::string IPAddr::AsString() const { + if ( GetFamily() == IPv4 ) { + char s[INET_ADDRSTRLEN]; + + if ( ! zeek_inet_ntop(AF_INET, &in6.s6_addr[12], s, INET_ADDRSTRLEN) ) + return "> 24) & 0xff; + uint32_t a2 = (a >> 16) & 0xff; + uint32_t a1 = (a >> 8) & 0xff; + uint32_t a0 = a & 0xff; + snprintf(buf, sizeof(buf), "%u.%u.%u.%u.in-addr.arpa", a0, a1, a2, a3); + return buf; + } + else { + static const char hex_digit[] = "0123456789abcdef"; + std::string ptr_name("ip6.arpa"); + uint32_t* p = (uint32_t*)in6.s6_addr; + + for ( unsigned int i = 0; i < 4; ++i ) { + uint32_t a = ntohl(p[i]); + for ( unsigned int j = 1; j <= 8; ++j ) { + ptr_name.insert(0, 1, '.'); + ptr_name.insert(0, 1, hex_digit[(a >> (32 - j * 4)) & 0x0f]); + } + } + + return ptr_name; + } +} + +IPPrefix::IPPrefix(const in4_addr& in4, uint8_t length) : prefix(in4), length(96 + length) { + if ( length > 32 ) { + reporter->Error("Bad in4_addr IPPrefix length : %d", length); + this->length = 0; + } + + prefix.Mask(this->length); +} + +IPPrefix::IPPrefix(const in6_addr& in6, uint8_t length) : prefix(in6), length(length) { + if ( length > 128 ) { + reporter->Error("Bad in6_addr IPPrefix length : %d", length); + this->length = 0; + } + + prefix.Mask(this->length); +} + +bool IPAddr::CheckPrefixLength(uint8_t length, bool len_is_v6_relative) const { + if ( GetFamily() == IPv4 && ! len_is_v6_relative ) { + if ( length > 32 ) + return false; + } + + else { + if ( length > 128 ) + return false; + } + + return true; +} + +IPPrefix::IPPrefix(const IPAddr& addr, uint8_t length, bool len_is_v6_relative) : prefix(addr) { + if ( prefix.CheckPrefixLength(length, len_is_v6_relative) ) { + if ( prefix.GetFamily() == IPv4 && ! len_is_v6_relative ) + this->length = length + 96; + else + this->length = length; + } + else { + auto vstr = prefix.GetFamily() == IPv4 ? "v4" : "v6"; + reporter->Error("Bad IPAddr(%s) IPPrefix length : %d", vstr, length); + this->length = 0; + } + + prefix.Mask(this->length); +} + +std::string IPPrefix::AsString() const { + char l[16]; + + if ( prefix.GetFamily() == IPv4 ) + modp_uitoa10(length - 96, l); + else + modp_uitoa10(length, l); + + return prefix.AsString() + "/" + l; +} + +std::unique_ptr IPPrefix::MakeHashKey() const { + struct { + in6_addr ip; + uint32_t len; + } key; + + key.ip = prefix.in6; + key.len = Length(); + + return std::make_unique(&key, sizeof(key)); +} + +bool IPPrefix::ConvertString(const char* text, IPPrefix* result) { + std::string s(text); + size_t slash_loc = s.find('/'); + + if ( slash_loc == std::string::npos ) + return false; + + auto ip_str = s.substr(0, slash_loc); + auto len = atoi(s.substr(slash_loc + 1).data()); + + in6_addr tmp; + + if ( ! IPAddr::ConvertString(ip_str.data(), &tmp) ) + return false; + + auto ip = IPAddr(tmp); + + if ( ! ip.CheckPrefixLength(len) ) + return false; + + *result = IPPrefix(ip, len); + return true; +} + +} // namespace zeek diff --git a/src/IPAddr.h b/src/IPAddr.h index a5d37a0e45..4395954eab 100644 --- a/src/IPAddr.h +++ b/src/IPAddr.h @@ -12,682 +12,624 @@ using in4_addr = in_addr; -namespace zeek - { +namespace zeek { class String; struct ConnTuple; class Val; -namespace detail - { +namespace detail { class HashKey; -class ConnKey - { +class ConnKey { public: - in6_addr ip1; - in6_addr ip2; - uint16_t port1 = 0; - uint16_t port2 = 0; - TransportProto transport = TRANSPORT_UNKNOWN; - bool valid = true; + in6_addr ip1; + in6_addr ip2; + uint16_t port1 = 0; + uint16_t port2 = 0; + TransportProto transport = TRANSPORT_UNKNOWN; + bool valid = true; - ConnKey(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, - TransportProto t, bool one_way); - ConnKey(const ConnTuple& conn); - ConnKey(const ConnKey& rhs) { *this = rhs; } - ConnKey(Val* v); + ConnKey(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, TransportProto t, bool one_way); + ConnKey(const ConnTuple& conn); + ConnKey(const ConnKey& rhs) { *this = rhs; } + ConnKey(Val* v); - bool operator<(const ConnKey& rhs) const { return memcmp(this, &rhs, sizeof(ConnKey)) < 0; } - bool operator<=(const ConnKey& rhs) const { return memcmp(this, &rhs, sizeof(ConnKey)) <= 0; } - bool operator==(const ConnKey& rhs) const { return memcmp(this, &rhs, sizeof(ConnKey)) == 0; } - bool operator!=(const ConnKey& rhs) const { return memcmp(this, &rhs, sizeof(ConnKey)) != 0; } - bool operator>=(const ConnKey& rhs) const { return memcmp(this, &rhs, sizeof(ConnKey)) >= 0; } - bool operator>(const ConnKey& rhs) const { return memcmp(this, &rhs, sizeof(ConnKey)) > 0; } + bool operator<(const ConnKey& rhs) const { return memcmp(this, &rhs, sizeof(ConnKey)) < 0; } + bool operator<=(const ConnKey& rhs) const { return memcmp(this, &rhs, sizeof(ConnKey)) <= 0; } + bool operator==(const ConnKey& rhs) const { return memcmp(this, &rhs, sizeof(ConnKey)) == 0; } + bool operator!=(const ConnKey& rhs) const { return memcmp(this, &rhs, sizeof(ConnKey)) != 0; } + bool operator>=(const ConnKey& rhs) const { return memcmp(this, &rhs, sizeof(ConnKey)) >= 0; } + bool operator>(const ConnKey& rhs) const { return memcmp(this, &rhs, sizeof(ConnKey)) > 0; } - ConnKey& operator=(const ConnKey& rhs); + ConnKey& operator=(const ConnKey& rhs); private: - void Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, - TransportProto t, bool one_way); - }; + void Init(const IPAddr& src, const IPAddr& dst, uint16_t src_port, uint16_t dst_port, TransportProto t, + bool one_way); +}; - } // namespace detail +} // namespace detail /** * Class storing both IPv4 and IPv6 addresses. */ -class IPAddr - { +class IPAddr { public: - /** - * Address family. - */ - using Family = IPFamily; + /** + * Address family. + */ + using Family = IPFamily; - /** - * Byte order. - */ - enum ByteOrder - { - Host, - Network - }; + /** + * Byte order. + */ + enum ByteOrder { Host, Network }; - /** - * Constructs the unspecified IPv6 address (all 128 bits zeroed). - */ - IPAddr() { memset(in6.s6_addr, 0, sizeof(in6.s6_addr)); } + /** + * Constructs the unspecified IPv6 address (all 128 bits zeroed). + */ + IPAddr() { memset(in6.s6_addr, 0, sizeof(in6.s6_addr)); } - /** - * Constructs an address instance from an IPv4 address. - * - * @param in6 The IPv6 address. - */ - explicit IPAddr(const in4_addr& in4) - { - memcpy(in6.s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix)); - memcpy(&in6.s6_addr[12], &in4.s_addr, sizeof(in4.s_addr)); - } + /** + * Constructs an address instance from an IPv4 address. + * + * @param in6 The IPv6 address. + */ + explicit IPAddr(const in4_addr& in4) { + memcpy(in6.s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix)); + memcpy(&in6.s6_addr[12], &in4.s_addr, sizeof(in4.s_addr)); + } - /** - * Constructs an address instance from an IPv6 address. - * - * @param in6 The IPv6 address. - */ - explicit IPAddr(const in6_addr& arg_in6) : in6(arg_in6) { } + /** + * Constructs an address instance from an IPv6 address. + * + * @param in6 The IPv6 address. + */ + explicit IPAddr(const in6_addr& arg_in6) : in6(arg_in6) {} - /** - * Constructs an address instance from a string representation. - * - * @param s String containing an IP address as either a dotted IPv4 - * address or a hex IPv6 address. - */ - IPAddr(const std::string& s) { Init(s.data()); } + /** + * Constructs an address instance from a string representation. + * + * @param s String containing an IP address as either a dotted IPv4 + * address or a hex IPv6 address. + */ + IPAddr(const std::string& s) { Init(s.data()); } - /** - * Constructs an address instance from a string representation. - * - * @param s ASCIIZ string containing an IP address as either a - * dotted IPv4 address or a hex IPv6 address. - */ - IPAddr(const char* s) { Init(s); } + /** + * Constructs an address instance from a string representation. + * + * @param s ASCIIZ string containing an IP address as either a + * dotted IPv4 address or a hex IPv6 address. + */ + IPAddr(const char* s) { Init(s); } - /** - * Constructs an address instance from a string representation. - * - * @param s String containing an IP address as either a dotted IPv4 - * address or a hex IPv6 address. - */ - explicit IPAddr(const String& s); + /** + * Constructs an address instance from a string representation. + * + * @param s String containing an IP address as either a dotted IPv4 + * address or a hex IPv6 address. + */ + explicit IPAddr(const String& s); - /** - * Constructs an address instance from a raw byte representation. - * - * @param family The address family. - * - * @param bytes A pointer to the raw byte representation. This must point - * to 4 bytes if \a family is IPv4, and to 16 bytes if \a family is - * IPv6. - * - * @param order Indicates whether the raw representation pointed to - * by \a bytes is stored in network or host order. - */ - IPAddr(Family family, const uint32_t* bytes, ByteOrder order); + /** + * Constructs an address instance from a raw byte representation. + * + * @param family The address family. + * + * @param bytes A pointer to the raw byte representation. This must point + * to 4 bytes if \a family is IPv4, and to 16 bytes if \a family is + * IPv6. + * + * @param order Indicates whether the raw representation pointed to + * by \a bytes is stored in network or host order. + */ + IPAddr(Family family, const uint32_t* bytes, ByteOrder order); - /** - * Copy constructor. - */ - IPAddr(const IPAddr& other) : in6(other.in6){}; + /** + * Copy constructor. + */ + IPAddr(const IPAddr& other) : in6(other.in6){}; - /** - * Destructor. - */ - ~IPAddr() = default; + /** + * Destructor. + */ + ~IPAddr() = default; - /** - * Returns the address' family. - */ - Family GetFamily() const - { - if ( memcmp(in6.s6_addr, v4_mapped_prefix, 12) == 0 ) - return IPv4; + /** + * Returns the address' family. + */ + Family GetFamily() const { + if ( memcmp(in6.s6_addr, v4_mapped_prefix, 12) == 0 ) + return IPv4; - return IPv6; - } + return IPv6; + } - /** - * Returns true if the address represents a loopback device. - */ - bool IsLoopback() const; + /** + * Returns true if the address represents a loopback device. + */ + bool IsLoopback() const; - /** - * Returns true if the address represents a multicast address. - */ - bool IsMulticast() const - { - if ( GetFamily() == IPv4 ) - return in6.s6_addr[12] == 224; + /** + * Returns true if the address represents a multicast address. + */ + bool IsMulticast() const { + if ( GetFamily() == IPv4 ) + return in6.s6_addr[12] == 224; - return in6.s6_addr[0] == 0xff; - } + return in6.s6_addr[0] == 0xff; + } - /** - * Returns true if the address represents a broadcast address. - */ - bool IsBroadcast() const - { - if ( GetFamily() == IPv4 ) - return ((in6.s6_addr[12] == 0xff) && (in6.s6_addr[13] == 0xff) && - (in6.s6_addr[14] == 0xff) && (in6.s6_addr[15] == 0xff)); + /** + * Returns true if the address represents a broadcast address. + */ + bool IsBroadcast() const { + if ( GetFamily() == IPv4 ) + return ((in6.s6_addr[12] == 0xff) && (in6.s6_addr[13] == 0xff) && (in6.s6_addr[14] == 0xff) && + (in6.s6_addr[15] == 0xff)); - return false; - } + return false; + } - /** - * Retrieves the raw byte representation of the address. - * - * @param bytes The pointer to which \a bytes points will be set to - * the address of the raw representation in network-byte order. - * The return value indicates how many 32-bit words are valid starting at - * that address. The pointer will be valid as long as the address instance - * exists. - * - * @return The number of 32-bit words the raw representation uses. This - * will be 1 for an IPv4 address and 4 for an IPv6 address. - */ - int GetBytes(const uint32_t** bytes) const - { - if ( GetFamily() == IPv4 ) - { - *bytes = (uint32_t*)&in6.s6_addr[12]; - return 1; - } - else - { - *bytes = (uint32_t*)in6.s6_addr; - return 4; - } - } + /** + * Retrieves the raw byte representation of the address. + * + * @param bytes The pointer to which \a bytes points will be set to + * the address of the raw representation in network-byte order. + * The return value indicates how many 32-bit words are valid starting at + * that address. The pointer will be valid as long as the address instance + * exists. + * + * @return The number of 32-bit words the raw representation uses. This + * will be 1 for an IPv4 address and 4 for an IPv6 address. + */ + int GetBytes(const uint32_t** bytes) const { + if ( GetFamily() == IPv4 ) { + *bytes = (uint32_t*)&in6.s6_addr[12]; + return 1; + } + else { + *bytes = (uint32_t*)in6.s6_addr; + return 4; + } + } - /** - * Retrieves a copy of the IPv6 raw byte representation of the address. - * If the internal address is IPv4, then the copied bytes use the - * IPv4 to IPv6 address mapping to return a full 16 bytes. - * - * @param bytes The pointer to a memory location in which the - * raw bytes of the address are to be copied. - * - * @param order The byte-order in which the returned raw bytes are copied. - * The default is network order. - */ - void CopyIPv6(uint32_t* bytes, ByteOrder order = Network) const - { - memcpy(bytes, in6.s6_addr, sizeof(in6.s6_addr)); + /** + * Retrieves a copy of the IPv6 raw byte representation of the address. + * If the internal address is IPv4, then the copied bytes use the + * IPv4 to IPv6 address mapping to return a full 16 bytes. + * + * @param bytes The pointer to a memory location in which the + * raw bytes of the address are to be copied. + * + * @param order The byte-order in which the returned raw bytes are copied. + * The default is network order. + */ + void CopyIPv6(uint32_t* bytes, ByteOrder order = Network) const { + memcpy(bytes, in6.s6_addr, sizeof(in6.s6_addr)); - if ( order == Host ) - { - for ( unsigned int i = 0; i < 4; ++i ) - bytes[i] = ntohl(bytes[i]); - } - } + if ( order == Host ) { + for ( unsigned int i = 0; i < 4; ++i ) + bytes[i] = ntohl(bytes[i]); + } + } - /** - * Retrieves a copy of the IPv6 raw byte representation of the address. - * @see CopyIPv6(uint32_t) - */ - void CopyIPv6(in6_addr* arg_in6) const - { - memcpy(arg_in6->s6_addr, in6.s6_addr, sizeof(in6.s6_addr)); - } + /** + * Retrieves a copy of the IPv6 raw byte representation of the address. + * @see CopyIPv6(uint32_t) + */ + void CopyIPv6(in6_addr* arg_in6) const { memcpy(arg_in6->s6_addr, in6.s6_addr, sizeof(in6.s6_addr)); } - /** - * Retrieves a copy of the IPv4 raw byte representation of the address. - * The caller should verify the address is of the IPv4 family type - * beforehand. @see GetFamily(). - * - * @param in4 The pointer to a memory location in which the raw bytes - * of the address are to be copied in network byte-order. - */ - void CopyIPv4(in4_addr* in4) const - { - memcpy(&in4->s_addr, &in6.s6_addr[12], sizeof(in4->s_addr)); - } + /** + * Retrieves a copy of the IPv4 raw byte representation of the address. + * The caller should verify the address is of the IPv4 family type + * beforehand. @see GetFamily(). + * + * @param in4 The pointer to a memory location in which the raw bytes + * of the address are to be copied in network byte-order. + */ + void CopyIPv4(in4_addr* in4) const { memcpy(&in4->s_addr, &in6.s6_addr[12], sizeof(in4->s_addr)); } - /** - * Returns a key that can be used to lookup the IP Address in a hash table. - */ - std::unique_ptr MakeHashKey() const; + /** + * Returns a key that can be used to lookup the IP Address in a hash table. + */ + std::unique_ptr MakeHashKey() const; - /** - * Masks out lower bits of the address. - * - * @param top_bits_to_keep The number of bits \a not to mask out, - * counting from the highest order bit. The value is always - * interpreted relative to the IPv6 bit width, even if the address - * is IPv4. That means if compute ``192.168.1.2/16``, you need to - * pass in 112 (i.e., 96 + 16). The value must be in the range from - * 0 to 128. - */ - void Mask(int top_bits_to_keep); + /** + * Masks out lower bits of the address. + * + * @param top_bits_to_keep The number of bits \a not to mask out, + * counting from the highest order bit. The value is always + * interpreted relative to the IPv6 bit width, even if the address + * is IPv4. That means if compute ``192.168.1.2/16``, you need to + * pass in 112 (i.e., 96 + 16). The value must be in the range from + * 0 to 128. + */ + void Mask(int top_bits_to_keep); - /** - * Masks out top bits of the address. - * - * @param top_bits_to_chop The number of bits to mask out, counting - * from the highest order bit. The value is always interpreted relative - * to the IPv6 bit width, even if the address is IPv4. So to mask out - * the first 16 bits of an IPv4 address, pass in 112 (i.e., 96 + 16). - * The value must be in the range from 0 to 128. - */ - void ReverseMask(int top_bits_to_chop); + /** + * Masks out top bits of the address. + * + * @param top_bits_to_chop The number of bits to mask out, counting + * from the highest order bit. The value is always interpreted relative + * to the IPv6 bit width, even if the address is IPv4. So to mask out + * the first 16 bits of an IPv4 address, pass in 112 (i.e., 96 + 16). + * The value must be in the range from 0 to 128. + */ + void ReverseMask(int top_bits_to_chop); - /** - * Assignment operator. - */ - IPAddr& operator=(const IPAddr& other) - { - // No self-assignment check here because it's correct without it and - // makes the common case faster. - in6 = other.in6; - return *this; - } + /** + * Assignment operator. + */ + IPAddr& operator=(const IPAddr& other) { + // No self-assignment check here because it's correct without it and + // makes the common case faster. + in6 = other.in6; + return *this; + } - /** - * Bitwise OR operator returns the IP address resulting from the bitwise - * OR operation on the raw bytes of this address with another. - */ - IPAddr operator|(const IPAddr& other) - { - in6_addr result; - for ( int i = 0; i < 16; ++i ) - result.s6_addr[i] = this->in6.s6_addr[i] | other.in6.s6_addr[i]; + /** + * Bitwise OR operator returns the IP address resulting from the bitwise + * OR operation on the raw bytes of this address with another. + */ + IPAddr operator|(const IPAddr& other) { + in6_addr result; + for ( int i = 0; i < 16; ++i ) + result.s6_addr[i] = this->in6.s6_addr[i] | other.in6.s6_addr[i]; - return IPAddr(result); - } + return IPAddr(result); + } - /** - * Returns a string representation of the address. IPv4 addresses - * will be returned in dotted representation, IPv6 addresses in - * compressed hex. - */ - std::string AsString() const; + /** + * Returns a string representation of the address. IPv4 addresses + * will be returned in dotted representation, IPv6 addresses in + * compressed hex. + */ + std::string AsString() const; - /** - * Returns a string representation of the address suitable for inclusion - * in an URI. For IPv4 addresses, this is the same as AsString(), but - * IPv6 addresses are encased in square brackets. - */ - std::string AsURIString() const - { - if ( GetFamily() == IPv4 ) - return AsString(); + /** + * Returns a string representation of the address suitable for inclusion + * in an URI. For IPv4 addresses, this is the same as AsString(), but + * IPv6 addresses are encased in square brackets. + */ + std::string AsURIString() const { + if ( GetFamily() == IPv4 ) + return AsString(); - return std::string("[") + AsString() + "]"; - } + return std::string("[") + AsString() + "]"; + } - /** - * Returns a host-order, plain hex string representation of the address. - */ - std::string AsHexString() const; + /** + * Returns a host-order, plain hex string representation of the address. + */ + std::string AsHexString() const; - /** - * Returns a string representation of the address. This returns the - * same as AsString(). - */ - operator std::string() const { return AsString(); } + /** + * Returns a string representation of the address. This returns the + * same as AsString(). + */ + operator std::string() const { return AsString(); } - /** - * Returns a reverse pointer name associated with the IP address. - * For example, 192.168.0.1's reverse pointer is 1.0.168.192.in-addr.arpa. - */ - std::string PtrName() const; + /** + * Returns a reverse pointer name associated with the IP address. + * For example, 192.168.0.1's reverse pointer is 1.0.168.192.in-addr.arpa. + */ + std::string PtrName() const; - /** - * Comparison operator for IP address. - */ - friend bool operator==(const IPAddr& addr1, const IPAddr& addr2) - { - return memcmp(&addr1.in6, &addr2.in6, sizeof(in6_addr)) == 0; - } + /** + * Comparison operator for IP address. + */ + friend bool operator==(const IPAddr& addr1, const IPAddr& addr2) { + return memcmp(&addr1.in6, &addr2.in6, sizeof(in6_addr)) == 0; + } - friend bool operator!=(const IPAddr& addr1, const IPAddr& addr2) { return ! (addr1 == addr2); } + friend bool operator!=(const IPAddr& addr1, const IPAddr& addr2) { return ! (addr1 == addr2); } - /** - * Comparison operator IP addresses. This defines a well-defined order for - * IP addresses. However, the order does not necessarily correspond to - * their numerical values. - */ - friend bool operator<(const IPAddr& addr1, const IPAddr& addr2) - { - return memcmp(&addr1.in6, &addr2.in6, sizeof(in6_addr)) < 0; - } + /** + * Comparison operator IP addresses. This defines a well-defined order for + * IP addresses. However, the order does not necessarily correspond to + * their numerical values. + */ + friend bool operator<(const IPAddr& addr1, const IPAddr& addr2) { + return memcmp(&addr1.in6, &addr2.in6, sizeof(in6_addr)) < 0; + } - friend bool operator<=(const IPAddr& addr1, const IPAddr& addr2) - { - return addr1 < addr2 || addr1 == addr2; - } + friend bool operator<=(const IPAddr& addr1, const IPAddr& addr2) { return addr1 < addr2 || addr1 == addr2; } - friend bool operator>=(const IPAddr& addr1, const IPAddr& addr2) { return ! (addr1 < addr2); } + friend bool operator>=(const IPAddr& addr1, const IPAddr& addr2) { return ! (addr1 < addr2); } - friend bool operator>(const IPAddr& addr1, const IPAddr& addr2) { return ! (addr1 <= addr2); } + friend bool operator>(const IPAddr& addr1, const IPAddr& addr2) { return ! (addr1 <= addr2); } - /** - * Converts the address into the type used internally by the - * inter-thread communication. - */ - void ConvertToThreadingValue(threading::Value::addr_t* v) const; + /** + * Converts the address into the type used internally by the + * inter-thread communication. + */ + void ConvertToThreadingValue(threading::Value::addr_t* v) const; - /** - * Check if an IP prefix length would be valid against this IP address. - * - * @param length the IP prefix length to check - * - * @param len_is_v6_relative whether the length is relative to the full - * IPv6 address length (e.g. since IPv4 addrs are internally stored - * in v4-to-v6-mapped format, this parameter disambiguates whether - * a the length is in the usual 32-bit space for IPv4 or the full - * 128-bit space of IPv6 address. - * - * @return whether the prefix length is valid. - */ - bool CheckPrefixLength(uint8_t length, bool len_is_v6_relative = false) const; + /** + * Check if an IP prefix length would be valid against this IP address. + * + * @param length the IP prefix length to check + * + * @param len_is_v6_relative whether the length is relative to the full + * IPv6 address length (e.g. since IPv4 addrs are internally stored + * in v4-to-v6-mapped format, this parameter disambiguates whether + * a the length is in the usual 32-bit space for IPv4 or the full + * 128-bit space of IPv6 address. + * + * @return whether the prefix length is valid. + */ + bool CheckPrefixLength(uint8_t length, bool len_is_v6_relative = false) const; - /** - * Converts an IPv4 or IPv6 string into a network address structure - * (IPv6 or v4-to-v6-mapping in network bytes order). - * - * @param s the IPv4 or IPv6 string to convert (ASCII, NUL-terminated). - * - * @param result buffer that the caller supplies to store the result. - * - * @return whether the conversion was successful. - */ - static bool ConvertString(const char* s, in6_addr* result); + /** + * Converts an IPv4 or IPv6 string into a network address structure + * (IPv6 or v4-to-v6-mapping in network bytes order). + * + * @param s the IPv4 or IPv6 string to convert (ASCII, NUL-terminated). + * + * @param result buffer that the caller supplies to store the result. + * + * @return whether the conversion was successful. + */ + static bool ConvertString(const char* s, in6_addr* result); - /** - * @param s the IPv4 or IPv6 string to convert (ASCII, NUL-terminated). - * - * @return whether the string is a valid IP address - */ - static bool IsValid(const char* s) - { - in6_addr tmp; - return ConvertString(s, &tmp); - } + /** + * @param s the IPv4 or IPv6 string to convert (ASCII, NUL-terminated). + * + * @return whether the string is a valid IP address + */ + static bool IsValid(const char* s) { + in6_addr tmp; + return ConvertString(s, &tmp); + } - /** - * Unspecified IPv4 addr, "0.0.0.0". - */ - static const IPAddr v4_unspecified; + /** + * Unspecified IPv4 addr, "0.0.0.0". + */ + static const IPAddr v4_unspecified; - /** - * Unspecified IPv6 addr, "::". - */ - static const IPAddr v6_unspecified; + /** + * Unspecified IPv6 addr, "::". + */ + static const IPAddr v6_unspecified; private: - friend class detail::ConnKey; - friend class IPPrefix; + friend class detail::ConnKey; + friend class IPPrefix; - /** - * Initializes an address instance from a string representation. - * - * @param s String containing an IP address as either a dotted IPv4 - * address or a hex IPv6 address (ASCII, NUL-terminated). - */ - void Init(const char* s); + /** + * Initializes an address instance from a string representation. + * + * @param s String containing an IP address as either a dotted IPv4 + * address or a hex IPv6 address (ASCII, NUL-terminated). + */ + void Init(const char* s); - in6_addr in6; // IPv6 or v4-to-v6-mapped address + in6_addr in6; // IPv6 or v4-to-v6-mapped address - // Top 96 bits of a v4-mapped-addr. - static constexpr uint8_t v4_mapped_prefix[12] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}; - }; + // Top 96 bits of a v4-mapped-addr. + static constexpr uint8_t v4_mapped_prefix[12] = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff}; +}; -inline IPAddr::IPAddr(Family family, const uint32_t* bytes, ByteOrder order) - { - if ( family == IPv4 ) - { - memcpy(in6.s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix)); - memcpy(&in6.s6_addr[12], bytes, sizeof(uint32_t)); +inline IPAddr::IPAddr(Family family, const uint32_t* bytes, ByteOrder order) { + if ( family == IPv4 ) { + memcpy(in6.s6_addr, v4_mapped_prefix, sizeof(v4_mapped_prefix)); + memcpy(&in6.s6_addr[12], bytes, sizeof(uint32_t)); - if ( order == Host ) - { - uint32_t* p = (uint32_t*)&in6.s6_addr[12]; - *p = htonl(*p); - } - } + if ( order == Host ) { + uint32_t* p = (uint32_t*)&in6.s6_addr[12]; + *p = htonl(*p); + } + } - else - { - memcpy(in6.s6_addr, bytes, sizeof(in6.s6_addr)); + else { + memcpy(in6.s6_addr, bytes, sizeof(in6.s6_addr)); - if ( order == Host ) - { - for ( unsigned int i = 0; i < 4; ++i ) - { - uint32_t* p = (uint32_t*)&in6.s6_addr[i * 4]; - *p = htonl(*p); - } - } - } - } + if ( order == Host ) { + for ( unsigned int i = 0; i < 4; ++i ) { + uint32_t* p = (uint32_t*)&in6.s6_addr[i * 4]; + *p = htonl(*p); + } + } + } +} -inline bool IPAddr::IsLoopback() const - { - if ( GetFamily() == IPv4 ) - return in6.s6_addr[12] == 127; +inline bool IPAddr::IsLoopback() const { + if ( GetFamily() == IPv4 ) + return in6.s6_addr[12] == 127; - else - return ((in6.s6_addr[0] == 0) && (in6.s6_addr[1] == 0) && (in6.s6_addr[2] == 0) && - (in6.s6_addr[3] == 0) && (in6.s6_addr[4] == 0) && (in6.s6_addr[5] == 0) && - (in6.s6_addr[6] == 0) && (in6.s6_addr[7] == 0) && (in6.s6_addr[8] == 0) && - (in6.s6_addr[9] == 0) && (in6.s6_addr[10] == 0) && (in6.s6_addr[11] == 0) && - (in6.s6_addr[12] == 0) && (in6.s6_addr[13] == 0) && (in6.s6_addr[14] == 0) && - (in6.s6_addr[15] == 1)); - } + else + return ((in6.s6_addr[0] == 0) && (in6.s6_addr[1] == 0) && (in6.s6_addr[2] == 0) && (in6.s6_addr[3] == 0) && + (in6.s6_addr[4] == 0) && (in6.s6_addr[5] == 0) && (in6.s6_addr[6] == 0) && (in6.s6_addr[7] == 0) && + (in6.s6_addr[8] == 0) && (in6.s6_addr[9] == 0) && (in6.s6_addr[10] == 0) && (in6.s6_addr[11] == 0) && + (in6.s6_addr[12] == 0) && (in6.s6_addr[13] == 0) && (in6.s6_addr[14] == 0) && (in6.s6_addr[15] == 1)); +} -inline void IPAddr::ConvertToThreadingValue(threading::Value::addr_t* v) const - { - v->family = GetFamily(); +inline void IPAddr::ConvertToThreadingValue(threading::Value::addr_t* v) const { + v->family = GetFamily(); - switch ( v->family ) - { - case IPv4: - CopyIPv4(&v->in.in4); - return; + switch ( v->family ) { + case IPv4: CopyIPv4(&v->in.in4); return; - case IPv6: - CopyIPv6(&v->in.in6); - return; - } - } + case IPv6: CopyIPv6(&v->in.in6); return; + } +} /** * Class storing both IPv4 and IPv6 prefixes * (i.e., \c 192.168.1.1/16 and \c FD00::/8. */ -class IPPrefix - { +class IPPrefix { public: - /** - * Constructs a prefix 0/0. - */ - IPPrefix() = default; + /** + * Constructs a prefix 0/0. + */ + IPPrefix() = default; - /** - * Constructs a prefix instance from an IPv4 address and a prefix - * length. - * - * @param in4 The IPv4 address. - * - * @param length The prefix length in the range from 0 to 32. - */ - IPPrefix(const in4_addr& in4, uint8_t length); + /** + * Constructs a prefix instance from an IPv4 address and a prefix + * length. + * + * @param in4 The IPv4 address. + * + * @param length The prefix length in the range from 0 to 32. + */ + IPPrefix(const in4_addr& in4, uint8_t length); - /** - * Constructs a prefix instance from an IPv6 address and a prefix - * length. - * - * @param in6 The IPv6 address. - * - * @param length The prefix length in the range from 0 to 128. - */ - IPPrefix(const in6_addr& in6, uint8_t length); + /** + * Constructs a prefix instance from an IPv6 address and a prefix + * length. + * + * @param in6 The IPv6 address. + * + * @param length The prefix length in the range from 0 to 128. + */ + IPPrefix(const in6_addr& in6, uint8_t length); - /** - * Constructs a prefix instance from an IPAddr object and prefix length. - * - * @param addr The IP address. - * - * @param length The prefix length in the range from 0 to 128 - * - * @param len_is_v6_relative Whether \a length is relative to the full - * 128 bits of an IPv6 address. If false and \a addr is an IPv4 - * address, then \a length is expected to range from 0 to 32. If true - * \a length is expected to range from 0 to 128 even if \a addr is IPv4, - * meaning that the mask is to apply to the IPv4-mapped-IPv6 representation. - */ - IPPrefix(const IPAddr& addr, uint8_t length, bool len_is_v6_relative = false); + /** + * Constructs a prefix instance from an IPAddr object and prefix length. + * + * @param addr The IP address. + * + * @param length The prefix length in the range from 0 to 128 + * + * @param len_is_v6_relative Whether \a length is relative to the full + * 128 bits of an IPv6 address. If false and \a addr is an IPv4 + * address, then \a length is expected to range from 0 to 32. If true + * \a length is expected to range from 0 to 128 even if \a addr is IPv4, + * meaning that the mask is to apply to the IPv4-mapped-IPv6 representation. + */ + IPPrefix(const IPAddr& addr, uint8_t length, bool len_is_v6_relative = false); - /** - * Copy constructor. - */ - IPPrefix(const IPPrefix& other) : prefix(other.prefix), length(other.length) { } + /** + * Copy constructor. + */ + IPPrefix(const IPPrefix& other) : prefix(other.prefix), length(other.length) {} - /** - * Destructor. - */ - ~IPPrefix() = default; + /** + * Destructor. + */ + ~IPPrefix() = default; - /** - * Returns the prefix in the form of an IP address. The address will - * have all bits not part of the prefixed set to zero. - */ - const IPAddr& Prefix() const { return prefix; } + /** + * Returns the prefix in the form of an IP address. The address will + * have all bits not part of the prefixed set to zero. + */ + const IPAddr& Prefix() const { return prefix; } - /** - * Returns the bit length of the prefix, relative to the 32 bits - * of an IPv4 prefix or relative to the 128 bits of an IPv6 prefix. - */ - uint8_t Length() const { return prefix.GetFamily() == IPv4 ? length - 96 : length; } + /** + * Returns the bit length of the prefix, relative to the 32 bits + * of an IPv4 prefix or relative to the 128 bits of an IPv6 prefix. + */ + uint8_t Length() const { return prefix.GetFamily() == IPv4 ? length - 96 : length; } - /** - * Returns the bit length of the prefix always relative to a full - * 128 bits of an IPv6 prefix (or IPv4 mapped to IPv6). - */ - uint8_t LengthIPv6() const { return length; } + /** + * Returns the bit length of the prefix always relative to a full + * 128 bits of an IPv6 prefix (or IPv4 mapped to IPv6). + */ + uint8_t LengthIPv6() const { return length; } - /** - * Returns true if the given address is part of the prefix. - * - * @param addr The address to test. - */ - bool Contains(const IPAddr& addr) const - { - IPAddr p(addr); - p.Mask(length); - return p == prefix; - } - /** - * Assignment operator. - */ - IPPrefix& operator=(const IPPrefix& other) - { - // No self-assignment check here because it's correct without it and - // makes the common case faster. - prefix = other.prefix; - length = other.length; - return *this; - } + /** + * Returns true if the given address is part of the prefix. + * + * @param addr The address to test. + */ + bool Contains(const IPAddr& addr) const { + IPAddr p(addr); + p.Mask(length); + return p == prefix; + } + /** + * Assignment operator. + */ + IPPrefix& operator=(const IPPrefix& other) { + // No self-assignment check here because it's correct without it and + // makes the common case faster. + prefix = other.prefix; + length = other.length; + return *this; + } - /** - * Returns a string representation of the prefix. IPv4 addresses - * will be returned in dotted representation, IPv6 addresses in - * compressed hex. - */ - std::string AsString() const; + /** + * Returns a string representation of the prefix. IPv4 addresses + * will be returned in dotted representation, IPv6 addresses in + * compressed hex. + */ + std::string AsString() const; - operator std::string() const { return AsString(); } + operator std::string() const { return AsString(); } - /** - * Returns a key that can be used to lookup the IP Prefix in a hash table. - */ - std::unique_ptr MakeHashKey() const; + /** + * Returns a key that can be used to lookup the IP Prefix in a hash table. + */ + std::unique_ptr MakeHashKey() const; - /** - * Converts the prefix into the type used internally by the - * inter-thread communication. - */ - void ConvertToThreadingValue(threading::Value::subnet_t* v) const - { - v->length = length; - prefix.ConvertToThreadingValue(&v->prefix); - } + /** + * Converts the prefix into the type used internally by the + * inter-thread communication. + */ + void ConvertToThreadingValue(threading::Value::subnet_t* v) const { + v->length = length; + prefix.ConvertToThreadingValue(&v->prefix); + } - /** - * Comparison operator for IP prefix. - */ - friend bool operator==(const IPPrefix& net1, const IPPrefix& net2) - { - return net1.Prefix() == net2.Prefix() && net1.Length() == net2.Length(); - } + /** + * Comparison operator for IP prefix. + */ + friend bool operator==(const IPPrefix& net1, const IPPrefix& net2) { + return net1.Prefix() == net2.Prefix() && net1.Length() == net2.Length(); + } - friend bool operator!=(const IPPrefix& net1, const IPPrefix& net2) { return ! (net1 == net2); } + friend bool operator!=(const IPPrefix& net1, const IPPrefix& net2) { return ! (net1 == net2); } - /** - * Comparison operator IP prefixes. This defines a well-defined order for - * IP prefix. However, the order does not necessarily corresponding to their - * numerical values. - */ - friend bool operator<(const IPPrefix& net1, const IPPrefix& net2) - { - if ( net1.Prefix() < net2.Prefix() ) - return true; + /** + * Comparison operator IP prefixes. This defines a well-defined order for + * IP prefix. However, the order does not necessarily corresponding to their + * numerical values. + */ + friend bool operator<(const IPPrefix& net1, const IPPrefix& net2) { + if ( net1.Prefix() < net2.Prefix() ) + return true; - else if ( net1.Prefix() == net2.Prefix() ) - return net1.Length() < net2.Length(); + else if ( net1.Prefix() == net2.Prefix() ) + return net1.Length() < net2.Length(); - else - return false; - } + else + return false; + } - friend bool operator<=(const IPPrefix& net1, const IPPrefix& net2) - { - return net1 < net2 || net1 == net2; - } + friend bool operator<=(const IPPrefix& net1, const IPPrefix& net2) { return net1 < net2 || net1 == net2; } - friend bool operator>=(const IPPrefix& net1, const IPPrefix& net2) { return ! (net1 < net2); } + friend bool operator>=(const IPPrefix& net1, const IPPrefix& net2) { return ! (net1 < net2); } - friend bool operator>(const IPPrefix& net1, const IPPrefix& net2) { return ! (net1 <= net2); } + friend bool operator>(const IPPrefix& net1, const IPPrefix& net2) { return ! (net1 <= net2); } - /** - * Converts an IPv4 or IPv6 prefix string into a network address prefix structure. - * - * @param s the IPv4 or IPv6 prefix string to convert (ASCII, NUL-terminated). - * - * @param result buffer that the caller supplies to store the result. - * - * @return whether the conversion was successful. - */ - static bool ConvertString(const char* s, IPPrefix* result); + /** + * Converts an IPv4 or IPv6 prefix string into a network address prefix structure. + * + * @param s the IPv4 or IPv6 prefix string to convert (ASCII, NUL-terminated). + * + * @param result buffer that the caller supplies to store the result. + * + * @return whether the conversion was successful. + */ + static bool ConvertString(const char* s, IPPrefix* result); - /** - * @param s the IPv4 or IPv6 prefix string to convert (ASCII, NUL-terminated). - * - * @return whether the string is a valid IP address prefix - */ - static bool IsValid(const char* s) - { - IPPrefix tmp; - return ConvertString(s, &tmp); - } + /** + * @param s the IPv4 or IPv6 prefix string to convert (ASCII, NUL-terminated). + * + * @return whether the string is a valid IP address prefix + */ + static bool IsValid(const char* s) { + IPPrefix tmp; + return ConvertString(s, &tmp); + } private: - IPAddr prefix; // We store it as an address with the non-prefix bits masked out via Mask(). - uint8_t length = 0; // The bit length of the prefix relative to full IPv6 addr. - }; + IPAddr prefix; // We store it as an address with the non-prefix bits masked out via Mask(). + uint8_t length = 0; // The bit length of the prefix relative to full IPv6 addr. +}; - } // namespace zeek +} // namespace zeek diff --git a/src/IntSet.cc b/src/IntSet.cc index 554e673f37..e2cac41798 100644 --- a/src/IntSet.cc +++ b/src/IntSet.cc @@ -4,20 +4,18 @@ #include -namespace zeek::detail - { +namespace zeek::detail { -void IntSet::Expand(unsigned int i) - { - unsigned int newsize = i / 8 + 1; - unsigned char* newset = new unsigned char[newsize]; +void IntSet::Expand(unsigned int i) { + unsigned int newsize = i / 8 + 1; + unsigned char* newset = new unsigned char[newsize]; - memset(newset, 0, newsize); - memcpy(newset, set, size); + memset(newset, 0, newsize); + memcpy(newset, set, size); - delete[] set; - size = newsize; - set = newset; - } + delete[] set; + size = newsize; + set = newset; +} - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/IntSet.h b/src/IntSet.h index 65b0e22cf9..73c37a180d 100644 --- a/src/IntSet.h +++ b/src/IntSet.h @@ -8,65 +8,51 @@ #include -namespace zeek::detail - { +namespace zeek::detail { -class IntSet - { +class IntSet { public: - // n is a hint for the value of the largest integer. - explicit IntSet(unsigned int n = 1); - ~IntSet(); + // n is a hint for the value of the largest integer. + explicit IntSet(unsigned int n = 1); + ~IntSet(); - void Insert(unsigned int i); - void Remove(unsigned int i); - bool Contains(unsigned int i) const; + void Insert(unsigned int i); + void Remove(unsigned int i); + bool Contains(unsigned int i) const; - void Clear(); + void Clear(); private: - void Expand(unsigned int i); + void Expand(unsigned int i); - unsigned int size; - unsigned char* set; - }; + unsigned int size; + unsigned char* set; +}; -inline IntSet::IntSet(unsigned int n) - { - size = n / 8 + 1; - set = new unsigned char[size]; - memset(set, 0, size); - } +inline IntSet::IntSet(unsigned int n) { + size = n / 8 + 1; + set = new unsigned char[size]; + memset(set, 0, size); +} -inline IntSet::~IntSet() - { - delete[] set; - } +inline IntSet::~IntSet() { delete[] set; } -inline void IntSet::Insert(unsigned int i) - { - if ( i / 8 >= size ) - Expand(i); +inline void IntSet::Insert(unsigned int i) { + if ( i / 8 >= size ) + Expand(i); - set[i / 8] |= (1 << (i % 8)); - } + set[i / 8] |= (1 << (i % 8)); +} -inline void IntSet::Remove(unsigned int i) - { - if ( i / 8 >= size ) - Expand(i); - else - set[i / 8] &= ~(1 << (i % 8)); - } +inline void IntSet::Remove(unsigned int i) { + if ( i / 8 >= size ) + Expand(i); + else + set[i / 8] &= ~(1 << (i % 8)); +} -inline bool IntSet::Contains(unsigned int i) const - { - return i / 8 < size ? set[i / 8] & (1 << (i % 8)) : false; - } +inline bool IntSet::Contains(unsigned int i) const { return i / 8 < size ? set[i / 8] & (1 << (i % 8)) : false; } -inline void IntSet::Clear() - { - memset(set, 0, size); - } +inline void IntSet::Clear() { memset(set, 0, size); } - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/IntrusivePtr.h b/src/IntrusivePtr.h index 7d9578d29f..b384b66e9a 100644 --- a/src/IntrusivePtr.h +++ b/src/IntrusivePtr.h @@ -8,24 +8,19 @@ #include "Obj.h" -namespace zeek - { +namespace zeek { /** * A tag class for the #IntrusivePtr constructor which means: adopt * the reference from the caller. */ -struct AdoptRef - { - }; +struct AdoptRef {}; /** * A tag class for the #IntrusivePtr constructor which means: create a * new reference to the object. */ -struct NewRef - { - }; +struct NewRef {}; /** * This has to be forward declared and known here in order for us to be able @@ -55,131 +50,120 @@ class OpaqueVal; * should use a smart pointer whenever possible to reduce boilerplate code and * increase robustness of the code (in particular w.r.t. exceptions). */ -template class IntrusivePtr - { +template +class IntrusivePtr { public: - // -- member types + // -- member types - using pointer = T*; + using pointer = T*; - using const_pointer = const T*; + using const_pointer = const T*; - using element_type = T; + using element_type = T; - using reference = T&; + using reference = T&; - using const_reference = const T&; + using const_reference = const T&; - // -- constructors, destructors, and assignment operators + // -- constructors, destructors, and assignment operators - constexpr IntrusivePtr() noexcept = default; + constexpr IntrusivePtr() noexcept = default; - constexpr IntrusivePtr(std::nullptr_t) noexcept : IntrusivePtr() - { - // nop - } + constexpr IntrusivePtr(std::nullptr_t) noexcept : IntrusivePtr() { + // nop + } - /** - * Constructs a new intrusive pointer for managing the lifetime of the object - * pointed to by @c raw_ptr. - * - * This overload adopts the existing reference from the caller. - * - * @param raw_ptr Pointer to the shared object. - */ - constexpr IntrusivePtr(AdoptRef, pointer raw_ptr) noexcept : ptr_(raw_ptr) { } + /** + * Constructs a new intrusive pointer for managing the lifetime of the object + * pointed to by @c raw_ptr. + * + * This overload adopts the existing reference from the caller. + * + * @param raw_ptr Pointer to the shared object. + */ + constexpr IntrusivePtr(AdoptRef, pointer raw_ptr) noexcept : ptr_(raw_ptr) {} - /** - * Constructs a new intrusive pointer for managing the lifetime of the object - * pointed to by @c raw_ptr. - * - * This overload adds a new reference. - * - * @param raw_ptr Pointer to the shared object. - */ - IntrusivePtr(NewRef, pointer raw_ptr) noexcept : ptr_(raw_ptr) - { - if ( ptr_ ) - Ref(ptr_); - } + /** + * Constructs a new intrusive pointer for managing the lifetime of the object + * pointed to by @c raw_ptr. + * + * This overload adds a new reference. + * + * @param raw_ptr Pointer to the shared object. + */ + IntrusivePtr(NewRef, pointer raw_ptr) noexcept : ptr_(raw_ptr) { + if ( ptr_ ) + Ref(ptr_); + } - IntrusivePtr(IntrusivePtr&& other) noexcept : ptr_(other.release()) - { - // nop - } + IntrusivePtr(IntrusivePtr&& other) noexcept : ptr_(other.release()) { + // nop + } - IntrusivePtr(const IntrusivePtr& other) noexcept : IntrusivePtr(NewRef{}, other.get()) { } + IntrusivePtr(const IntrusivePtr& other) noexcept : IntrusivePtr(NewRef{}, other.get()) {} - template >> - IntrusivePtr(IntrusivePtr other) noexcept : ptr_(other.release()) - { - // nop - } + template>> + IntrusivePtr(IntrusivePtr other) noexcept : ptr_(other.release()) { + // nop + } - ~IntrusivePtr() - { - if ( ptr_ ) - { - // Specializing `OpaqueVal` as MSVC compiler does not detect it - // inheriting from `zeek::Obj` so we have to do that manually. - if constexpr ( std::is_same_v ) - Unref(reinterpret_cast(ptr_)); - else - Unref(ptr_); - } - } + ~IntrusivePtr() { + if ( ptr_ ) { + // Specializing `OpaqueVal` as MSVC compiler does not detect it + // inheriting from `zeek::Obj` so we have to do that manually. + if constexpr ( std::is_same_v ) + Unref(reinterpret_cast(ptr_)); + else + Unref(ptr_); + } + } - void swap(IntrusivePtr& other) noexcept { std::swap(ptr_, other.ptr_); } + void swap(IntrusivePtr& other) noexcept { std::swap(ptr_, other.ptr_); } - friend void swap(IntrusivePtr& a, IntrusivePtr& b) noexcept - { - using std::swap; - swap(a.ptr_, b.ptr_); - } + friend void swap(IntrusivePtr& a, IntrusivePtr& b) noexcept { + using std::swap; + swap(a.ptr_, b.ptr_); + } - /** - * Detaches an object from the automated lifetime management and sets this - * intrusive pointer to @c nullptr. - * @returns the raw pointer without modifying the reference count. - */ - pointer release() noexcept { return std::exchange(ptr_, nullptr); } + /** + * Detaches an object from the automated lifetime management and sets this + * intrusive pointer to @c nullptr. + * @returns the raw pointer without modifying the reference count. + */ + pointer release() noexcept { return std::exchange(ptr_, nullptr); } - IntrusivePtr& operator=(const IntrusivePtr& other) noexcept - { - IntrusivePtr tmp{other}; - swap(tmp); - return *this; - } + IntrusivePtr& operator=(const IntrusivePtr& other) noexcept { + IntrusivePtr tmp{other}; + swap(tmp); + return *this; + } - IntrusivePtr& operator=(IntrusivePtr&& other) noexcept - { - swap(other); - return *this; - } + IntrusivePtr& operator=(IntrusivePtr&& other) noexcept { + swap(other); + return *this; + } - IntrusivePtr& operator=(std::nullptr_t) noexcept - { - if ( ptr_ ) - { - Unref(ptr_); - ptr_ = nullptr; - } - return *this; - } + IntrusivePtr& operator=(std::nullptr_t) noexcept { + if ( ptr_ ) { + Unref(ptr_); + ptr_ = nullptr; + } + return *this; + } - pointer get() const noexcept { return ptr_; } + pointer get() const noexcept { return ptr_; } - pointer operator->() const noexcept { return ptr_; } + pointer operator->() const noexcept { return ptr_; } - reference operator*() const noexcept { return *ptr_; } + reference operator*() const noexcept { return *ptr_; } - bool operator!() const noexcept { return ! ptr_; } + bool operator!() const noexcept { return ! ptr_; } - explicit operator bool() const noexcept { return ptr_ != nullptr; } + explicit operator bool() const noexcept { return ptr_ != nullptr; } private: - pointer ptr_ = nullptr; - }; + pointer ptr_ = nullptr; +}; /** * Convenience function for creating a reference counted object and wrapping it @@ -189,11 +173,11 @@ private: * @note This function assumes that any @c T starts with a reference count of 1. * @relates IntrusivePtr */ -template IntrusivePtr make_intrusive(Ts&&... args) - { - // Assumes that objects start with a reference count of 1! - return {AdoptRef{}, new T(std::forward(args)...)}; - } +template +IntrusivePtr make_intrusive(Ts&&... args) { + // Assumes that objects start with a reference count of 1! + return {AdoptRef{}, new T(std::forward(args)...)}; +} /** * Casts an @c IntrusivePtr object to another by way of static_cast on @@ -201,78 +185,78 @@ template IntrusivePtr make_intrusive(Ts&&... args) * @param p The pointer of type @c U to cast to another type, @c T. * @return The pointer, as cast to type @c T. */ -template IntrusivePtr cast_intrusive(IntrusivePtr p) noexcept - { - return {AdoptRef{}, static_cast(p.release())}; - } +template +IntrusivePtr cast_intrusive(IntrusivePtr p) noexcept { + return {AdoptRef{}, static_cast(p.release())}; +} // -- comparison to nullptr ---------------------------------------------------- /** * @relates IntrusivePtr */ -template bool operator==(const zeek::IntrusivePtr& x, std::nullptr_t) - { - return ! x; - } +template +bool operator==(const zeek::IntrusivePtr& x, std::nullptr_t) { + return ! x; +} /** * @relates IntrusivePtr */ -template bool operator==(std::nullptr_t, const zeek::IntrusivePtr& x) - { - return ! x; - } +template +bool operator==(std::nullptr_t, const zeek::IntrusivePtr& x) { + return ! x; +} /** * @relates IntrusivePtr */ -template bool operator!=(const zeek::IntrusivePtr& x, std::nullptr_t) - { - return static_cast(x); - } +template +bool operator!=(const zeek::IntrusivePtr& x, std::nullptr_t) { + return static_cast(x); +} /** * @relates IntrusivePtr */ -template bool operator!=(std::nullptr_t, const zeek::IntrusivePtr& x) - { - return static_cast(x); - } +template +bool operator!=(std::nullptr_t, const zeek::IntrusivePtr& x) { + return static_cast(x); +} // -- comparison to raw pointer ------------------------------------------------ /** * @relates IntrusivePtr */ -template bool operator==(const zeek::IntrusivePtr& x, const T* y) - { - return x.get() == y; - } +template +bool operator==(const zeek::IntrusivePtr& x, const T* y) { + return x.get() == y; +} /** * @relates IntrusivePtr */ -template bool operator==(const T* x, const zeek::IntrusivePtr& y) - { - return x == y.get(); - } +template +bool operator==(const T* x, const zeek::IntrusivePtr& y) { + return x == y.get(); +} /** * @relates IntrusivePtr */ -template bool operator!=(const zeek::IntrusivePtr& x, const T* y) - { - return x.get() != y; - } +template +bool operator!=(const zeek::IntrusivePtr& x, const T* y) { + return x.get() != y; +} /** * @relates IntrusivePtr */ -template bool operator!=(const T* x, const zeek::IntrusivePtr& y) - { - return x != y.get(); - } +template +bool operator!=(const T* x, const zeek::IntrusivePtr& y) { + return x != y.get(); +} // -- comparison to intrusive pointer ------------------------------------------ @@ -282,35 +266,27 @@ template bool operator!=(const T* x, const zeek::IntrusivePtr& y) /** * @relates IntrusivePtr */ -template -auto operator==(const zeek::IntrusivePtr& x, const zeek::IntrusivePtr& y) - -> decltype(x.get() == y.get()) - { - return x.get() == y.get(); - } +template +auto operator==(const zeek::IntrusivePtr& x, const zeek::IntrusivePtr& y) -> decltype(x.get() == y.get()) { + return x.get() == y.get(); +} /** * @relates IntrusivePtr */ -template -auto operator!=(const zeek::IntrusivePtr& x, const zeek::IntrusivePtr& y) - -> decltype(x.get() != y.get()) - { - return x.get() != y.get(); - } +template +auto operator!=(const zeek::IntrusivePtr& x, const zeek::IntrusivePtr& y) -> decltype(x.get() != y.get()) { + return x.get() != y.get(); +} - } // namespace zeek +} // namespace zeek // -- hashing ------------------------------------------------ -namespace std - { -template struct hash> - { - // Hash of intrusive pointer is the same as hash of the raw pointer it holds. - size_t operator()(const zeek::IntrusivePtr& v) const noexcept - { - return std::hash{}(v.get()); - } - }; - } +namespace std { +template +struct hash> { + // Hash of intrusive pointer is the same as hash of the raw pointer it holds. + size_t operator()(const zeek::IntrusivePtr& v) const noexcept { return std::hash{}(v.get()); } +}; +} // namespace std diff --git a/src/List.cc b/src/List.cc index 817dafcdf8..690d144c69 100644 --- a/src/List.cc +++ b/src/List.cc @@ -2,130 +2,125 @@ #include "zeek/3rdparty/doctest.h" -TEST_CASE("list construction") - { - zeek::List list; - CHECK(list.empty()); +TEST_CASE("list construction") { + zeek::List list; + CHECK(list.empty()); - zeek::List list2(10); - CHECK(list2.empty()); - CHECK(list2.max() == 10); - } + zeek::List list2(10); + CHECK(list2.empty()); + CHECK(list2.max() == 10); +} -TEST_CASE("list operation") - { - zeek::List list({1, 2, 3}); - CHECK(list.size() == 3); - CHECK(list.max() == 3); - CHECK(list[0] == 1); - CHECK(list[1] == 2); - CHECK(list[2] == 3); +TEST_CASE("list operation") { + zeek::List list({1, 2, 3}); + CHECK(list.size() == 3); + CHECK(list.max() == 3); + CHECK(list[0] == 1); + CHECK(list[1] == 2); + CHECK(list[2] == 3); - // push_back forces a resize of the list here, which grows the list - // by a growth factor. That makes the max elements equal to 6. - list.push_back(4); - CHECK(list.size() == 4); - CHECK(list.max() == 6); - CHECK(list[3] == 4); + // push_back forces a resize of the list here, which grows the list + // by a growth factor. That makes the max elements equal to 6. + list.push_back(4); + CHECK(list.size() == 4); + CHECK(list.max() == 6); + CHECK(list[3] == 4); - CHECK(list.front() == 1); - CHECK(list.back() == 4); + CHECK(list.front() == 1); + CHECK(list.back() == 4); - list.pop_front(); - CHECK(list.size() == 3); - CHECK(list.front() == 2); + list.pop_front(); + CHECK(list.size() == 3); + CHECK(list.front() == 2); - list.pop_back(); - CHECK(list.size() == 2); - CHECK(list.back() == 3); + list.pop_back(); + CHECK(list.size() == 2); + CHECK(list.back() == 3); - list.push_back(4); - CHECK(list.is_member(2)); - CHECK(list.member_pos(2) == 0); + list.push_back(4); + CHECK(list.is_member(2)); + CHECK(list.member_pos(2) == 0); - list.remove(2); - CHECK(list.size() == 2); - CHECK(list[0] == 3); - CHECK(list[1] == 4); + list.remove(2); + CHECK(list.size() == 2); + CHECK(list[0] == 3); + CHECK(list[1] == 4); - // Squash the list down to the existing elements. - list.resize(); - CHECK(list.size() == 2); - CHECK(list.max() == 2); + // Squash the list down to the existing elements. + list.resize(); + CHECK(list.size() == 2); + CHECK(list.max() == 2); - // Attempt replacing a known position. - int old = list.replace(0, 10); - CHECK(list.size() == 2); - CHECK(list.max() == 2); - CHECK(old == 3); - CHECK(list[0] == 10); - CHECK(list[1] == 4); + // Attempt replacing a known position. + int old = list.replace(0, 10); + CHECK(list.size() == 2); + CHECK(list.max() == 2); + CHECK(old == 3); + CHECK(list[0] == 10); + CHECK(list[1] == 4); - // Attempt replacing an element off the end of the list, which - // causes a resize. - old = list.replace(3, 5); - CHECK(list.size() == 4); - CHECK(list.max() == 4); - CHECK(old == 0); - CHECK(list[0] == 10); - CHECK(list[1] == 4); - CHECK(list[2] == 0); - CHECK(list[3] == 5); + // Attempt replacing an element off the end of the list, which + // causes a resize. + old = list.replace(3, 5); + CHECK(list.size() == 4); + CHECK(list.max() == 4); + CHECK(old == 0); + CHECK(list[0] == 10); + CHECK(list[1] == 4); + CHECK(list[2] == 0); + CHECK(list[3] == 5); - // Attempt replacing an element with a negative index, which returns the - // default value for the list type. - old = list.replace(-1, 50); - CHECK(list.size() == 4); - CHECK(list.max() == 4); - CHECK(old == 0); + // Attempt replacing an element with a negative index, which returns the + // default value for the list type. + old = list.replace(-1, 50); + CHECK(list.size() == 4); + CHECK(list.max() == 4); + CHECK(old == 0); - list.clear(); - CHECK(list.size() == 0); - CHECK(list.max() == 0); - } + list.clear(); + CHECK(list.size() == 0); + CHECK(list.max() == 0); +} -TEST_CASE("list iteration") - { - zeek::List list({1, 2, 3, 4}); +TEST_CASE("list iteration") { + zeek::List list({1, 2, 3, 4}); - int index = 1; - for ( int v : list ) - CHECK(v == index++); + int index = 1; + for ( int v : list ) + CHECK(v == index++); - index = 1; - for ( auto it = list.begin(); it != list.end(); index++, ++it ) - CHECK(*it == index); - } + index = 1; + for ( auto it = list.begin(); it != list.end(); index++, ++it ) + CHECK(*it == index); +} -TEST_CASE("plists") - { - zeek::PList list; - list.push_back(new int{1}); - list.push_back(new int{2}); - list.push_back(new int{3}); +TEST_CASE("plists") { + zeek::PList list; + list.push_back(new int{1}); + list.push_back(new int{2}); + list.push_back(new int{3}); - CHECK(*list[0] == 1); + CHECK(*list[0] == 1); - int* new_val = new int(5); - auto old = list.replace(-1, new_val); - delete new_val; - CHECK(old == nullptr); + int* new_val = new int(5); + auto old = list.replace(-1, new_val); + delete new_val; + CHECK(old == nullptr); - for ( auto v : list ) - delete v; - list.clear(); - } + for ( auto v : list ) + delete v; + list.clear(); +} -TEST_CASE("unordered list operation") - { - zeek::List list({1, 2, 3, 4}); - CHECK(list.size() == 4); +TEST_CASE("unordered list operation") { + zeek::List list({1, 2, 3, 4}); + CHECK(list.size() == 4); - // An unordered list doesn't maintain the ordering of the elements when - // one is removed. It just swaps the last element into the hole. - list.remove(2); - CHECK(list.size() == 3); - CHECK(list[0] == 1); - CHECK(list[1] == 4); - CHECK(list[2] == 3); - } + // An unordered list doesn't maintain the ordering of the elements when + // one is removed. It just swaps the last element into the hole. + list.remove(2); + CHECK(list.size() == 3); + CHECK(list[0] == 1); + CHECK(list[1] == 4); + CHECK(list[2] == 3); +} diff --git a/src/List.h b/src/List.h index 7a3d379148..4db1ded1f3 100644 --- a/src/List.h +++ b/src/List.h @@ -27,313 +27,293 @@ #include "zeek/util.h" -namespace zeek - { +namespace zeek { -enum class ListOrder : int - { - ORDERED, - UNORDERED - }; +enum class ListOrder : int { ORDERED, UNORDERED }; -template class List - { +template +class List { public: - constexpr static int DEFAULT_LIST_SIZE = 10; - constexpr static int LIST_GROWTH_FACTOR = 2; + constexpr static int DEFAULT_LIST_SIZE = 10; + constexpr static int LIST_GROWTH_FACTOR = 2; - ~List() { free(entries); } - explicit List(int size = 0) - { - num_entries = 0; + ~List() { free(entries); } + explicit List(int size = 0) { + num_entries = 0; - if ( size <= 0 ) - { - max_entries = 0; - entries = nullptr; - return; - } + if ( size <= 0 ) { + max_entries = 0; + entries = nullptr; + return; + } - max_entries = size; + max_entries = size; - entries = (T*)util::safe_malloc(max_entries * sizeof(T)); - } + entries = (T*)util::safe_malloc(max_entries * sizeof(T)); + } - List(const List& b) - { - max_entries = b.max_entries; - num_entries = b.num_entries; + List(const List& b) { + max_entries = b.max_entries; + num_entries = b.num_entries; - if ( max_entries ) - entries = (T*)util::safe_malloc(max_entries * sizeof(T)); - else - entries = nullptr; + if ( max_entries ) + entries = (T*)util::safe_malloc(max_entries * sizeof(T)); + else + entries = nullptr; - for ( int i = 0; i < num_entries; ++i ) - entries[i] = b.entries[i]; - } + for ( int i = 0; i < num_entries; ++i ) + entries[i] = b.entries[i]; + } - List(List&& b) - { - entries = b.entries; - num_entries = b.num_entries; - max_entries = b.max_entries; + List(List&& b) { + entries = b.entries; + num_entries = b.num_entries; + max_entries = b.max_entries; - b.entries = nullptr; - b.num_entries = b.max_entries = 0; - } + b.entries = nullptr; + b.num_entries = b.max_entries = 0; + } - List(const T* arr, int n) - { - num_entries = max_entries = n; - entries = (T*)util::safe_malloc(max_entries * sizeof(T)); - memcpy(entries, arr, n * sizeof(T)); - } + List(const T* arr, int n) { + num_entries = max_entries = n; + entries = (T*)util::safe_malloc(max_entries * sizeof(T)); + memcpy(entries, arr, n * sizeof(T)); + } - List(std::initializer_list il) : List(il.begin(), il.size()) { } + List(std::initializer_list il) : List(il.begin(), il.size()) {} - List& operator=(const List& b) - { - if ( this == &b ) - return *this; + List& operator=(const List& b) { + if ( this == &b ) + return *this; - free(entries); + free(entries); - max_entries = b.max_entries; - num_entries = b.num_entries; + max_entries = b.max_entries; + num_entries = b.num_entries; - if ( max_entries ) - entries = (T*)util::safe_malloc(max_entries * sizeof(T)); - else - entries = nullptr; + if ( max_entries ) + entries = (T*)util::safe_malloc(max_entries * sizeof(T)); + else + entries = nullptr; - for ( int i = 0; i < num_entries; ++i ) - entries[i] = b.entries[i]; + for ( int i = 0; i < num_entries; ++i ) + entries[i] = b.entries[i]; - return *this; - } + return *this; + } - List& operator=(List&& b) - { - if ( this == &b ) - return *this; + List& operator=(List&& b) { + if ( this == &b ) + return *this; - free(entries); - entries = b.entries; - num_entries = b.num_entries; - max_entries = b.max_entries; + free(entries); + entries = b.entries; + num_entries = b.num_entries; + max_entries = b.max_entries; - b.entries = nullptr; - b.num_entries = b.max_entries = 0; - return *this; - } + b.entries = nullptr; + b.num_entries = b.max_entries = 0; + return *this; + } - // Return nth ent of list (do not remove). - T& operator[](int i) const { return entries[i]; } + // Return nth ent of list (do not remove). + T& operator[](int i) const { return entries[i]; } - void clear() // remove all entries - { - free(entries); - entries = nullptr; - num_entries = max_entries = 0; - } + void clear() // remove all entries + { + free(entries); + entries = nullptr; + num_entries = max_entries = 0; + } - bool empty() const noexcept { return num_entries == 0; } - size_t size() const noexcept { return num_entries; } + bool empty() const noexcept { return num_entries == 0; } + size_t size() const noexcept { return num_entries; } - int length() const { return num_entries; } - int max() const { return max_entries; } - int resize(int new_size = 0) // 0 => size to fit current number of entries - { - if ( new_size < num_entries ) - new_size = num_entries; // do not lose any entries + int length() const { return num_entries; } + int max() const { return max_entries; } + int resize(int new_size = 0) // 0 => size to fit current number of entries + { + if ( new_size < num_entries ) + new_size = num_entries; // do not lose any entries - if ( new_size != max_entries ) - { - entries = (T*)util::safe_realloc((void*)entries, sizeof(T) * new_size); - if ( entries ) - max_entries = new_size; - else - max_entries = 0; - } + if ( new_size != max_entries ) { + entries = (T*)util::safe_realloc((void*)entries, sizeof(T) * new_size); + if ( entries ) + max_entries = new_size; + else + max_entries = 0; + } - return max_entries; - } + return max_entries; + } - void push_front(const T& a) - { - if ( num_entries == max_entries ) - resize(max_entries ? max_entries * LIST_GROWTH_FACTOR : DEFAULT_LIST_SIZE); + void push_front(const T& a) { + if ( num_entries == max_entries ) + resize(max_entries ? max_entries * LIST_GROWTH_FACTOR : DEFAULT_LIST_SIZE); - for ( int i = num_entries; i > 0; --i ) - entries[i] = entries[i - 1]; // move all pointers up one + for ( int i = num_entries; i > 0; --i ) + entries[i] = entries[i - 1]; // move all pointers up one - ++num_entries; - entries[0] = a; - } + ++num_entries; + entries[0] = a; + } - void push_back(const T& a) - { - if ( num_entries == max_entries ) - resize(max_entries ? max_entries * LIST_GROWTH_FACTOR : DEFAULT_LIST_SIZE); + void push_back(const T& a) { + if ( num_entries == max_entries ) + resize(max_entries ? max_entries * LIST_GROWTH_FACTOR : DEFAULT_LIST_SIZE); - entries[num_entries++] = a; - } + entries[num_entries++] = a; + } - void pop_front() { remove_nth(0); } - void pop_back() { remove_nth(num_entries - 1); } + void pop_front() { remove_nth(0); } + void pop_back() { remove_nth(num_entries - 1); } - T& front() { return entries[0]; } - T& back() { return entries[num_entries - 1]; } + T& front() { return entries[0]; } + T& back() { return entries[num_entries - 1]; } - // The append method is maintained for historical/compatibility reasons. - // (It's commonly used in the event generation API) - void append(const T& a) // add to end of list - { - push_back(a); - } + // The append method is maintained for historical/compatibility reasons. + // (It's commonly used in the event generation API) + void append(const T& a) // add to end of list + { + push_back(a); + } - bool remove(const T& a) // delete entry from list - { - int pos = member_pos(a); - if ( pos != -1 ) - { - remove_nth(pos); - return true; - } + bool remove(const T& a) // delete entry from list + { + int pos = member_pos(a); + if ( pos != -1 ) { + remove_nth(pos); + return true; + } - return false; - } + return false; + } - T remove_nth(int n) // delete nth entry from list - { - assert(n >= 0 && n < num_entries); + T remove_nth(int n) // delete nth entry from list + { + assert(n >= 0 && n < num_entries); - T old_ent = entries[n]; + T old_ent = entries[n]; - // For data where we don't care about ordering, we don't care about keeping - // the list in the same order when removing an element. Just swap the last - // element with the element being removed. - if constexpr ( Order == ListOrder::ORDERED ) - { - --num_entries; + // For data where we don't care about ordering, we don't care about keeping + // the list in the same order when removing an element. Just swap the last + // element with the element being removed. + if constexpr ( Order == ListOrder::ORDERED ) { + --num_entries; - for ( ; n < num_entries; ++n ) - entries[n] = entries[n + 1]; - } - else - { - entries[n] = entries[num_entries - 1]; - --num_entries; - } + for ( ; n < num_entries; ++n ) + entries[n] = entries[n + 1]; + } + else { + entries[n] = entries[num_entries - 1]; + --num_entries; + } - return old_ent; - } + return old_ent; + } - // Return 0 if ent is not in the list, ent otherwise. - bool is_member(const T& a) const - { - int pos = member_pos(a); - return pos != -1; - } + // Return 0 if ent is not in the list, ent otherwise. + bool is_member(const T& a) const { + int pos = member_pos(a); + return pos != -1; + } - // Returns -1 if ent is not in the list, otherwise its position. - int member_pos(const T& e) const - { - int i; - for ( i = 0; i < length() && e != entries[i]; ++i ) - ; + // Returns -1 if ent is not in the list, otherwise its position. + int member_pos(const T& e) const { + int i; + for ( i = 0; i < length() && e != entries[i]; ++i ) + ; - return (i == length()) ? -1 : i; - } + return (i == length()) ? -1 : i; + } - T replace(int ent_index, const T& new_ent) // replace entry #i with a new value - { - if ( ent_index < 0 ) - return T{}; + T replace(int ent_index, const T& new_ent) // replace entry #i with a new value + { + if ( ent_index < 0 ) + return T{}; - T old_ent{}; + T old_ent{}; - if ( ent_index > num_entries - 1 ) - { // replacement beyond the end of the list - resize(ent_index + 1); + if ( ent_index > num_entries - 1 ) { // replacement beyond the end of the list + resize(ent_index + 1); - for ( int i = num_entries; i < max_entries; ++i ) - entries[i] = T{}; - num_entries = max_entries; - } - else - old_ent = entries[ent_index]; + for ( int i = num_entries; i < max_entries; ++i ) + entries[i] = T{}; + num_entries = max_entries; + } + else + old_ent = entries[ent_index]; - entries[ent_index] = new_ent; + entries[ent_index] = new_ent; - return old_ent; - } + return old_ent; + } - // Type traits needed for some of the std algorithms to work - using value_type = T; - using pointer = T*; - using const_pointer = const T*; + // Type traits needed for some of the std algorithms to work + using value_type = T; + using pointer = T*; + using const_pointer = const T*; - // Iterator support - using iterator = pointer; - using const_iterator = const_pointer; - using reverse_iterator = std::reverse_iterator; - using const_reverse_iterator = std::reverse_iterator; + // Iterator support + using iterator = pointer; + using const_iterator = const_pointer; + using reverse_iterator = std::reverse_iterator; + using const_reverse_iterator = std::reverse_iterator; - iterator begin() { return entries; } - iterator end() { return entries + num_entries; } - const_iterator begin() const { return entries; } - const_iterator end() const { return entries + num_entries; } - const_iterator cbegin() const { return entries; } - const_iterator cend() const { return entries + num_entries; } + iterator begin() { return entries; } + iterator end() { return entries + num_entries; } + const_iterator begin() const { return entries; } + const_iterator end() const { return entries + num_entries; } + const_iterator cbegin() const { return entries; } + const_iterator cend() const { return entries + num_entries; } - reverse_iterator rbegin() { return reverse_iterator{end()}; } - reverse_iterator rend() { return reverse_iterator{begin()}; } - const_reverse_iterator rbegin() const { return const_reverse_iterator{end()}; } - const_reverse_iterator rend() const { return const_reverse_iterator{begin()}; } - const_reverse_iterator crbegin() const { return rbegin(); } - const_reverse_iterator crend() const { return rend(); } + reverse_iterator rbegin() { return reverse_iterator{end()}; } + reverse_iterator rend() { return reverse_iterator{begin()}; } + const_reverse_iterator rbegin() const { return const_reverse_iterator{end()}; } + const_reverse_iterator rend() const { return const_reverse_iterator{begin()}; } + const_reverse_iterator crbegin() const { return rbegin(); } + const_reverse_iterator crend() const { return rend(); } protected: - // This could essentially be an std::vector if we wanted. Some - // reasons to maybe not refactor to use std::vector ? - // - // - Harder to use a custom growth factor. Also, the growth - // factor would be implementation-specific, taking some control over - // performance out of our hands. - // - // - It won't ever take advantage of realloc's occasional ability to - // grow in-place. - // - // - Combine above point this with lack of control of growth - // factor means the common choice of 2x growth factor causes - // a growth pattern that crawls forward in memory with no possible - // re-use of previous chunks (the new capacity is always larger than - // all previously allocated chunks combined). This point and - // whether 2x is empirically an issue still seems debated (at least - // GCC seems to stand by 2x as empirically better). - // - // - Sketchy shrinking behavior: standard says that requests to - // shrink are non-binding (it's expected implementations heed, but - // still not great to have no guarantee). Also, it would not take - // advantage of realloc's ability to contract in-place, it would - // allocate-and-copy. + // This could essentially be an std::vector if we wanted. Some + // reasons to maybe not refactor to use std::vector ? + // + // - Harder to use a custom growth factor. Also, the growth + // factor would be implementation-specific, taking some control over + // performance out of our hands. + // + // - It won't ever take advantage of realloc's occasional ability to + // grow in-place. + // + // - Combine above point this with lack of control of growth + // factor means the common choice of 2x growth factor causes + // a growth pattern that crawls forward in memory with no possible + // re-use of previous chunks (the new capacity is always larger than + // all previously allocated chunks combined). This point and + // whether 2x is empirically an issue still seems debated (at least + // GCC seems to stand by 2x as empirically better). + // + // - Sketchy shrinking behavior: standard says that requests to + // shrink are non-binding (it's expected implementations heed, but + // still not great to have no guarantee). Also, it would not take + // advantage of realloc's ability to contract in-place, it would + // allocate-and-copy. - T* entries; - int max_entries; - int num_entries; - }; + T* entries; + int max_entries; + int num_entries; +}; // Specialization of the List class to store pointers of a type. -template using PList = List; +template +using PList = List; // Popular type of list: list of strings. using name_list = PList; - } // namespace zeek +} // namespace zeek // Macro to visit each list element in turn. -#define loop_over_list(list, iterator) \ - int iterator; \ - for ( iterator = 0; iterator < (list).length(); ++iterator ) +#define loop_over_list(list, iterator) \ + int iterator; \ + for ( iterator = 0; iterator < (list).length(); ++iterator ) diff --git a/src/NFA.cc b/src/NFA.cc index db629c5c58..20e1e75abc 100644 --- a/src/NFA.cc +++ b/src/NFA.cc @@ -10,360 +10,314 @@ #include "zeek/EquivClass.h" #include "zeek/IntSet.h" -namespace zeek::detail - { +namespace zeek::detail { static int nfa_state_id = 0; -NFA_State::NFA_State(int arg_sym, EquivClass* ec) - { - sym = arg_sym; - ccl = nullptr; - accept = NO_ACCEPT; - first_trans_is_back_ref = false; - mark = nullptr; - epsclosure = nullptr; - id = ++nfa_state_id; - - // Fix up equivalence classes based on this transition. Note that any - // character which has its own transition gets its own equivalence - // class. Thus only characters which are only in character classes - // have a chance at being in the same equivalence class. E.g. "a|b" - // puts 'a' and 'b' into two different equivalence classes. "[ab]" - // puts them in the same equivalence class (barring other differences - // elsewhere in the input). - - if ( ec && sym != SYM_EPSILON /* no associated symbol */ ) - ec->UniqueChar(sym); - } - -NFA_State::NFA_State(CCL* arg_ccl) - { - sym = SYM_CCL; - ccl = arg_ccl; - accept = NO_ACCEPT; - first_trans_is_back_ref = false; - mark = nullptr; - id = ++nfa_state_id; - epsclosure = nullptr; - } - -NFA_State::~NFA_State() - { - for ( int i = 0; i < xtions.length(); ++i ) - if ( i > 0 || ! first_trans_is_back_ref ) - Unref(xtions[i]); - - delete epsclosure; - } - -void NFA_State::AddXtionsTo(NFA_state_list* ns) - { - for ( int i = 0; i < xtions.length(); ++i ) - ns->push_back(xtions[i]); - } - -NFA_State* NFA_State::DeepCopy() - { - if ( mark ) - { - Ref(mark); - return mark; - } - - NFA_State* copy = ccl ? new NFA_State(ccl) : new NFA_State(sym, nullptr); - SetMark(copy); - - for ( int i = 0; i < xtions.length(); ++i ) - copy->AddXtion(xtions[i]->DeepCopy()); - - return copy; - } - -void NFA_State::ClearMarks() - { - if ( mark ) - { - SetMark(nullptr); - for ( int i = 0; i < xtions.length(); ++i ) - xtions[i]->ClearMarks(); - } - } - -NFA_state_list* NFA_State::EpsilonClosure() - { - if ( epsclosure ) - return epsclosure; - - epsclosure = new NFA_state_list; - - NFA_state_list states; - states.push_back(this); - SetMark(this); - - int i; - for ( i = 0; i < states.length(); ++i ) - { - NFA_State* ns = states[i]; - if ( ns->TransSym() == SYM_EPSILON ) - { - NFA_state_list* x = ns->Transitions(); - for ( int j = 0; j < x->length(); ++j ) - { - NFA_State* nxt = (*x)[j]; - if ( ! nxt->Mark() ) - { - states.push_back(nxt); - nxt->SetMark(nxt); - } - } - - if ( ns->Accept() != NO_ACCEPT ) - epsclosure->push_back(ns); - } - - else - // Non-epsilon transition - keep it. - epsclosure->push_back(ns); - } - - // Clear out markers. - for ( i = 0; i < states.length(); ++i ) - states[i]->SetMark(nullptr); - - // Make it fit. - epsclosure->resize(0); - - return epsclosure; - } - -void NFA_State::Describe(ODesc* d) const - { - d->Add("NFA state"); - } - -void NFA_State::Dump(FILE* f) - { - if ( mark ) - return; - - fprintf(f, "NFA state %d, sym = %d, accept = %d:\n", id, sym, accept); - - for ( int i = 0; i < xtions.length(); ++i ) - fprintf(f, "\ttransition to %d\n", xtions[i]->ID()); - - SetMark(this); - for ( int i = 0; i < xtions.length(); ++i ) - xtions[i]->Dump(f); - } - -NFA_Machine::NFA_Machine(NFA_State* first, NFA_State* final) - { - first_state = first; - final_state = final ? final : first; - eol = bol = 0; - } - -NFA_Machine::~NFA_Machine() - { - Unref(first_state); - } - -void NFA_Machine::InsertEpsilon() - { - NFA_State* eps = new EpsilonState(); - eps->AddXtion(first_state); - first_state = eps; - } - -void NFA_Machine::AppendEpsilon() - { - AppendState(new EpsilonState()); - } - -void NFA_Machine::AddAccept(int accept_val) - { - // Hang the accepting number off an epsilon state. If it is associated - // with a state that has a non-epsilon out-transition, then the state - // will accept BEFORE it makes that transition, i.e., one character - // too soon. - - if ( final_state->TransSym() != SYM_EPSILON ) - AppendState(new EpsilonState()); - - final_state->SetAccept(accept_val); - } - -void NFA_Machine::LinkCopies(int n) - { - if ( n <= 0 ) - return; - - // Make all the copies before doing any appending, otherwise - // subsequent DuplicateMachine()'s will include the extra - // copies! - NFA_Machine** copies = new NFA_Machine*[n]; - - int i; - for ( i = 0; i < n; ++i ) - copies[i] = DuplicateMachine(); - - for ( i = 0; i < n; ++i ) - AppendMachine(copies[i]); - - delete[] copies; - } - -NFA_Machine* NFA_Machine::DuplicateMachine() - { - NFA_State* new_first_state = first_state->DeepCopy(); - NFA_Machine* new_m = new NFA_Machine(new_first_state, final_state->Mark()); - first_state->ClearMarks(); - - return new_m; - } - -void NFA_Machine::AppendState(NFA_State* s) - { - final_state->AddXtion(s); - final_state = s; - } - -void NFA_Machine::AppendMachine(NFA_Machine* m) - { - AppendEpsilon(); - final_state->AddXtion(m->FirstState()); - final_state = m->FinalState(); - - Ref(m->FirstState()); // so states stay around after the following - Unref(m); - } - -void NFA_Machine::MakeOptional() - { - InsertEpsilon(); - AppendEpsilon(); - first_state->AddXtion(final_state); - Ref(final_state); - } - -void NFA_Machine::MakePositiveClosure() - { - AppendEpsilon(); - final_state->AddXtion(first_state); - - // Don't Ref the state the final epsilon points to, otherwise we'll - // have reference cycles that lead to leaks. - final_state->SetFirstTransIsBackRef(); - } - -void NFA_Machine::MakeRepl(int lower, int upper) - { - NFA_Machine* dup = nullptr; - if ( upper > lower || upper == NO_UPPER_BOUND ) - dup = DuplicateMachine(); - - LinkCopies(lower - 1); - - if ( upper == NO_UPPER_BOUND ) - { - dup->MakeClosure(); - AppendMachine(dup); - return; - } - - while ( upper > lower ) - { - NFA_Machine* dup2; - if ( --upper == lower ) - // Don't need "dup" for any further copies - dup2 = dup; - else - dup2 = dup->DuplicateMachine(); - - dup2->MakeOptional(); - AppendMachine(dup2); - } - } - -void NFA_Machine::Describe(ODesc* d) const - { - d->Add("NFA machine"); - } - -void NFA_Machine::Dump(FILE* f) - { - first_state->Dump(f); - first_state->ClearMarks(); - } - -NFA_Machine* make_alternate(NFA_Machine* m1, NFA_Machine* m2) - { - if ( ! m1 ) - return m2; - if ( ! m2 ) - return m1; - - NFA_State* first = new EpsilonState(); - NFA_State* last = new EpsilonState(); - - first->AddXtion(m1->FirstState()); - first->AddXtion(m2->FirstState()); - - m1->AppendState(last); - m2->AppendState(last); - Ref(last); - - // Keep these around. - Ref(m1->FirstState()); - Ref(m2->FirstState()); - - Unref(m1); - Unref(m2); - - return new NFA_Machine(first, last); - } - -NFA_state_list* epsilon_closure(NFA_state_list* states) - { - // We just keep one of this as it may get quite large. - static IntSet closuremap; - closuremap.Clear(); - - NFA_state_list* closure = new NFA_state_list; - - for ( int i = 0; i < states->length(); ++i ) - { - NFA_state_list* stateclosure = (*states)[i]->EpsilonClosure(); - - for ( int j = 0; j < stateclosure->length(); ++j ) - { - NFA_State* ns = (*stateclosure)[j]; - if ( ! closuremap.Contains(ns->ID()) ) - { - closuremap.Insert(ns->ID()); - closure->push_back(ns); - } - } - } - - // Sort all of the closures in the list by ID - std::sort(closure->begin(), closure->end(), NFA_state_cmp_neg); - - // Make it fit. - closure->resize(0); - - delete states; - - return closure; - } - -bool NFA_state_cmp_neg(const NFA_State* v1, const NFA_State* v2) - { - if ( v1->ID() < v2->ID() ) - return true; - else - return false; - } - - } // namespace zeek::detail +NFA_State::NFA_State(int arg_sym, EquivClass* ec) { + sym = arg_sym; + ccl = nullptr; + accept = NO_ACCEPT; + first_trans_is_back_ref = false; + mark = nullptr; + epsclosure = nullptr; + id = ++nfa_state_id; + + // Fix up equivalence classes based on this transition. Note that any + // character which has its own transition gets its own equivalence + // class. Thus only characters which are only in character classes + // have a chance at being in the same equivalence class. E.g. "a|b" + // puts 'a' and 'b' into two different equivalence classes. "[ab]" + // puts them in the same equivalence class (barring other differences + // elsewhere in the input). + + if ( ec && sym != SYM_EPSILON /* no associated symbol */ ) + ec->UniqueChar(sym); +} + +NFA_State::NFA_State(CCL* arg_ccl) { + sym = SYM_CCL; + ccl = arg_ccl; + accept = NO_ACCEPT; + first_trans_is_back_ref = false; + mark = nullptr; + id = ++nfa_state_id; + epsclosure = nullptr; +} + +NFA_State::~NFA_State() { + for ( int i = 0; i < xtions.length(); ++i ) + if ( i > 0 || ! first_trans_is_back_ref ) + Unref(xtions[i]); + + delete epsclosure; +} + +void NFA_State::AddXtionsTo(NFA_state_list* ns) { + for ( int i = 0; i < xtions.length(); ++i ) + ns->push_back(xtions[i]); +} + +NFA_State* NFA_State::DeepCopy() { + if ( mark ) { + Ref(mark); + return mark; + } + + NFA_State* copy = ccl ? new NFA_State(ccl) : new NFA_State(sym, nullptr); + SetMark(copy); + + for ( int i = 0; i < xtions.length(); ++i ) + copy->AddXtion(xtions[i]->DeepCopy()); + + return copy; +} + +void NFA_State::ClearMarks() { + if ( mark ) { + SetMark(nullptr); + for ( int i = 0; i < xtions.length(); ++i ) + xtions[i]->ClearMarks(); + } +} + +NFA_state_list* NFA_State::EpsilonClosure() { + if ( epsclosure ) + return epsclosure; + + epsclosure = new NFA_state_list; + + NFA_state_list states; + states.push_back(this); + SetMark(this); + + int i; + for ( i = 0; i < states.length(); ++i ) { + NFA_State* ns = states[i]; + if ( ns->TransSym() == SYM_EPSILON ) { + NFA_state_list* x = ns->Transitions(); + for ( int j = 0; j < x->length(); ++j ) { + NFA_State* nxt = (*x)[j]; + if ( ! nxt->Mark() ) { + states.push_back(nxt); + nxt->SetMark(nxt); + } + } + + if ( ns->Accept() != NO_ACCEPT ) + epsclosure->push_back(ns); + } + + else + // Non-epsilon transition - keep it. + epsclosure->push_back(ns); + } + + // Clear out markers. + for ( i = 0; i < states.length(); ++i ) + states[i]->SetMark(nullptr); + + // Make it fit. + epsclosure->resize(0); + + return epsclosure; +} + +void NFA_State::Describe(ODesc* d) const { d->Add("NFA state"); } + +void NFA_State::Dump(FILE* f) { + if ( mark ) + return; + + fprintf(f, "NFA state %d, sym = %d, accept = %d:\n", id, sym, accept); + + for ( int i = 0; i < xtions.length(); ++i ) + fprintf(f, "\ttransition to %d\n", xtions[i]->ID()); + + SetMark(this); + for ( int i = 0; i < xtions.length(); ++i ) + xtions[i]->Dump(f); +} + +NFA_Machine::NFA_Machine(NFA_State* first, NFA_State* final) { + first_state = first; + final_state = final ? final : first; + eol = bol = 0; +} + +NFA_Machine::~NFA_Machine() { Unref(first_state); } + +void NFA_Machine::InsertEpsilon() { + NFA_State* eps = new EpsilonState(); + eps->AddXtion(first_state); + first_state = eps; +} + +void NFA_Machine::AppendEpsilon() { AppendState(new EpsilonState()); } + +void NFA_Machine::AddAccept(int accept_val) { + // Hang the accepting number off an epsilon state. If it is associated + // with a state that has a non-epsilon out-transition, then the state + // will accept BEFORE it makes that transition, i.e., one character + // too soon. + + if ( final_state->TransSym() != SYM_EPSILON ) + AppendState(new EpsilonState()); + + final_state->SetAccept(accept_val); +} + +void NFA_Machine::LinkCopies(int n) { + if ( n <= 0 ) + return; + + // Make all the copies before doing any appending, otherwise + // subsequent DuplicateMachine()'s will include the extra + // copies! + NFA_Machine** copies = new NFA_Machine*[n]; + + int i; + for ( i = 0; i < n; ++i ) + copies[i] = DuplicateMachine(); + + for ( i = 0; i < n; ++i ) + AppendMachine(copies[i]); + + delete[] copies; +} + +NFA_Machine* NFA_Machine::DuplicateMachine() { + NFA_State* new_first_state = first_state->DeepCopy(); + NFA_Machine* new_m = new NFA_Machine(new_first_state, final_state->Mark()); + first_state->ClearMarks(); + + return new_m; +} + +void NFA_Machine::AppendState(NFA_State* s) { + final_state->AddXtion(s); + final_state = s; +} + +void NFA_Machine::AppendMachine(NFA_Machine* m) { + AppendEpsilon(); + final_state->AddXtion(m->FirstState()); + final_state = m->FinalState(); + + Ref(m->FirstState()); // so states stay around after the following + Unref(m); +} + +void NFA_Machine::MakeOptional() { + InsertEpsilon(); + AppendEpsilon(); + first_state->AddXtion(final_state); + Ref(final_state); +} + +void NFA_Machine::MakePositiveClosure() { + AppendEpsilon(); + final_state->AddXtion(first_state); + + // Don't Ref the state the final epsilon points to, otherwise we'll + // have reference cycles that lead to leaks. + final_state->SetFirstTransIsBackRef(); +} + +void NFA_Machine::MakeRepl(int lower, int upper) { + NFA_Machine* dup = nullptr; + if ( upper > lower || upper == NO_UPPER_BOUND ) + dup = DuplicateMachine(); + + LinkCopies(lower - 1); + + if ( upper == NO_UPPER_BOUND ) { + dup->MakeClosure(); + AppendMachine(dup); + return; + } + + while ( upper > lower ) { + NFA_Machine* dup2; + if ( --upper == lower ) + // Don't need "dup" for any further copies + dup2 = dup; + else + dup2 = dup->DuplicateMachine(); + + dup2->MakeOptional(); + AppendMachine(dup2); + } +} + +void NFA_Machine::Describe(ODesc* d) const { d->Add("NFA machine"); } + +void NFA_Machine::Dump(FILE* f) { + first_state->Dump(f); + first_state->ClearMarks(); +} + +NFA_Machine* make_alternate(NFA_Machine* m1, NFA_Machine* m2) { + if ( ! m1 ) + return m2; + if ( ! m2 ) + return m1; + + NFA_State* first = new EpsilonState(); + NFA_State* last = new EpsilonState(); + + first->AddXtion(m1->FirstState()); + first->AddXtion(m2->FirstState()); + + m1->AppendState(last); + m2->AppendState(last); + Ref(last); + + // Keep these around. + Ref(m1->FirstState()); + Ref(m2->FirstState()); + + Unref(m1); + Unref(m2); + + return new NFA_Machine(first, last); +} + +NFA_state_list* epsilon_closure(NFA_state_list* states) { + // We just keep one of this as it may get quite large. + static IntSet closuremap; + closuremap.Clear(); + + NFA_state_list* closure = new NFA_state_list; + + for ( int i = 0; i < states->length(); ++i ) { + NFA_state_list* stateclosure = (*states)[i]->EpsilonClosure(); + + for ( int j = 0; j < stateclosure->length(); ++j ) { + NFA_State* ns = (*stateclosure)[j]; + if ( ! closuremap.Contains(ns->ID()) ) { + closuremap.Insert(ns->ID()); + closure->push_back(ns); + } + } + } + + // Sort all of the closures in the list by ID + std::sort(closure->begin(), closure->end(), NFA_state_cmp_neg); + + // Make it fit. + closure->resize(0); + + delete states; + + return closure; +} + +bool NFA_state_cmp_neg(const NFA_State* v1, const NFA_State* v2) { + if ( v1->ID() < v2->ID() ) + return true; + else + return false; +} + +} // namespace zeek::detail diff --git a/src/NFA.h b/src/NFA.h index 6304662639..8364dee6d6 100644 --- a/src/NFA.h +++ b/src/NFA.h @@ -16,13 +16,11 @@ #define SYM_EPSILON 259 #define SYM_CCL 260 -namespace zeek - { +namespace zeek { class Func; -namespace detail - { +namespace detail { class CCL; class EquivClass; @@ -30,104 +28,100 @@ class EquivClass; class NFA_State; using NFA_state_list = PList; -class NFA_State : public Obj - { +class NFA_State : public Obj { public: - NFA_State(int sym, EquivClass* ec); - explicit NFA_State(CCL* ccl); - ~NFA_State() override; + NFA_State(int sym, EquivClass* ec); + explicit NFA_State(CCL* ccl); + ~NFA_State() override; - void AddXtion(NFA_State* next_state) { xtions.push_back(next_state); } - NFA_state_list* Transitions() { return &xtions; } - void AddXtionsTo(NFA_state_list* ns); + void AddXtion(NFA_State* next_state) { xtions.push_back(next_state); } + NFA_state_list* Transitions() { return &xtions; } + void AddXtionsTo(NFA_state_list* ns); - void SetAccept(int accept_val) { accept = accept_val; } - int Accept() const { return accept; } + void SetAccept(int accept_val) { accept = accept_val; } + int Accept() const { return accept; } - // Returns a deep copy of this NFA state and everything it points - // to. Upon return, each state's marker is set to point to its - // copy. - NFA_State* DeepCopy(); + // Returns a deep copy of this NFA state and everything it points + // to. Upon return, each state's marker is set to point to its + // copy. + NFA_State* DeepCopy(); - void SetMark(NFA_State* m) { mark = m; } - NFA_State* Mark() const { return mark; } - void ClearMarks(); + void SetMark(NFA_State* m) { mark = m; } + NFA_State* Mark() const { return mark; } + void ClearMarks(); - void SetFirstTransIsBackRef() { first_trans_is_back_ref = true; } + void SetFirstTransIsBackRef() { first_trans_is_back_ref = true; } - int TransSym() const { return sym; } - CCL* TransCCL() const { return ccl; } - int ID() const { return id; } + int TransSym() const { return sym; } + CCL* TransCCL() const { return ccl; } + int ID() const { return id; } - NFA_state_list* EpsilonClosure(); + NFA_state_list* EpsilonClosure(); - void Describe(ODesc* d) const override; - void Dump(FILE* f); + void Describe(ODesc* d) const override; + void Dump(FILE* f); protected: - int sym; // if SYM_CCL, then use ccl - int id; // number that uniquely identifies this state - CCL* ccl; // if nil, then use sym - int accept; + int sym; // if SYM_CCL, then use ccl + int id; // number that uniquely identifies this state + CCL* ccl; // if nil, then use sym + int accept; - // Whether the first transition points backwards. Used - // to avoid reference-counting loops. - bool first_trans_is_back_ref; + // Whether the first transition points backwards. Used + // to avoid reference-counting loops. + bool first_trans_is_back_ref; - NFA_state_list xtions; - NFA_state_list* epsclosure; - NFA_State* mark; - }; + NFA_state_list xtions; + NFA_state_list* epsclosure; + NFA_State* mark; +}; -class EpsilonState : public NFA_State - { +class EpsilonState : public NFA_State { public: - EpsilonState() : NFA_State(SYM_EPSILON, nullptr) { } - }; + EpsilonState() : NFA_State(SYM_EPSILON, nullptr) {} +}; -class NFA_Machine : public Obj - { +class NFA_Machine : public Obj { public: - explicit NFA_Machine(NFA_State* first, NFA_State* final = nullptr); - ~NFA_Machine() override; + explicit NFA_Machine(NFA_State* first, NFA_State* final = nullptr); + ~NFA_Machine() override; - NFA_State* FirstState() const { return first_state; } + NFA_State* FirstState() const { return first_state; } - void SetFinalState(NFA_State* final) { final_state = final; } - NFA_State* FinalState() const { return final_state; } + void SetFinalState(NFA_State* final) { final_state = final; } + NFA_State* FinalState() const { return final_state; } - void AddAccept(int accept_val); + void AddAccept(int accept_val); - void MakeClosure() - { - MakePositiveClosure(); - MakeOptional(); - } - void MakeOptional(); - void MakePositiveClosure(); + void MakeClosure() { + MakePositiveClosure(); + MakeOptional(); + } + void MakeOptional(); + void MakePositiveClosure(); - // re{lower,upper}; upper can be NO_UPPER_BOUND = infinity. - void MakeRepl(int lower, int upper); + // re{lower,upper}; upper can be NO_UPPER_BOUND = infinity. + void MakeRepl(int lower, int upper); - void MarkBOL() { bol = 1; } - void MarkEOL() { eol = 1; } + void MarkBOL() { bol = 1; } + void MarkEOL() { eol = 1; } - NFA_Machine* DuplicateMachine(); - void LinkCopies(int n); - void InsertEpsilon(); - void AppendEpsilon(); + NFA_Machine* DuplicateMachine(); + void LinkCopies(int n); + void InsertEpsilon(); + void AppendEpsilon(); - void AppendState(NFA_State* new_state); - void AppendMachine(NFA_Machine* new_mach); + void AppendState(NFA_State* new_state); + void AppendMachine(NFA_Machine* new_mach); - void Describe(ODesc* d) const override; - void Dump(FILE* f); + void Describe(ODesc* d) const override; + void Dump(FILE* f); protected: - NFA_State* first_state; - NFA_State* final_state; - int bol, eol; - }; + NFA_State* first_state; + NFA_State* final_state; + int bol, eol; +}; extern NFA_Machine* make_alternate(NFA_Machine* m1, NFA_Machine* m2); @@ -141,5 +135,5 @@ extern NFA_state_list* epsilon_closure(NFA_state_list* states); // For sorting NFA states based on their ID fields (decreasing) extern bool NFA_state_cmp_neg(const NFA_State* v1, const NFA_State* v2); - } // namespace detail - } // namespace zeek +} // namespace detail +} // namespace zeek diff --git a/src/NetVar.cc b/src/NetVar.cc index 3389c7b224..a3d9323a0e 100644 --- a/src/NetVar.cc +++ b/src/NetVar.cc @@ -105,8 +105,7 @@ zeek::StringVal* cmd_line_bpf_filter; zeek::StringVal* global_hash_seed; -namespace zeek::detail - { +namespace zeek::detail { int watchdog_interval; @@ -194,29 +193,26 @@ zeek_uint_t bits_per_uid; zeek_uint_t tunnel_max_changes_per_connection; - } // namespace zeek::detail. The namespace has be closed here before we include the netvar_def - // files. +} // namespace zeek::detail + // files. // Because of how the BIF include files are built with namespaces already in them, // these files need to be included separately before the namespace is opened below. -static void bif_init_event_handlers() - { +static void bif_init_event_handlers() { #include "event.bif.netvar_init" - } +} -static void bif_init_net_var() - { +static void bif_init_net_var() { #include "const.bif.netvar_init" #include "packet_analysis.bif.netvar_init" #include "reporter.bif.netvar_init" #include "supervisor.bif.netvar_init" - } +} -static void init_bif_types() - { +static void init_bif_types() { #include "types.bif.netvar_init" - } +} #include "const.bif.netvar_def" #include "event.bif.netvar_def" @@ -226,127 +222,116 @@ static void init_bif_types() #include "types.bif.netvar_def" // Re-open the namespace now that the bif headers are all included. -namespace zeek::detail - { +namespace zeek::detail { -void init_event_handlers() - { - bif_init_event_handlers(); - } +void init_event_handlers() { bif_init_event_handlers(); } -void init_general_global_var() - { - table_expire_interval = id::find_val("table_expire_interval")->AsInterval(); - table_expire_delay = id::find_val("table_expire_delay")->AsInterval(); - table_incremental_step = id::find_val("table_incremental_step")->AsCount(); - packet_filter_default = id::find_val("packet_filter_default")->AsBool(); - sig_max_group_size = id::find_val("sig_max_group_size")->AsCount(); - check_for_unused_event_handlers = id::find_val("check_for_unused_event_handlers")->AsBool(); - record_all_packets = id::find_val("record_all_packets")->AsBool(); - bits_per_uid = id::find_val("bits_per_uid")->AsCount(); - } +void init_general_global_var() { + table_expire_interval = id::find_val("table_expire_interval")->AsInterval(); + table_expire_delay = id::find_val("table_expire_delay")->AsInterval(); + table_incremental_step = id::find_val("table_incremental_step")->AsCount(); + packet_filter_default = id::find_val("packet_filter_default")->AsBool(); + sig_max_group_size = id::find_val("sig_max_group_size")->AsCount(); + check_for_unused_event_handlers = id::find_val("check_for_unused_event_handlers")->AsBool(); + record_all_packets = id::find_val("record_all_packets")->AsBool(); + bits_per_uid = id::find_val("bits_per_uid")->AsCount(); +} -void init_builtin_types() - { - init_bif_types(); - id::detail::init_types(); - } +void init_builtin_types() { + init_bif_types(); + id::detail::init_types(); +} -void init_net_var() - { - bif_init_net_var(); +void init_net_var() { + bif_init_net_var(); - ignore_checksums = id::find_val("ignore_checksums")->AsBool(); - partial_connection_ok = id::find_val("partial_connection_ok")->AsBool(); - tcp_SYN_ack_ok = id::find_val("tcp_SYN_ack_ok")->AsBool(); - tcp_match_undelivered = id::find_val("tcp_match_undelivered")->AsBool(); + ignore_checksums = id::find_val("ignore_checksums")->AsBool(); + partial_connection_ok = id::find_val("partial_connection_ok")->AsBool(); + tcp_SYN_ack_ok = id::find_val("tcp_SYN_ack_ok")->AsBool(); + tcp_match_undelivered = id::find_val("tcp_match_undelivered")->AsBool(); - frag_timeout = id::find_val("frag_timeout")->AsInterval(); + frag_timeout = id::find_val("frag_timeout")->AsInterval(); - tcp_SYN_timeout = id::find_val("tcp_SYN_timeout")->AsInterval(); - tcp_session_timer = id::find_val("tcp_session_timer")->AsInterval(); - tcp_connection_linger = id::find_val("tcp_connection_linger")->AsInterval(); - tcp_attempt_delay = id::find_val("tcp_attempt_delay")->AsInterval(); - tcp_close_delay = id::find_val("tcp_close_delay")->AsInterval(); - tcp_reset_delay = id::find_val("tcp_reset_delay")->AsInterval(); - tcp_partial_close_delay = id::find_val("tcp_partial_close_delay")->AsInterval(); + tcp_SYN_timeout = id::find_val("tcp_SYN_timeout")->AsInterval(); + tcp_session_timer = id::find_val("tcp_session_timer")->AsInterval(); + tcp_connection_linger = id::find_val("tcp_connection_linger")->AsInterval(); + tcp_attempt_delay = id::find_val("tcp_attempt_delay")->AsInterval(); + tcp_close_delay = id::find_val("tcp_close_delay")->AsInterval(); + tcp_reset_delay = id::find_val("tcp_reset_delay")->AsInterval(); + tcp_partial_close_delay = id::find_val("tcp_partial_close_delay")->AsInterval(); - tcp_max_initial_window = id::find_val("tcp_max_initial_window")->AsCount(); - tcp_max_above_hole_without_any_acks = - id::find_val("tcp_max_above_hole_without_any_acks")->AsCount(); - tcp_excessive_data_without_further_acks = - id::find_val("tcp_excessive_data_without_further_acks")->AsCount(); - tcp_max_old_segments = id::find_val("tcp_max_old_segments")->AsCount(); + tcp_max_initial_window = id::find_val("tcp_max_initial_window")->AsCount(); + tcp_max_above_hole_without_any_acks = id::find_val("tcp_max_above_hole_without_any_acks")->AsCount(); + tcp_excessive_data_without_further_acks = id::find_val("tcp_excessive_data_without_further_acks")->AsCount(); + tcp_max_old_segments = id::find_val("tcp_max_old_segments")->AsCount(); - non_analyzed_lifetime = id::find_val("non_analyzed_lifetime")->AsInterval(); - tcp_inactivity_timeout = id::find_val("tcp_inactivity_timeout")->AsInterval(); - udp_inactivity_timeout = id::find_val("udp_inactivity_timeout")->AsInterval(); - icmp_inactivity_timeout = id::find_val("icmp_inactivity_timeout")->AsInterval(); + non_analyzed_lifetime = id::find_val("non_analyzed_lifetime")->AsInterval(); + tcp_inactivity_timeout = id::find_val("tcp_inactivity_timeout")->AsInterval(); + udp_inactivity_timeout = id::find_val("udp_inactivity_timeout")->AsInterval(); + icmp_inactivity_timeout = id::find_val("icmp_inactivity_timeout")->AsInterval(); - tcp_storm_thresh = id::find_val("tcp_storm_thresh")->AsCount(); - tcp_storm_interarrival_thresh = id::find_val("tcp_storm_interarrival_thresh")->AsInterval(); + tcp_storm_thresh = id::find_val("tcp_storm_thresh")->AsCount(); + tcp_storm_interarrival_thresh = id::find_val("tcp_storm_interarrival_thresh")->AsInterval(); - tcp_content_deliver_all_orig = bool(id::find_val("tcp_content_deliver_all_orig")->AsBool()); - tcp_content_deliver_all_resp = bool(id::find_val("tcp_content_deliver_all_resp")->AsBool()); + tcp_content_deliver_all_orig = bool(id::find_val("tcp_content_deliver_all_orig")->AsBool()); + tcp_content_deliver_all_resp = bool(id::find_val("tcp_content_deliver_all_resp")->AsBool()); - udp_content_deliver_all_orig = bool(id::find_val("udp_content_deliver_all_orig")->AsBool()); - udp_content_deliver_all_resp = bool(id::find_val("udp_content_deliver_all_resp")->AsBool()); - udp_content_delivery_ports_use_resp = bool( - id::find_val("udp_content_delivery_ports_use_resp")->AsBool()); + udp_content_deliver_all_orig = bool(id::find_val("udp_content_deliver_all_orig")->AsBool()); + udp_content_deliver_all_resp = bool(id::find_val("udp_content_deliver_all_resp")->AsBool()); + udp_content_delivery_ports_use_resp = bool(id::find_val("udp_content_delivery_ports_use_resp")->AsBool()); - dns_session_timeout = id::find_val("dns_session_timeout")->AsInterval(); - rpc_timeout = id::find_val("rpc_timeout")->AsInterval(); + dns_session_timeout = id::find_val("dns_session_timeout")->AsInterval(); + rpc_timeout = id::find_val("rpc_timeout")->AsInterval(); - watchdog_interval = int(id::find_val("watchdog_interval")->AsInterval()); + watchdog_interval = int(id::find_val("watchdog_interval")->AsInterval()); - max_timer_expires = id::find_val("max_timer_expires")->AsCount(); + max_timer_expires = id::find_val("max_timer_expires")->AsCount(); - mime_segment_length = id::find_val("mime_segment_length")->AsCount(); - mime_segment_overlap_length = id::find_val("mime_segment_overlap_length")->AsCount(); + mime_segment_length = id::find_val("mime_segment_length")->AsCount(); + mime_segment_overlap_length = id::find_val("mime_segment_overlap_length")->AsCount(); - http_entity_data_delivery_size = id::find_val("http_entity_data_delivery_size")->AsCount(); - truncate_http_URI = id::find_val("truncate_http_URI")->AsInt(); + http_entity_data_delivery_size = id::find_val("http_entity_data_delivery_size")->AsCount(); + truncate_http_URI = id::find_val("truncate_http_URI")->AsInt(); - dns_skip_all_auth = id::find_val("dns_skip_all_auth")->AsBool(); - dns_skip_all_addl = id::find_val("dns_skip_all_addl")->AsBool(); - dns_max_queries = id::find_val("dns_max_queries")->AsCount(); + dns_skip_all_auth = id::find_val("dns_skip_all_auth")->AsBool(); + dns_skip_all_addl = id::find_val("dns_skip_all_addl")->AsBool(); + dns_max_queries = id::find_val("dns_max_queries")->AsCount(); - orig_addr_anonymization = 0; - if ( const auto& id = id::find("orig_addr_anonymization") ) - if ( const auto& v = id->GetVal() ) - orig_addr_anonymization = v->AsInt(); - resp_addr_anonymization = 0; - if ( const auto& id = id::find("resp_addr_anonymization") ) - if ( const auto& v = id->GetVal() ) - resp_addr_anonymization = v->AsInt(); - other_addr_anonymization = 0; - if ( const auto& id = id::find("other_addr_anonymization") ) - if ( const auto& v = id->GetVal() ) - other_addr_anonymization = v->AsInt(); + orig_addr_anonymization = 0; + if ( const auto& id = id::find("orig_addr_anonymization") ) + if ( const auto& v = id->GetVal() ) + orig_addr_anonymization = v->AsInt(); + resp_addr_anonymization = 0; + if ( const auto& id = id::find("resp_addr_anonymization") ) + if ( const auto& v = id->GetVal() ) + resp_addr_anonymization = v->AsInt(); + other_addr_anonymization = 0; + if ( const auto& id = id::find("other_addr_anonymization") ) + if ( const auto& v = id->GetVal() ) + other_addr_anonymization = v->AsInt(); - connection_status_update_interval = 0.0; - if ( const auto& id = id::find("connection_status_update_interval") ) - if ( const auto& v = id->GetVal() ) - connection_status_update_interval = v->AsInterval(); + connection_status_update_interval = 0.0; + if ( const auto& id = id::find("connection_status_update_interval") ) + if ( const auto& v = id->GetVal() ) + connection_status_update_interval = v->AsInterval(); - expensive_profiling_multiple = id::find_val("expensive_profiling_multiple")->AsCount(); - profiling_interval = id::find_val("profiling_interval")->AsInterval(); - segment_profiling = id::find_val("segment_profiling")->AsBool(); + expensive_profiling_multiple = id::find_val("expensive_profiling_multiple")->AsCount(); + profiling_interval = id::find_val("profiling_interval")->AsInterval(); + segment_profiling = id::find_val("segment_profiling")->AsBool(); - pkt_profile_mode = id::find_val("pkt_profile_mode")->InternalInt(); - pkt_profile_freq = id::find_val("pkt_profile_freq")->AsDouble(); + pkt_profile_mode = id::find_val("pkt_profile_mode")->InternalInt(); + pkt_profile_freq = id::find_val("pkt_profile_freq")->AsDouble(); - load_sample_freq = id::find_val("load_sample_freq")->AsCount(); + load_sample_freq = id::find_val("load_sample_freq")->AsCount(); - dpd_reassemble_first_packets = id::find_val("dpd_reassemble_first_packets")->AsBool(); - dpd_buffer_size = id::find_val("dpd_buffer_size")->AsCount(); - dpd_max_packets = id::find_val("dpd_max_packets")->AsCount(); - dpd_match_only_beginning = id::find_val("dpd_match_only_beginning")->AsBool(); - dpd_late_match_stop = id::find_val("dpd_late_match_stop")->AsBool(); - dpd_ignore_ports = id::find_val("dpd_ignore_ports")->AsBool(); + dpd_reassemble_first_packets = id::find_val("dpd_reassemble_first_packets")->AsBool(); + dpd_buffer_size = id::find_val("dpd_buffer_size")->AsCount(); + dpd_max_packets = id::find_val("dpd_max_packets")->AsCount(); + dpd_match_only_beginning = id::find_val("dpd_match_only_beginning")->AsBool(); + dpd_late_match_stop = id::find_val("dpd_late_match_stop")->AsBool(); + dpd_ignore_ports = id::find_val("dpd_ignore_ports")->AsBool(); - tunnel_max_changes_per_connection = - id::find_val("Tunnel::max_changes_per_connection")->AsCount(); - } + tunnel_max_changes_per_connection = id::find_val("Tunnel::max_changes_per_connection")->AsCount(); +} - } // namespace zeek::detail +} // namespace zeek::detail diff --git a/src/NetVar.h b/src/NetVar.h index c56d57f5d4..21bbcffd22 100644 --- a/src/NetVar.h +++ b/src/NetVar.h @@ -6,8 +6,7 @@ #include "zeek/Stats.h" #include "zeek/Val.h" -namespace zeek::detail - { +namespace zeek::detail { extern int watchdog_interval; @@ -103,7 +102,7 @@ extern void init_event_handlers(); extern void init_net_var(); extern void init_builtin_types(); - } // namespace zeek::detail +} // namespace zeek::detail #include "const.bif.netvar_h" #include "event.bif.netvar_h" diff --git a/src/Notifier.cc b/src/Notifier.cc index c8bef502c7..ce3e8d68eb 100644 --- a/src/Notifier.cc +++ b/src/Notifier.cc @@ -8,84 +8,68 @@ zeek::notifier::detail::Registry zeek::notifier::detail::registry; -namespace zeek::notifier::detail - { +namespace zeek::notifier::detail { -Receiver::Receiver() - { - DBG_LOG(DBG_NOTIFIERS, "creating receiver %p", this); - } +Receiver::Receiver() { DBG_LOG(DBG_NOTIFIERS, "creating receiver %p", this); } -Receiver::~Receiver() - { - DBG_LOG(DBG_NOTIFIERS, "deleting receiver %p", this); - } +Receiver::~Receiver() { DBG_LOG(DBG_NOTIFIERS, "deleting receiver %p", this); } -Registry::~Registry() - { - while ( registrations.begin() != registrations.end() ) - Unregister(registrations.begin()->first); - } +Registry::~Registry() { + while ( registrations.begin() != registrations.end() ) + Unregister(registrations.begin()->first); +} -void Registry::Register(Modifiable* m, Receiver* r) - { - DBG_LOG(DBG_NOTIFIERS, "registering object %p for receiver %p", m, r); +void Registry::Register(Modifiable* m, Receiver* r) { + DBG_LOG(DBG_NOTIFIERS, "registering object %p for receiver %p", m, r); - registrations.insert({m, r}); - ++m->num_receivers; - } + registrations.insert({m, r}); + ++m->num_receivers; +} -void Registry::Unregister(Modifiable* m, Receiver* r) - { - DBG_LOG(DBG_NOTIFIERS, "unregistering object %p from receiver %p", m, r); +void Registry::Unregister(Modifiable* m, Receiver* r) { + DBG_LOG(DBG_NOTIFIERS, "unregistering object %p from receiver %p", m, r); - auto x = registrations.equal_range(m); - for ( auto i = x.first; i != x.second; i++ ) - { - if ( i->second == r ) - { - --i->first->num_receivers; - registrations.erase(i); - break; - } - } - } + auto x = registrations.equal_range(m); + for ( auto i = x.first; i != x.second; i++ ) { + if ( i->second == r ) { + --i->first->num_receivers; + registrations.erase(i); + break; + } + } +} -void Registry::Unregister(Modifiable* m) - { - DBG_LOG(DBG_NOTIFIERS, "unregistering object %p from all notifiers", m); +void Registry::Unregister(Modifiable* m) { + DBG_LOG(DBG_NOTIFIERS, "unregistering object %p from all notifiers", m); - auto x = registrations.equal_range(m); - for ( auto i = x.first; i != x.second; i++ ) - --i->first->num_receivers; + auto x = registrations.equal_range(m); + for ( auto i = x.first; i != x.second; i++ ) + --i->first->num_receivers; - registrations.erase(x.first, x.second); - } + registrations.erase(x.first, x.second); +} -void Registry::Modified(Modifiable* m) - { - DBG_LOG(DBG_NOTIFIERS, "object %p has been modified", m); +void Registry::Modified(Modifiable* m) { + DBG_LOG(DBG_NOTIFIERS, "object %p has been modified", m); - auto x = registrations.equal_range(m); - for ( auto i = x.first; i != x.second; i++ ) - i->second->Modified(m); - } + auto x = registrations.equal_range(m); + for ( auto i = x.first; i != x.second; i++ ) + i->second->Modified(m); +} -void Registry::Terminate() - { - std::set receivers; +void Registry::Terminate() { + std::set receivers; - for ( auto& r : registrations ) - receivers.emplace(r.second); + for ( auto& r : registrations ) + receivers.emplace(r.second); - for ( auto& r : receivers ) - r->Terminate(); - } + for ( auto& r : receivers ) + r->Terminate(); +} -Modifiable::~Modifiable() - { - if ( num_receivers ) - registry.Unregister(this); - } +Modifiable::~Modifiable() { + if ( num_receivers ) + registry.Unregister(this); +} - } // namespace zeek::notifier::detail +} // namespace zeek::notifier::detail diff --git a/src/Notifier.h b/src/Notifier.h index de15c3ea9e..592c491219 100644 --- a/src/Notifier.h +++ b/src/Notifier.h @@ -10,86 +10,83 @@ #include #include -namespace zeek::notifier::detail - { +namespace zeek::notifier::detail { class Modifiable; /** Interface class for receivers of notifications. */ -class Receiver - { +class Receiver { public: - Receiver(); - virtual ~Receiver(); + Receiver(); + virtual ~Receiver(); - /** - * Callback executed when a register object has been modified. - * - * @param m object that was modified - */ - virtual void Modified(Modifiable* m) = 0; + /** + * Callback executed when a register object has been modified. + * + * @param m object that was modified + */ + virtual void Modified(Modifiable* m) = 0; - /** - * Callback executed when notification registry is terminating and - * no further modifications can possibly occur. - */ - virtual void Terminate() { } - }; + /** + * Callback executed when notification registry is terminating and + * no further modifications can possibly occur. + */ + virtual void Terminate() {} +}; /** Singleton class tracking all notification requests globally. */ -class Registry - { +class Registry { public: - ~Registry(); + ~Registry(); - /** - * Registers a receiver to be informed when a modifiable object has - * changed. - * - * @param m object to track. Does not take ownership, but the object - * will automatically unregister itself on destruction. - * - * @param r receiver to notify on changes. Does not take ownership, - * the receiver must remain valid as long as the registration stays - * in place. - */ - void Register(Modifiable* m, Receiver* r); + /** + * Registers a receiver to be informed when a modifiable object has + * changed. + * + * @param m object to track. Does not take ownership, but the object + * will automatically unregister itself on destruction. + * + * @param r receiver to notify on changes. Does not take ownership, + * the receiver must remain valid as long as the registration stays + * in place. + */ + void Register(Modifiable* m, Receiver* r); - /** - * Cancels a receiver's request to be informed about an object's - * modification. The arguments to the method must match what was - * originally registered. - * - * @param m object to no longer track. - * - * @param r receiver to no longer notify. - */ - void Unregister(Modifiable* m, Receiver* Receiver); + /** + * Cancels a receiver's request to be informed about an object's + * modification. The arguments to the method must match what was + * originally registered. + * + * @param m object to no longer track. + * + * @param r receiver to no longer notify. + */ + void Unregister(Modifiable* m, Receiver* Receiver); - /** - * Cancels any active receiver requests to be informed about a - * particular object's modifications. - * - * @param m object to no longer track. - */ - void Unregister(Modifiable* m); + /** + * Cancels any active receiver requests to be informed about a + * particular object's modifications. + * + * @param m object to no longer track. + */ + void Unregister(Modifiable* m); - /** - * Notifies all receivers that no further modifications will occur - * as the registry is shutting down. - */ - void Terminate(); + /** + * Notifies all receivers that no further modifications will occur + * as the registry is shutting down. + */ + void Terminate(); private: - friend class Modifiable; + friend class Modifiable; - // Inform all registered receivers of a modification to an object. - // Will be called from the object itself. - void Modified(Modifiable* m); + // Inform all registered receivers of a modification to an object. + // Will be called from the object itself. + void Modified(Modifiable* m); - using ModifiableMap = std::unordered_multimap; - ModifiableMap registrations; - }; + using ModifiableMap = std::unordered_multimap; + ModifiableMap registrations; +}; /** * Singleton object tracking all global notification requests. @@ -100,26 +97,24 @@ extern Registry registry; * Base class for objects that can trigger notifications to receivers when * modified. */ -class Modifiable - { +class Modifiable { public: - /** - * Calling this method signals to all registered receivers that the - * object has been modified. - */ - void Modified() - { - if ( num_receivers ) - registry.Modified(this); - } + /** + * Calling this method signals to all registered receivers that the + * object has been modified. + */ + void Modified() { + if ( num_receivers ) + registry.Modified(this); + } protected: - friend class Registry; + friend class Registry; - virtual ~Modifiable(); + virtual ~Modifiable(); - // Number of currently registered receivers. - uint64_t num_receivers = 0; - }; + // Number of currently registered receivers. + uint64_t num_receivers = 0; +}; - } // namespace zeek::notifier::detail +} // namespace zeek::notifier::detail diff --git a/src/Obj.cc b/src/Obj.cc index 1bf384aeb4..a3b36dee76 100644 --- a/src/Obj.cc +++ b/src/Obj.cc @@ -11,208 +11,180 @@ #include "zeek/Func.h" #include "zeek/plugin/Manager.h" -namespace zeek - { -namespace detail - { +namespace zeek { +namespace detail { Location start_location("", 0, 0, 0, 0); Location end_location("", 0, 0, 0, 0); -void Location::Describe(ODesc* d) const - { - if ( filename ) - { - d->Add(filename); +void Location::Describe(ODesc* d) const { + if ( filename ) { + d->Add(filename); - if ( first_line == 0 ) - return; + if ( first_line == 0 ) + return; - d->AddSP(","); - } + d->AddSP(","); + } - if ( last_line != first_line ) - { - d->Add("lines "); - d->Add(first_line); - d->Add("-"); - d->Add(last_line); - } - else - { - d->Add("line "); - d->Add(first_line); - } - } + if ( last_line != first_line ) { + d->Add("lines "); + d->Add(first_line); + d->Add("-"); + d->Add(last_line); + } + else { + d->Add("line "); + d->Add(first_line); + } +} -bool Location::operator==(const Location& l) const - { - if ( filename == l.filename || (filename && l.filename && util::streq(filename, l.filename)) ) - return first_line == l.first_line && last_line == l.last_line; - else - return false; - } +bool Location::operator==(const Location& l) const { + if ( filename == l.filename || (filename && l.filename && util::streq(filename, l.filename)) ) + return first_line == l.first_line && last_line == l.last_line; + else + return false; +} - } // namespace detail +} // namespace detail int Obj::suppress_errors = 0; -Obj::~Obj() - { - if ( notify_plugins ) - PLUGIN_HOOK_VOID(HOOK_OBJ_DTOR, HookObjDtor(this)); +Obj::~Obj() { + if ( notify_plugins ) + PLUGIN_HOOK_VOID(HOOK_OBJ_DTOR, HookObjDtor(this)); - delete location; - } + delete location; +} -void Obj::Warn(const char* msg, const Obj* obj2, bool pinpoint_only, - const detail::Location* expr_location) const - { - ODesc d; - DoMsg(&d, msg, obj2, pinpoint_only, expr_location); - reporter->Warning("%s", d.Description()); - reporter->PopLocation(); - } +void Obj::Warn(const char* msg, const Obj* obj2, bool pinpoint_only, const detail::Location* expr_location) const { + ODesc d; + DoMsg(&d, msg, obj2, pinpoint_only, expr_location); + reporter->Warning("%s", d.Description()); + reporter->PopLocation(); +} -void Obj::Error(const char* msg, const Obj* obj2, bool pinpoint_only, - const detail::Location* expr_location) const - { - if ( suppress_errors ) - return; +void Obj::Error(const char* msg, const Obj* obj2, bool pinpoint_only, const detail::Location* expr_location) const { + if ( suppress_errors ) + return; - ODesc d; - DoMsg(&d, msg, obj2, pinpoint_only, expr_location); - reporter->Error("%s", d.Description()); - reporter->PopLocation(); - } + ODesc d; + DoMsg(&d, msg, obj2, pinpoint_only, expr_location); + reporter->Error("%s", d.Description()); + reporter->PopLocation(); +} -void Obj::BadTag(const char* msg, const char* t1, const char* t2) const - { - char out[512]; +void Obj::BadTag(const char* msg, const char* t1, const char* t2) const { + char out[512]; - if ( t2 ) - snprintf(out, sizeof(out), "%s (%s/%s)", msg, t1, t2); - else if ( t1 ) - snprintf(out, sizeof(out), "%s (%s)", msg, t1); - else - snprintf(out, sizeof(out), "%s", msg); + if ( t2 ) + snprintf(out, sizeof(out), "%s (%s/%s)", msg, t1, t2); + else if ( t1 ) + snprintf(out, sizeof(out), "%s (%s)", msg, t1); + else + snprintf(out, sizeof(out), "%s", msg); - ODesc d; - DoMsg(&d, out); - reporter->FatalErrorWithCore("%s", d.Description()); - reporter->PopLocation(); - } + ODesc d; + DoMsg(&d, out); + reporter->FatalErrorWithCore("%s", d.Description()); + reporter->PopLocation(); +} -void Obj::Internal(const char* msg) const - { - ODesc d; - DoMsg(&d, msg); - auto rcs = render_call_stack(); +void Obj::Internal(const char* msg) const { + ODesc d; + DoMsg(&d, msg); + auto rcs = render_call_stack(); - if ( rcs.empty() ) - reporter->InternalError("%s", d.Description()); - else - reporter->InternalError("%s, call stack: %s", d.Description(), rcs.data()); + if ( rcs.empty() ) + reporter->InternalError("%s", d.Description()); + else + reporter->InternalError("%s, call stack: %s", d.Description(), rcs.data()); - reporter->PopLocation(); - } + reporter->PopLocation(); +} -void Obj::InternalWarning(const char* msg) const - { - ODesc d; - DoMsg(&d, msg); - reporter->InternalWarning("%s", d.Description()); - reporter->PopLocation(); - } +void Obj::InternalWarning(const char* msg) const { + ODesc d; + DoMsg(&d, msg); + reporter->InternalWarning("%s", d.Description()); + reporter->PopLocation(); +} -void Obj::AddLocation(ODesc* d) const - { - if ( ! location ) - { - d->Add(""); - return; - } +void Obj::AddLocation(ODesc* d) const { + if ( ! location ) { + d->Add(""); + return; + } - location->Describe(d); - } + location->Describe(d); +} -bool Obj::SetLocationInfo(const detail::Location* start, const detail::Location* end) - { - if ( ! start || ! end ) - return false; +bool Obj::SetLocationInfo(const detail::Location* start, const detail::Location* end) { + if ( ! start || ! end ) + return false; - if ( end->filename && ! util::streq(start->filename, end->filename) ) - return false; + if ( end->filename && ! util::streq(start->filename, end->filename) ) + return false; - if ( location && (start == &detail::no_location || end == &detail::no_location) ) - // We already have a better location, so don't use this one. - return true; + if ( location && (start == &detail::no_location || end == &detail::no_location) ) + // We already have a better location, so don't use this one. + return true; - delete location; + delete location; - location = new detail::Location(start->filename, start->first_line, end->last_line, - start->first_column, end->last_column); + location = + new detail::Location(start->filename, start->first_line, end->last_line, start->first_column, end->last_column); - return true; - } + return true; +} -void Obj::UpdateLocationEndInfo(const detail::Location& end) - { - if ( ! location ) - SetLocationInfo(&end, &end); +void Obj::UpdateLocationEndInfo(const detail::Location& end) { + if ( ! location ) + SetLocationInfo(&end, &end); - location->last_line = end.last_line; - location->last_column = end.last_column; - } + location->last_line = end.last_line; + location->last_column = end.last_column; +} void Obj::DoMsg(ODesc* d, const char s1[], const Obj* obj2, bool pinpoint_only, - const detail::Location* expr_location) const - { - d->SetShort(); + const detail::Location* expr_location) const { + d->SetShort(); - d->Add(s1); - PinPoint(d, obj2, pinpoint_only); + d->Add(s1); + PinPoint(d, obj2, pinpoint_only); - const detail::Location* loc2 = nullptr; - if ( obj2 && obj2->GetLocationInfo() != &detail::no_location && - *obj2->GetLocationInfo() != *GetLocationInfo() ) - loc2 = obj2->GetLocationInfo(); - else if ( expr_location ) - loc2 = expr_location; + const detail::Location* loc2 = nullptr; + if ( obj2 && obj2->GetLocationInfo() != &detail::no_location && *obj2->GetLocationInfo() != *GetLocationInfo() ) + loc2 = obj2->GetLocationInfo(); + else if ( expr_location ) + loc2 = expr_location; - reporter->PushLocation(GetLocationInfo(), loc2); - } + reporter->PushLocation(GetLocationInfo(), loc2); +} -void Obj::PinPoint(ODesc* d, const Obj* obj2, bool pinpoint_only) const - { - d->Add(" ("); - Describe(d); - if ( obj2 && ! pinpoint_only ) - { - d->Add(" and "); - obj2->Describe(d); - } +void Obj::PinPoint(ODesc* d, const Obj* obj2, bool pinpoint_only) const { + d->Add(" ("); + Describe(d); + if ( obj2 && ! pinpoint_only ) { + d->Add(" and "); + obj2->Describe(d); + } - d->Add(")"); - } + d->Add(")"); +} -void Obj::Print() const - { - static File fstderr(stderr); - ODesc d(DESC_READABLE, &fstderr); - Describe(&d); - d.Add("\n"); - } +void Obj::Print() const { + static File fstderr(stderr); + ODesc d(DESC_READABLE, &fstderr); + Describe(&d); + d.Add("\n"); +} -void bad_ref(int type) - { - reporter->InternalError("bad reference count [%d]", type); - abort(); - } +void bad_ref(int type) { + reporter->InternalError("bad reference count [%d]", type); + abort(); +} -void obj_delete_func(void* v) - { - Unref((Obj*)v); - } +void obj_delete_func(void* v) { Unref((Obj*)v); } - } // namespace zeek +} // namespace zeek diff --git a/src/Obj.h b/src/Obj.h index bc2eba6aa4..fda839d6f0 100644 --- a/src/Obj.h +++ b/src/Obj.h @@ -6,34 +6,28 @@ #include -namespace zeek - { +namespace zeek { class ODesc; -namespace detail - { +namespace detail { -class Location final - { +class Location final { public: - constexpr Location(const char* fname, int line_f, int line_l, int col_f, int col_l) noexcept - : filename(fname), first_line(line_f), last_line(line_l), first_column(col_f), - last_column(col_l) - { - } + constexpr Location(const char* fname, int line_f, int line_l, int col_f, int col_l) noexcept + : filename(fname), first_line(line_f), last_line(line_l), first_column(col_f), last_column(col_l) {} - Location() = default; + Location() = default; - void Describe(ODesc* d) const; + void Describe(ODesc* d) const; - bool operator==(const Location& l) const; - bool operator!=(const Location& l) const { return ! (*this == l); } + bool operator==(const Location& l) const; + bool operator!=(const Location& l) const { return ! (*this == l); } - const char* filename = nullptr; - int first_line = 0, last_line = 0; - int first_column = 0, last_column = 0; - }; + const char* filename = nullptr; + int first_line = 0, last_line = 0; + int first_column = 0, last_column = 0; +}; #define YYLTYPE zeek::detail::yyltype using yyltype = Location; @@ -48,154 +42,138 @@ extern Location start_location; extern Location end_location; // Used by parser to set the above. -inline void set_location(const Location loc) - { - start_location = end_location = loc; - } +inline void set_location(const Location loc) { start_location = end_location = loc; } -inline void set_location(const Location start, const Location end) - { - start_location = start; - end_location = end; - } +inline void set_location(const Location start, const Location end) { + start_location = start; + end_location = end; +} - } // namespace detail +} // namespace detail -class Obj - { +class Obj { public: - Obj() - { - // A bit of a hack. We'd like to associate location - // information with every object created when parsing, - // since for them, the location is generally well-defined. - // We could maintain a separate flag that tells us whether - // we're inside a parse, but the parser also sets the - // location to no_location when it's done, so it makes - // sense to just check for that. *However*, start_location - // and end_location are maintained as their own objects - // rather than pointers or references, so we can't directly - // check them for equality with no_location. So instead - // we check for whether start_location has a line number - // of 0, which should only happen if it's been assigned - // to no_location (or hasn't been initialized at all). - location = nullptr; - if ( detail::start_location.first_line != 0 ) - SetLocationInfo(&detail::start_location, &detail::end_location); - } + Obj() { + // A bit of a hack. We'd like to associate location + // information with every object created when parsing, + // since for them, the location is generally well-defined. + // We could maintain a separate flag that tells us whether + // we're inside a parse, but the parser also sets the + // location to no_location when it's done, so it makes + // sense to just check for that. *However*, start_location + // and end_location are maintained as their own objects + // rather than pointers or references, so we can't directly + // check them for equality with no_location. So instead + // we check for whether start_location has a line number + // of 0, which should only happen if it's been assigned + // to no_location (or hasn't been initialized at all). + location = nullptr; + if ( detail::start_location.first_line != 0 ) + SetLocationInfo(&detail::start_location, &detail::end_location); + } - virtual ~Obj(); + virtual ~Obj(); - /* disallow copying */ - Obj(const Obj&) = delete; - Obj& operator=(const Obj&) = delete; + /* disallow copying */ + Obj(const Obj&) = delete; + Obj& operator=(const Obj&) = delete; - // Report user warnings/errors. If obj2 is given, then it's - // included in the message, though if pinpoint_only is non-zero, - // then obj2 is only used to pinpoint the location. - void Warn(const char* msg, const Obj* obj2 = nullptr, bool pinpoint_only = false, - const detail::Location* expr_location = nullptr) const; - void Error(const char* msg, const Obj* obj2 = nullptr, bool pinpoint_only = false, - const detail::Location* expr_location = nullptr) const; + // Report user warnings/errors. If obj2 is given, then it's + // included in the message, though if pinpoint_only is non-zero, + // then obj2 is only used to pinpoint the location. + void Warn(const char* msg, const Obj* obj2 = nullptr, bool pinpoint_only = false, + const detail::Location* expr_location = nullptr) const; + void Error(const char* msg, const Obj* obj2 = nullptr, bool pinpoint_only = false, + const detail::Location* expr_location = nullptr) const; - // Report internal errors. - void BadTag(const char* msg, const char* t1 = nullptr, const char* t2 = nullptr) const; -#define CHECK_TAG(t1, t2, text, tag_to_text_func) \ - { \ - if ( t1 != t2 ) \ - BadTag(text, tag_to_text_func(t1), tag_to_text_func(t2)); \ - } + // Report internal errors. + void BadTag(const char* msg, const char* t1 = nullptr, const char* t2 = nullptr) const; +#define CHECK_TAG(t1, t2, text, tag_to_text_func) \ + { \ + if ( t1 != t2 ) \ + BadTag(text, tag_to_text_func(t1), tag_to_text_func(t2)); \ + } - [[noreturn]] void Internal(const char* msg) const; - void InternalWarning(const char* msg) const; + [[noreturn]] void Internal(const char* msg) const; + void InternalWarning(const char* msg) const; - virtual void Describe(ODesc* d) const {/* FIXME: Add code */}; + virtual void Describe(ODesc* d) const {/* FIXME: Add code */}; - void AddLocation(ODesc* d) const; + void AddLocation(ODesc* d) const; - // Get location info for debugging. - virtual const detail::Location* GetLocationInfo() const - { - return location ? location : &detail::no_location; - } + // Get location info for debugging. + virtual const detail::Location* GetLocationInfo() const { return location ? location : &detail::no_location; } - virtual bool SetLocationInfo(const detail::Location* loc) { return SetLocationInfo(loc, loc); } + virtual bool SetLocationInfo(const detail::Location* loc) { return SetLocationInfo(loc, loc); } - // Location = range from start to end. - virtual bool SetLocationInfo(const detail::Location* start, const detail::Location* end); + // Location = range from start to end. + virtual bool SetLocationInfo(const detail::Location* start, const detail::Location* end); - // Set new end-of-location information. This is used to - // extend compound objects such as statement lists. - virtual void UpdateLocationEndInfo(const detail::Location& end); + // Set new end-of-location information. This is used to + // extend compound objects such as statement lists. + virtual void UpdateLocationEndInfo(const detail::Location& end); - // Enable notification of plugins when this objects gets destroyed. - void NotifyPluginsOnDtor() { notify_plugins = true; } + // Enable notification of plugins when this objects gets destroyed. + void NotifyPluginsOnDtor() { notify_plugins = true; } - int RefCnt() const { return ref_cnt; } + int RefCnt() const { return ref_cnt; } - // Helper class to temporarily suppress errors - // as long as there exist any instances. - class SuppressErrors - { - public: - SuppressErrors() { ++Obj::suppress_errors; } - ~SuppressErrors() { --Obj::suppress_errors; } - }; + // Helper class to temporarily suppress errors + // as long as there exist any instances. + class SuppressErrors { + public: + SuppressErrors() { ++Obj::suppress_errors; } + ~SuppressErrors() { --Obj::suppress_errors; } + }; - void Print() const; + void Print() const; protected: - detail::Location* location; // all that matters in real estate + detail::Location* location; // all that matters in real estate private: - friend class SuppressErrors; + friend class SuppressErrors; - void DoMsg(ODesc* d, const char s1[], const Obj* obj2 = nullptr, bool pinpoint_only = false, - const detail::Location* expr_location = nullptr) const; - void PinPoint(ODesc* d, const Obj* obj2 = nullptr, bool pinpoint_only = false) const; + void DoMsg(ODesc* d, const char s1[], const Obj* obj2 = nullptr, bool pinpoint_only = false, + const detail::Location* expr_location = nullptr) const; + void PinPoint(ODesc* d, const Obj* obj2 = nullptr, bool pinpoint_only = false) const; - friend inline void Ref(Obj* o); - friend inline void Unref(Obj* o); + friend inline void Ref(Obj* o); + friend inline void Unref(Obj* o); - int ref_cnt = 1; - bool notify_plugins = false; + int ref_cnt = 1; + bool notify_plugins = false; - // If non-zero, do not print runtime errors. Useful for - // speculative evaluation. - static int suppress_errors; - }; + // If non-zero, do not print runtime errors. Useful for + // speculative evaluation. + static int suppress_errors; +}; // Sometimes useful when dealing with Obj subclasses that have their // own (protected) versions of Error. -inline void Error(const Obj* o, const char* msg) - { - o->Error(msg); - } +inline void Error(const Obj* o, const char* msg) { o->Error(msg); } [[noreturn]] extern void bad_ref(int type); -inline void Ref(Obj* o) - { - if ( ++(o->ref_cnt) <= 1 ) - bad_ref(0); - if ( o->ref_cnt == INT_MAX ) - bad_ref(1); - } +inline void Ref(Obj* o) { + if ( ++(o->ref_cnt) <= 1 ) + bad_ref(0); + if ( o->ref_cnt == INT_MAX ) + bad_ref(1); +} -inline void Unref(Obj* o) - { - if ( o && --o->ref_cnt <= 0 ) - { - if ( o->ref_cnt < 0 ) - bad_ref(2); - delete o; +inline void Unref(Obj* o) { + if ( o && --o->ref_cnt <= 0 ) { + if ( o->ref_cnt < 0 ) + bad_ref(2); + delete o; - // We could do the following if o were passed by reference. - // o = (Obj*) 0xcd; - } - } + // We could do the following if o were passed by reference. + // o = (Obj*) 0xcd; + } +} // A dict_delete_func that knows to Unref() dictionary entries. extern void obj_delete_func(void* v); - } // namespace zeek +} // namespace zeek diff --git a/src/OpaqueVal.cc b/src/OpaqueVal.cc index e1389e82d8..c323cb7a2e 100644 --- a/src/OpaqueVal.cc +++ b/src/OpaqueVal.cc @@ -23,1125 +23,969 @@ #include "zeek/probabilistic/BloomFilter.h" #include "zeek/probabilistic/CardinalityCounter.h" -namespace zeek - { +namespace zeek { // Helper to retrieve a broker value out of a broker::vector at a specified // index, and casted to the expected destination type. -template -inline bool get_vector_idx(const V& v, unsigned int i, D* dst) - { - if ( i >= v.size() ) - return false; +template +inline bool get_vector_idx(const V& v, unsigned int i, D* dst) { + if ( i >= v.size() ) + return false; - auto x = broker::get_if(&v[i]); - if ( ! x ) - return false; + auto x = broker::get_if(&v[i]); + if ( ! x ) + return false; - *dst = static_cast(*x); - return true; - } + *dst = static_cast(*x); + return true; +} -OpaqueMgr* OpaqueMgr::mgr() - { - static OpaqueMgr mgr; - return &mgr; - } +OpaqueMgr* OpaqueMgr::mgr() { + static OpaqueMgr mgr; + return &mgr; +} -OpaqueVal::OpaqueVal(OpaqueTypePtr t) : Val(std::move(t)) { } +OpaqueVal::OpaqueVal(OpaqueTypePtr t) : Val(std::move(t)) {} -const std::string& OpaqueMgr::TypeID(const OpaqueVal* v) const - { - auto x = _types.find(v->OpaqueName()); +const std::string& OpaqueMgr::TypeID(const OpaqueVal* v) const { + auto x = _types.find(v->OpaqueName()); - if ( x == _types.end() ) - reporter->InternalError("OpaqueMgr::TypeID: opaque type %s not registered", - v->OpaqueName()); + if ( x == _types.end() ) + reporter->InternalError("OpaqueMgr::TypeID: opaque type %s not registered", v->OpaqueName()); - return x->first; - } + return x->first; +} -OpaqueValPtr OpaqueMgr::Instantiate(const std::string& id) const - { - auto x = _types.find(id); - return x != _types.end() ? (*x->second)() : nullptr; - } +OpaqueValPtr OpaqueMgr::Instantiate(const std::string& id) const { + auto x = _types.find(id); + return x != _types.end() ? (*x->second)() : nullptr; +} -broker::expected OpaqueVal::Serialize() const - { - auto type = OpaqueMgr::mgr()->TypeID(this); +broker::expected OpaqueVal::Serialize() const { + auto type = OpaqueMgr::mgr()->TypeID(this); - auto d = DoSerialize(); - if ( ! d ) - return d.error(); + auto d = DoSerialize(); + if ( ! d ) + return d.error(); - return {broker::vector{std::move(type), std::move(*d)}}; - } + return {broker::vector{std::move(type), std::move(*d)}}; +} -OpaqueValPtr OpaqueVal::Unserialize(const broker::data& data) - { - auto v = broker::get_if(&data); +OpaqueValPtr OpaqueVal::Unserialize(const broker::data& data) { + auto v = broker::get_if(&data); - if ( ! (v && v->size() == 2) ) - return nullptr; + if ( ! (v && v->size() == 2) ) + return nullptr; - auto type = broker::get_if(&(*v)[0]); - if ( ! type ) - return nullptr; + auto type = broker::get_if(&(*v)[0]); + if ( ! type ) + return nullptr; - auto val = OpaqueMgr::mgr()->Instantiate(*type); - if ( ! val ) - return nullptr; + auto val = OpaqueMgr::mgr()->Instantiate(*type); + if ( ! val ) + return nullptr; - if ( ! val->DoUnserialize((*v)[1]) ) - return nullptr; + if ( ! val->DoUnserialize((*v)[1]) ) + return nullptr; - return val; - } + return val; +} -broker::expected OpaqueVal::SerializeType(const TypePtr& t) - { - if ( t->InternalType() == TYPE_INTERNAL_ERROR ) - return broker::ec::invalid_data; +broker::expected OpaqueVal::SerializeType(const TypePtr& t) { + if ( t->InternalType() == TYPE_INTERNAL_ERROR ) + return broker::ec::invalid_data; - if ( t->InternalType() == TYPE_INTERNAL_OTHER ) - { - // Serialize by name. - assert(t->GetName().size()); - return {broker::vector{true, t->GetName()}}; - } + if ( t->InternalType() == TYPE_INTERNAL_OTHER ) { + // Serialize by name. + assert(t->GetName().size()); + return {broker::vector{true, t->GetName()}}; + } - // A base type. - return {broker::vector{false, static_cast(t->Tag())}}; - } + // A base type. + return {broker::vector{false, static_cast(t->Tag())}}; +} -TypePtr OpaqueVal::UnserializeType(const broker::data& data) - { - auto v = broker::get_if(&data); - if ( ! (v && v->size() == 2) ) - return nullptr; +TypePtr OpaqueVal::UnserializeType(const broker::data& data) { + auto v = broker::get_if(&data); + if ( ! (v && v->size() == 2) ) + return nullptr; - auto by_name = broker::get_if(&(*v)[0]); - if ( ! by_name ) - return nullptr; + auto by_name = broker::get_if(&(*v)[0]); + if ( ! by_name ) + return nullptr; - if ( *by_name ) - { - auto name = broker::get_if(&(*v)[1]); - if ( ! name ) - return nullptr; + if ( *by_name ) { + auto name = broker::get_if(&(*v)[1]); + if ( ! name ) + return nullptr; - const auto& id = detail::global_scope()->Find(*name); - if ( ! id ) - return nullptr; + const auto& id = detail::global_scope()->Find(*name); + if ( ! id ) + return nullptr; - if ( ! id->IsType() ) - return nullptr; + if ( ! id->IsType() ) + return nullptr; - return id->GetType(); - } + return id->GetType(); + } - auto tag = broker::get_if(&(*v)[1]); - if ( ! tag ) - return nullptr; + auto tag = broker::get_if(&(*v)[1]); + if ( ! tag ) + return nullptr; - return base_type(static_cast(*tag)); - } + return base_type(static_cast(*tag)); +} -ValPtr OpaqueVal::DoClone(CloneState* state) - { - auto d = OpaqueVal::Serialize(); - if ( ! d ) - return nullptr; +ValPtr OpaqueVal::DoClone(CloneState* state) { + auto d = OpaqueVal::Serialize(); + if ( ! d ) + return nullptr; - auto rval = OpaqueVal::Unserialize(std::move(*d)); - return state->NewClone(this, std::move(rval)); - } + auto rval = OpaqueVal::Unserialize(std::move(*d)); + return state->NewClone(this, std::move(rval)); +} -void OpaqueVal::ValDescribe(ODesc* d) const - { - d->Add(util::fmt("", OpaqueName())); - } +void OpaqueVal::ValDescribe(ODesc* d) const { d->Add(util::fmt("", OpaqueName())); } -void OpaqueVal::ValDescribeReST(ODesc* d) const - { - d->Add(util::fmt("", OpaqueName())); - } +void OpaqueVal::ValDescribeReST(ODesc* d) const { d->Add(util::fmt("", OpaqueName())); } -bool HashVal::IsValid() const - { - return valid; - } +bool HashVal::IsValid() const { return valid; } -bool HashVal::Init() - { - if ( valid ) - return false; +bool HashVal::Init() { + if ( valid ) + return false; - valid = DoInit(); - return valid; - } + valid = DoInit(); + return valid; +} -StringValPtr HashVal::Get() - { - if ( ! valid ) - return val_mgr->EmptyString(); +StringValPtr HashVal::Get() { + if ( ! valid ) + return val_mgr->EmptyString(); - auto result = DoGet(); - valid = false; - return result; - } + auto result = DoGet(); + valid = false; + return result; +} -bool HashVal::Feed(const void* data, size_t size) - { - if ( valid ) - return DoFeed(data, size); +bool HashVal::Feed(const void* data, size_t size) { + if ( valid ) + return DoFeed(data, size); - Error("attempt to update an invalid opaque hash value"); - return false; - } + Error("attempt to update an invalid opaque hash value"); + return false; +} -bool HashVal::DoInit() - { - assert(! "missing implementation of DoInit()"); - return false; - } +bool HashVal::DoInit() { + assert(! "missing implementation of DoInit()"); + return false; +} -bool HashVal::DoFeed(const void*, size_t) - { - assert(! "missing implementation of DoFeed()"); - return false; - } +bool HashVal::DoFeed(const void*, size_t) { + assert(! "missing implementation of DoFeed()"); + return false; +} -StringValPtr HashVal::DoGet() - { - assert(! "missing implementation of DoGet()"); - return val_mgr->EmptyString(); - } +StringValPtr HashVal::DoGet() { + assert(! "missing implementation of DoGet()"); + return val_mgr->EmptyString(); +} -HashVal::HashVal(OpaqueTypePtr t) : OpaqueVal(std::move(t)) - { - valid = false; - } +HashVal::HashVal(OpaqueTypePtr t) : OpaqueVal(std::move(t)) { valid = false; } -MD5Val::MD5Val() : HashVal(md5_type) - { - memset(&ctx, 0, sizeof(ctx)); - } +MD5Val::MD5Val() : HashVal(md5_type) { memset(&ctx, 0, sizeof(ctx)); } -MD5Val::~MD5Val() - { +MD5Val::~MD5Val() { #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - if ( IsValid() ) - EVP_MD_CTX_free(ctx); + if ( IsValid() ) + EVP_MD_CTX_free(ctx); #endif - } +} -void HashVal::digest_one(EVP_MD_CTX* h, const Val* v) - { - if ( v->GetType()->Tag() == TYPE_STRING ) - { - const String* str = v->AsString(); - detail::hash_update(h, str->Bytes(), str->Len()); - } - else - { - ODesc d(DESC_BINARY); - v->Describe(&d); - detail::hash_update(h, (const u_char*)d.Bytes(), d.Len()); - } - } +void HashVal::digest_one(EVP_MD_CTX* h, const Val* v) { + if ( v->GetType()->Tag() == TYPE_STRING ) { + const String* str = v->AsString(); + detail::hash_update(h, str->Bytes(), str->Len()); + } + else { + ODesc d(DESC_BINARY); + v->Describe(&d); + detail::hash_update(h, (const u_char*)d.Bytes(), d.Len()); + } +} -void HashVal::digest_one(EVP_MD_CTX* h, const ValPtr& v) - { - digest_one(h, v.get()); - } +void HashVal::digest_one(EVP_MD_CTX* h, const ValPtr& v) { digest_one(h, v.get()); } -ValPtr MD5Val::DoClone(CloneState* state) - { - auto out = make_intrusive(); +ValPtr MD5Val::DoClone(CloneState* state) { + auto out = make_intrusive(); - if ( IsValid() ) - { - if ( ! out->Init() ) - return nullptr; + if ( IsValid() ) { + if ( ! out->Init() ) + return nullptr; #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - EVP_MD_CTX_copy_ex(out->ctx, ctx); + EVP_MD_CTX_copy_ex(out->ctx, ctx); #else - out->ctx = ctx; + out->ctx = ctx; #endif - } + } - return state->NewClone(this, std::move(out)); - } + return state->NewClone(this, std::move(out)); +} -bool MD5Val::DoInit() - { - assert(! IsValid()); +bool MD5Val::DoInit() { + assert(! IsValid()); #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - ctx = detail::hash_init(detail::Hash_MD5); + ctx = detail::hash_init(detail::Hash_MD5); #else - MD5_Init(&ctx); + MD5_Init(&ctx); #endif - return true; - } + return true; +} -bool MD5Val::DoFeed(const void* data, size_t size) - { - if ( ! IsValid() ) - return false; +bool MD5Val::DoFeed(const void* data, size_t size) { + if ( ! IsValid() ) + return false; #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - detail::hash_update(ctx, data, size); + detail::hash_update(ctx, data, size); #else - MD5_Update(&ctx, data, size); + MD5_Update(&ctx, data, size); #endif - return true; - } + return true; +} -StringValPtr MD5Val::DoGet() - { - if ( ! IsValid() ) - return val_mgr->EmptyString(); +StringValPtr MD5Val::DoGet() { + if ( ! IsValid() ) + return val_mgr->EmptyString(); - u_char digest[MD5_DIGEST_LENGTH]; + u_char digest[MD5_DIGEST_LENGTH]; #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - detail::hash_final(ctx, digest); + detail::hash_final(ctx, digest); #else - MD5_Final(digest, &ctx); + MD5_Final(digest, &ctx); #endif - return make_intrusive(detail::md5_digest_print(digest)); - } + return make_intrusive(detail::md5_digest_print(digest)); +} IMPLEMENT_OPAQUE_VALUE(MD5Val) -broker::expected MD5Val::DoSerialize() const - { - if ( ! IsValid() ) - return {broker::vector{false}}; +broker::expected MD5Val::DoSerialize() const { + if ( ! IsValid() ) + return {broker::vector{false}}; #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - MD5_CTX* md = (MD5_CTX*)EVP_MD_CTX_md_data(ctx); - auto data = std::string(reinterpret_cast(md), sizeof(MD5_CTX)); + MD5_CTX* md = (MD5_CTX*)EVP_MD_CTX_md_data(ctx); + auto data = std::string(reinterpret_cast(md), sizeof(MD5_CTX)); #else - auto data = std::string(reinterpret_cast(&ctx), sizeof(ctx)); + auto data = std::string(reinterpret_cast(&ctx), sizeof(ctx)); #endif - broker::vector d = {true, data}; - return {std::move(d)}; - } + broker::vector d = {true, data}; + return {std::move(d)}; +} -bool MD5Val::DoUnserialize(const broker::data& data) - { - auto d = broker::get_if(&data); - if ( ! d ) - return false; +bool MD5Val::DoUnserialize(const broker::data& data) { + auto d = broker::get_if(&data); + if ( ! d ) + return false; - auto valid = broker::get_if(&(*d)[0]); - if ( ! valid ) - return false; + auto valid = broker::get_if(&(*d)[0]); + if ( ! valid ) + return false; - if ( ! *valid ) - { - assert(! IsValid()); // default set by ctor - return true; - } + if ( ! *valid ) { + assert(! IsValid()); // default set by ctor + return true; + } - if ( (*d).size() != 2 ) - return false; + if ( (*d).size() != 2 ) + return false; - auto s = broker::get_if(&(*d)[1]); - if ( ! s ) - return false; + auto s = broker::get_if(&(*d)[1]); + if ( ! s ) + return false; #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - if ( sizeof(MD5_CTX) != s->size() ) + if ( sizeof(MD5_CTX) != s->size() ) #else - if ( sizeof(ctx) != s->size() ) + if ( sizeof(ctx) != s->size() ) #endif - return false; + return false; - Init(); + Init(); #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - MD5_CTX* md = (MD5_CTX*)EVP_MD_CTX_md_data(ctx); - memcpy(md, s->data(), s->size()); + MD5_CTX* md = (MD5_CTX*)EVP_MD_CTX_md_data(ctx); + memcpy(md, s->data(), s->size()); #else - memcpy(&ctx, s->data(), s->size()); + memcpy(&ctx, s->data(), s->size()); #endif - return true; - } + return true; +} -SHA1Val::SHA1Val() : HashVal(sha1_type) - { - memset(&ctx, 0, sizeof(ctx)); - } +SHA1Val::SHA1Val() : HashVal(sha1_type) { memset(&ctx, 0, sizeof(ctx)); } -SHA1Val::~SHA1Val() - { +SHA1Val::~SHA1Val() { #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - if ( IsValid() ) - EVP_MD_CTX_free(ctx); + if ( IsValid() ) + EVP_MD_CTX_free(ctx); #endif - } +} -ValPtr SHA1Val::DoClone(CloneState* state) - { - auto out = make_intrusive(); +ValPtr SHA1Val::DoClone(CloneState* state) { + auto out = make_intrusive(); - if ( IsValid() ) - { - if ( ! out->Init() ) - return nullptr; + if ( IsValid() ) { + if ( ! out->Init() ) + return nullptr; #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - EVP_MD_CTX_copy_ex(out->ctx, ctx); + EVP_MD_CTX_copy_ex(out->ctx, ctx); #else - out->ctx = ctx; + out->ctx = ctx; #endif - } + } - return state->NewClone(this, std::move(out)); - } + return state->NewClone(this, std::move(out)); +} -bool SHA1Val::DoInit() - { - assert(! IsValid()); +bool SHA1Val::DoInit() { + assert(! IsValid()); #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - ctx = detail::hash_init(detail::Hash_SHA1); + ctx = detail::hash_init(detail::Hash_SHA1); #else - SHA1_Init(&ctx); + SHA1_Init(&ctx); #endif - return true; - } + return true; +} -bool SHA1Val::DoFeed(const void* data, size_t size) - { - if ( ! IsValid() ) - return false; +bool SHA1Val::DoFeed(const void* data, size_t size) { + if ( ! IsValid() ) + return false; #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - detail::hash_update(ctx, data, size); + detail::hash_update(ctx, data, size); #else - SHA1_Update(&ctx, data, size); + SHA1_Update(&ctx, data, size); #endif - return true; - } + return true; +} -StringValPtr SHA1Val::DoGet() - { - if ( ! IsValid() ) - return val_mgr->EmptyString(); +StringValPtr SHA1Val::DoGet() { + if ( ! IsValid() ) + return val_mgr->EmptyString(); - u_char digest[SHA_DIGEST_LENGTH]; + u_char digest[SHA_DIGEST_LENGTH]; #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - detail::hash_final(ctx, digest); + detail::hash_final(ctx, digest); #else - SHA1_Final(digest, &ctx); + SHA1_Final(digest, &ctx); #endif - return make_intrusive(detail::sha1_digest_print(digest)); - } + return make_intrusive(detail::sha1_digest_print(digest)); +} IMPLEMENT_OPAQUE_VALUE(SHA1Val) -broker::expected SHA1Val::DoSerialize() const - { - if ( ! IsValid() ) - return {broker::vector{false}}; +broker::expected SHA1Val::DoSerialize() const { + if ( ! IsValid() ) + return {broker::vector{false}}; #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - SHA_CTX* md = (SHA_CTX*)EVP_MD_CTX_md_data(ctx); - auto data = std::string(reinterpret_cast(md), sizeof(SHA_CTX)); + SHA_CTX* md = (SHA_CTX*)EVP_MD_CTX_md_data(ctx); + auto data = std::string(reinterpret_cast(md), sizeof(SHA_CTX)); #else - auto data = std::string(reinterpret_cast(&ctx), sizeof(ctx)); + auto data = std::string(reinterpret_cast(&ctx), sizeof(ctx)); #endif - broker::vector d = {true, data}; + broker::vector d = {true, data}; - return {std::move(d)}; - } + return {std::move(d)}; +} -bool SHA1Val::DoUnserialize(const broker::data& data) - { - auto d = broker::get_if(&data); - if ( ! d ) - return false; +bool SHA1Val::DoUnserialize(const broker::data& data) { + auto d = broker::get_if(&data); + if ( ! d ) + return false; - auto valid = broker::get_if(&(*d)[0]); - if ( ! valid ) - return false; + auto valid = broker::get_if(&(*d)[0]); + if ( ! valid ) + return false; - if ( ! *valid ) - { - assert(! IsValid()); // default set by ctor - return true; - } + if ( ! *valid ) { + assert(! IsValid()); // default set by ctor + return true; + } - if ( (*d).size() != 2 ) - return false; + if ( (*d).size() != 2 ) + return false; - auto s = broker::get_if(&(*d)[1]); - if ( ! s ) - return false; + auto s = broker::get_if(&(*d)[1]); + if ( ! s ) + return false; #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - if ( sizeof(SHA_CTX) != s->size() ) + if ( sizeof(SHA_CTX) != s->size() ) #else - if ( sizeof(ctx) != s->size() ) + if ( sizeof(ctx) != s->size() ) #endif - return false; + return false; - Init(); + Init(); #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - SHA_CTX* md = (SHA_CTX*)EVP_MD_CTX_md_data(ctx); - memcpy(md, s->data(), s->size()); + SHA_CTX* md = (SHA_CTX*)EVP_MD_CTX_md_data(ctx); + memcpy(md, s->data(), s->size()); #else - memcpy(&ctx, s->data(), s->size()); + memcpy(&ctx, s->data(), s->size()); #endif - return true; - } + return true; +} -SHA256Val::SHA256Val() : HashVal(sha256_type) - { - memset(&ctx, 0, sizeof(ctx)); - } +SHA256Val::SHA256Val() : HashVal(sha256_type) { memset(&ctx, 0, sizeof(ctx)); } -SHA256Val::~SHA256Val() - { +SHA256Val::~SHA256Val() { #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - if ( IsValid() ) - EVP_MD_CTX_free(ctx); + if ( IsValid() ) + EVP_MD_CTX_free(ctx); #endif - } +} -ValPtr SHA256Val::DoClone(CloneState* state) - { - auto out = make_intrusive(); +ValPtr SHA256Val::DoClone(CloneState* state) { + auto out = make_intrusive(); - if ( IsValid() ) - { - if ( ! out->Init() ) - return nullptr; + if ( IsValid() ) { + if ( ! out->Init() ) + return nullptr; #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - EVP_MD_CTX_copy_ex(out->ctx, ctx); + EVP_MD_CTX_copy_ex(out->ctx, ctx); #else - out->ctx = ctx; + out->ctx = ctx; #endif - } + } - return state->NewClone(this, std::move(out)); - } + return state->NewClone(this, std::move(out)); +} -bool SHA256Val::DoInit() - { - assert(! IsValid()); +bool SHA256Val::DoInit() { + assert(! IsValid()); #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - ctx = detail::hash_init(detail::Hash_SHA256); + ctx = detail::hash_init(detail::Hash_SHA256); #else - SHA256_Init(&ctx); + SHA256_Init(&ctx); #endif - return true; - } + return true; +} -bool SHA256Val::DoFeed(const void* data, size_t size) - { - if ( ! IsValid() ) - return false; +bool SHA256Val::DoFeed(const void* data, size_t size) { + if ( ! IsValid() ) + return false; #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - detail::hash_update(ctx, data, size); + detail::hash_update(ctx, data, size); #else - SHA256_Update(&ctx, data, size); + SHA256_Update(&ctx, data, size); #endif - return true; - } + return true; +} -StringValPtr SHA256Val::DoGet() - { - if ( ! IsValid() ) - return val_mgr->EmptyString(); +StringValPtr SHA256Val::DoGet() { + if ( ! IsValid() ) + return val_mgr->EmptyString(); - u_char digest[SHA256_DIGEST_LENGTH]; + u_char digest[SHA256_DIGEST_LENGTH]; #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - detail::hash_final(ctx, digest); + detail::hash_final(ctx, digest); #else - SHA256_Final(digest, &ctx); + SHA256_Final(digest, &ctx); #endif - return make_intrusive(detail::sha256_digest_print(digest)); - } + return make_intrusive(detail::sha256_digest_print(digest)); +} IMPLEMENT_OPAQUE_VALUE(SHA256Val) -broker::expected SHA256Val::DoSerialize() const - { - if ( ! IsValid() ) - return {broker::vector{false}}; +broker::expected SHA256Val::DoSerialize() const { + if ( ! IsValid() ) + return {broker::vector{false}}; #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - SHA256_CTX* md = (SHA256_CTX*)EVP_MD_CTX_md_data(ctx); - auto data = std::string(reinterpret_cast(md), sizeof(SHA256_CTX)); + SHA256_CTX* md = (SHA256_CTX*)EVP_MD_CTX_md_data(ctx); + auto data = std::string(reinterpret_cast(md), sizeof(SHA256_CTX)); #else - auto data = std::string(reinterpret_cast(&ctx), sizeof(ctx)); + auto data = std::string(reinterpret_cast(&ctx), sizeof(ctx)); #endif - broker::vector d = {true, data}; + broker::vector d = {true, data}; - return {std::move(d)}; - } + return {std::move(d)}; +} -bool SHA256Val::DoUnserialize(const broker::data& data) - { - auto d = broker::get_if(&data); - if ( ! d ) - return false; +bool SHA256Val::DoUnserialize(const broker::data& data) { + auto d = broker::get_if(&data); + if ( ! d ) + return false; - auto valid = broker::get_if(&(*d)[0]); - if ( ! valid ) - return false; + auto valid = broker::get_if(&(*d)[0]); + if ( ! valid ) + return false; - if ( ! *valid ) - { - assert(! IsValid()); // default set by ctor - return true; - } + if ( ! *valid ) { + assert(! IsValid()); // default set by ctor + return true; + } - if ( (*d).size() != 2 ) - return false; + if ( (*d).size() != 2 ) + return false; - auto s = broker::get_if(&(*d)[1]); - if ( ! s ) - return false; + auto s = broker::get_if(&(*d)[1]); + if ( ! s ) + return false; #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - if ( sizeof(SHA256_CTX) != s->size() ) + if ( sizeof(SHA256_CTX) != s->size() ) #else - if ( sizeof(ctx) != s->size() ) + if ( sizeof(ctx) != s->size() ) #endif - return false; + return false; - Init(); + Init(); #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - SHA256_CTX* md = (SHA256_CTX*)EVP_MD_CTX_md_data(ctx); - memcpy(md, s->data(), s->size()); + SHA256_CTX* md = (SHA256_CTX*)EVP_MD_CTX_md_data(ctx); + memcpy(md, s->data(), s->size()); #else - memcpy(&ctx, s->data(), s->size()); + memcpy(&ctx, s->data(), s->size()); #endif - return true; - } + return true; +} -EntropyVal::EntropyVal() : OpaqueVal(entropy_type) { } +EntropyVal::EntropyVal() : OpaqueVal(entropy_type) {} -bool EntropyVal::Feed(const void* data, size_t size) - { - state.add(data, size); - return true; - } +bool EntropyVal::Feed(const void* data, size_t size) { + state.add(data, size); + return true; +} -bool EntropyVal::Get(double* r_ent, double* r_chisq, double* r_mean, double* r_montepicalc, - double* r_scc) - { - state.end(r_ent, r_chisq, r_mean, r_montepicalc, r_scc); - return true; - } +bool EntropyVal::Get(double* r_ent, double* r_chisq, double* r_mean, double* r_montepicalc, double* r_scc) { + state.end(r_ent, r_chisq, r_mean, r_montepicalc, r_scc); + return true; +} IMPLEMENT_OPAQUE_VALUE(EntropyVal) -broker::expected EntropyVal::DoSerialize() const - { - broker::vector d = { - static_cast(state.totalc), static_cast(state.mp), - static_cast(state.sccfirst), static_cast(state.inmont), - static_cast(state.mcount), static_cast(state.cexp), - static_cast(state.montex), static_cast(state.montey), - static_cast(state.montepi), static_cast(state.sccu0), - static_cast(state.scclast), static_cast(state.scct1), - static_cast(state.scct2), static_cast(state.scct3), - }; +broker::expected EntropyVal::DoSerialize() const { + broker::vector d = { + static_cast(state.totalc), static_cast(state.mp), + static_cast(state.sccfirst), static_cast(state.inmont), + static_cast(state.mcount), static_cast(state.cexp), + static_cast(state.montex), static_cast(state.montey), + static_cast(state.montepi), static_cast(state.sccu0), + static_cast(state.scclast), static_cast(state.scct1), + static_cast(state.scct2), static_cast(state.scct3), + }; - d.reserve(256 + 3 + RT_MONTEN + 11); + d.reserve(256 + 3 + RT_MONTEN + 11); - for ( int i = 0; i < 256; ++i ) - d.emplace_back(static_cast(state.ccount[i])); + for ( int i = 0; i < 256; ++i ) + d.emplace_back(static_cast(state.ccount[i])); - for ( int i = 0; i < RT_MONTEN; ++i ) - d.emplace_back(static_cast(state.monte[i])); + for ( int i = 0; i < RT_MONTEN; ++i ) + d.emplace_back(static_cast(state.monte[i])); - return {std::move(d)}; - } + return {std::move(d)}; +} -bool EntropyVal::DoUnserialize(const broker::data& data) - { - auto d = broker::get_if(&data); - if ( ! d ) - return false; +bool EntropyVal::DoUnserialize(const broker::data& data) { + auto d = broker::get_if(&data); + if ( ! d ) + return false; - if ( ! get_vector_idx(*d, 0, &state.totalc) ) - return false; - if ( ! get_vector_idx(*d, 1, &state.mp) ) - return false; - if ( ! get_vector_idx(*d, 2, &state.sccfirst) ) - return false; - if ( ! get_vector_idx(*d, 3, &state.inmont) ) - return false; - if ( ! get_vector_idx(*d, 4, &state.mcount) ) - return false; - if ( ! get_vector_idx(*d, 5, &state.cexp) ) - return false; - if ( ! get_vector_idx(*d, 6, &state.montex) ) - return false; - if ( ! get_vector_idx(*d, 7, &state.montey) ) - return false; - if ( ! get_vector_idx(*d, 8, &state.montepi) ) - return false; - if ( ! get_vector_idx(*d, 9, &state.sccu0) ) - return false; - if ( ! get_vector_idx(*d, 10, &state.scclast) ) - return false; - if ( ! get_vector_idx(*d, 11, &state.scct1) ) - return false; - if ( ! get_vector_idx(*d, 12, &state.scct2) ) - return false; - if ( ! get_vector_idx(*d, 13, &state.scct3) ) - return false; + if ( ! get_vector_idx(*d, 0, &state.totalc) ) + return false; + if ( ! get_vector_idx(*d, 1, &state.mp) ) + return false; + if ( ! get_vector_idx(*d, 2, &state.sccfirst) ) + return false; + if ( ! get_vector_idx(*d, 3, &state.inmont) ) + return false; + if ( ! get_vector_idx(*d, 4, &state.mcount) ) + return false; + if ( ! get_vector_idx(*d, 5, &state.cexp) ) + return false; + if ( ! get_vector_idx(*d, 6, &state.montex) ) + return false; + if ( ! get_vector_idx(*d, 7, &state.montey) ) + return false; + if ( ! get_vector_idx(*d, 8, &state.montepi) ) + return false; + if ( ! get_vector_idx(*d, 9, &state.sccu0) ) + return false; + if ( ! get_vector_idx(*d, 10, &state.scclast) ) + return false; + if ( ! get_vector_idx(*d, 11, &state.scct1) ) + return false; + if ( ! get_vector_idx(*d, 12, &state.scct2) ) + return false; + if ( ! get_vector_idx(*d, 13, &state.scct3) ) + return false; - for ( int i = 0; i < 256; ++i ) - { - if ( ! get_vector_idx(*d, 14 + i, &state.ccount[i]) ) - return false; - } + for ( int i = 0; i < 256; ++i ) { + if ( ! get_vector_idx(*d, 14 + i, &state.ccount[i]) ) + return false; + } - for ( int i = 0; i < RT_MONTEN; ++i ) - { - if ( ! get_vector_idx(*d, 14 + 256 + i, &state.monte[i]) ) - return false; - } + for ( int i = 0; i < RT_MONTEN; ++i ) { + if ( ! get_vector_idx(*d, 14 + 256 + i, &state.monte[i]) ) + return false; + } - return true; - } + return true; +} -BloomFilterVal::BloomFilterVal() : OpaqueVal(bloomfilter_type) - { - hash = nullptr; - bloom_filter = nullptr; - } +BloomFilterVal::BloomFilterVal() : OpaqueVal(bloomfilter_type) { + hash = nullptr; + bloom_filter = nullptr; +} -BloomFilterVal::BloomFilterVal(probabilistic::BloomFilter* bf) : OpaqueVal(bloomfilter_type) - { - hash = nullptr; - bloom_filter = bf; - } +BloomFilterVal::BloomFilterVal(probabilistic::BloomFilter* bf) : OpaqueVal(bloomfilter_type) { + hash = nullptr; + bloom_filter = bf; +} -ValPtr BloomFilterVal::DoClone(CloneState* state) - { - if ( bloom_filter ) - { - auto bf = make_intrusive(bloom_filter->Clone()); - assert(type); - bf->Typify(type); - return state->NewClone(this, std::move(bf)); - } +ValPtr BloomFilterVal::DoClone(CloneState* state) { + if ( bloom_filter ) { + auto bf = make_intrusive(bloom_filter->Clone()); + assert(type); + bf->Typify(type); + return state->NewClone(this, std::move(bf)); + } - return state->NewClone(this, make_intrusive()); - } + return state->NewClone(this, make_intrusive()); +} -bool BloomFilterVal::Typify(TypePtr arg_type) - { - if ( type ) - return false; +bool BloomFilterVal::Typify(TypePtr arg_type) { + if ( type ) + return false; - type = std::move(arg_type); + type = std::move(arg_type); - auto tl = make_intrusive(type); - tl->Append(type); - hash = new detail::CompositeHash(std::move(tl)); + auto tl = make_intrusive(type); + tl->Append(type); + hash = new detail::CompositeHash(std::move(tl)); - return true; - } + return true; +} -void BloomFilterVal::Add(const Val* val) - { - auto key = hash->MakeHashKey(*val, true); - bloom_filter->Add(key.get()); - } +void BloomFilterVal::Add(const Val* val) { + auto key = hash->MakeHashKey(*val, true); + bloom_filter->Add(key.get()); +} -bool BloomFilterVal::Decrement(const Val* val) - { - auto key = hash->MakeHashKey(*val, true); - return bloom_filter->Decrement(key.get()); - } +bool BloomFilterVal::Decrement(const Val* val) { + auto key = hash->MakeHashKey(*val, true); + return bloom_filter->Decrement(key.get()); +} -size_t BloomFilterVal::Count(const Val* val) const - { - auto key = hash->MakeHashKey(*val, true); - size_t cnt = bloom_filter->Count(key.get()); - return cnt; - } +size_t BloomFilterVal::Count(const Val* val) const { + auto key = hash->MakeHashKey(*val, true); + size_t cnt = bloom_filter->Count(key.get()); + return cnt; +} -void BloomFilterVal::Clear() - { - bloom_filter->Clear(); - } +void BloomFilterVal::Clear() { bloom_filter->Clear(); } -bool BloomFilterVal::Empty() const - { - return bloom_filter->Empty(); - } +bool BloomFilterVal::Empty() const { return bloom_filter->Empty(); } -std::string BloomFilterVal::InternalState() const - { - return bloom_filter->InternalState(); - } +std::string BloomFilterVal::InternalState() const { return bloom_filter->InternalState(); } -BloomFilterValPtr BloomFilterVal::Merge(const BloomFilterVal* x, const BloomFilterVal* y) - { - if ( x->Type() && // any one 0 is ok here - y->Type() && ! same_type(x->Type(), y->Type()) ) - { - reporter->Error("cannot merge Bloom filters with different types"); - return nullptr; - } +BloomFilterValPtr BloomFilterVal::Merge(const BloomFilterVal* x, const BloomFilterVal* y) { + if ( x->Type() && // any one 0 is ok here + y->Type() && ! same_type(x->Type(), y->Type()) ) { + reporter->Error("cannot merge Bloom filters with different types"); + return nullptr; + } - auto final_type = x->Type() ? x->Type() : y->Type(); + auto final_type = x->Type() ? x->Type() : y->Type(); - if ( typeid(*x->bloom_filter) != typeid(*y->bloom_filter) ) - { - reporter->Error("cannot merge different Bloom filter types"); - return nullptr; - } + if ( typeid(*x->bloom_filter) != typeid(*y->bloom_filter) ) { + reporter->Error("cannot merge different Bloom filter types"); + return nullptr; + } - probabilistic::BloomFilter* copy = x->bloom_filter->Clone(); + probabilistic::BloomFilter* copy = x->bloom_filter->Clone(); - if ( ! copy->Merge(y->bloom_filter) ) - { - delete copy; - reporter->Error("failed to merge Bloom filter"); - return nullptr; - } + if ( ! copy->Merge(y->bloom_filter) ) { + delete copy; + reporter->Error("failed to merge Bloom filter"); + return nullptr; + } - auto merged = make_intrusive(copy); + auto merged = make_intrusive(copy); - if ( final_type && ! merged->Typify(final_type) ) - { - reporter->Error("failed to set type on merged Bloom filter"); - return nullptr; - } + if ( final_type && ! merged->Typify(final_type) ) { + reporter->Error("failed to set type on merged Bloom filter"); + return nullptr; + } - return merged; - } + return merged; +} -BloomFilterValPtr BloomFilterVal::Intersect(const BloomFilterVal* x, const BloomFilterVal* y) - { - if ( x->Type() && // any one 0 is ok here - y->Type() && ! same_type(x->Type(), y->Type()) ) - { - reporter->Error("cannot merge Bloom filters with different types"); - return nullptr; - } +BloomFilterValPtr BloomFilterVal::Intersect(const BloomFilterVal* x, const BloomFilterVal* y) { + if ( x->Type() && // any one 0 is ok here + y->Type() && ! same_type(x->Type(), y->Type()) ) { + reporter->Error("cannot merge Bloom filters with different types"); + return nullptr; + } - if ( typeid(*x->bloom_filter) != typeid(*y->bloom_filter) ) - { - reporter->Error("cannot intersect different Bloom filter types"); - return nullptr; - } + if ( typeid(*x->bloom_filter) != typeid(*y->bloom_filter) ) { + reporter->Error("cannot intersect different Bloom filter types"); + return nullptr; + } - auto intersected_bf = x->bloom_filter->Intersect(y->bloom_filter); + auto intersected_bf = x->bloom_filter->Intersect(y->bloom_filter); - if ( ! intersected_bf ) - { - reporter->Error("failed to intersect Bloom filter"); - return nullptr; - } + if ( ! intersected_bf ) { + reporter->Error("failed to intersect Bloom filter"); + return nullptr; + } - auto final_type = x->Type() ? x->Type() : y->Type(); + auto final_type = x->Type() ? x->Type() : y->Type(); - auto intersected = make_intrusive(intersected_bf); + auto intersected = make_intrusive(intersected_bf); - if ( final_type && ! intersected->Typify(final_type) ) - { - reporter->Error("Failed to set type on intersected bloom filter"); - return nullptr; - } + if ( final_type && ! intersected->Typify(final_type) ) { + reporter->Error("Failed to set type on intersected bloom filter"); + return nullptr; + } - return intersected; - } + return intersected; +} -BloomFilterVal::~BloomFilterVal() - { - delete hash; - delete bloom_filter; - } +BloomFilterVal::~BloomFilterVal() { + delete hash; + delete bloom_filter; +} IMPLEMENT_OPAQUE_VALUE(BloomFilterVal) -broker::expected BloomFilterVal::DoSerialize() const - { - broker::vector d; +broker::expected BloomFilterVal::DoSerialize() const { + broker::vector d; - if ( type ) - { - auto t = SerializeType(type); - if ( ! t ) - return broker::ec::invalid_data; + if ( type ) { + auto t = SerializeType(type); + if ( ! t ) + return broker::ec::invalid_data; - d.emplace_back(std::move(*t)); - } - else - d.emplace_back(broker::none()); + d.emplace_back(std::move(*t)); + } + else + d.emplace_back(broker::none()); - auto bf = bloom_filter->Serialize(); - if ( ! bf ) - return broker::ec::invalid_data; // Cannot serialize; + auto bf = bloom_filter->Serialize(); + if ( ! bf ) + return broker::ec::invalid_data; // Cannot serialize; - d.emplace_back(*bf); - return {std::move(d)}; - } + d.emplace_back(*bf); + return {std::move(d)}; +} -bool BloomFilterVal::DoUnserialize(const broker::data& data) - { - auto v = broker::get_if(&data); +bool BloomFilterVal::DoUnserialize(const broker::data& data) { + auto v = broker::get_if(&data); - if ( ! (v && v->size() == 2) ) - return false; + if ( ! (v && v->size() == 2) ) + return false; - auto no_type = broker::get_if(&(*v)[0]); - if ( ! no_type ) - { - auto t = UnserializeType((*v)[0]); + auto no_type = broker::get_if(&(*v)[0]); + if ( ! no_type ) { + auto t = UnserializeType((*v)[0]); - if ( ! (t && Typify(std::move(t))) ) - return false; - } + if ( ! (t && Typify(std::move(t))) ) + return false; + } - auto bf = probabilistic::BloomFilter::Unserialize((*v)[1]); - if ( ! bf ) - return false; + auto bf = probabilistic::BloomFilter::Unserialize((*v)[1]); + if ( ! bf ) + return false; - bloom_filter = bf.release(); - return true; - } + bloom_filter = bf.release(); + return true; +} -CardinalityVal::CardinalityVal() : OpaqueVal(cardinality_type) - { - c = nullptr; - hash = nullptr; - } +CardinalityVal::CardinalityVal() : OpaqueVal(cardinality_type) { + c = nullptr; + hash = nullptr; +} -CardinalityVal::CardinalityVal(probabilistic::detail::CardinalityCounter* arg_c) - : OpaqueVal(cardinality_type) - { - c = arg_c; - hash = nullptr; - } +CardinalityVal::CardinalityVal(probabilistic::detail::CardinalityCounter* arg_c) : OpaqueVal(cardinality_type) { + c = arg_c; + hash = nullptr; +} -CardinalityVal::~CardinalityVal() - { - delete c; - delete hash; - } +CardinalityVal::~CardinalityVal() { + delete c; + delete hash; +} -ValPtr CardinalityVal::DoClone(CloneState* state) - { - return state->NewClone( - this, make_intrusive(new probabilistic::detail::CardinalityCounter(*c))); - } +ValPtr CardinalityVal::DoClone(CloneState* state) { + return state->NewClone(this, make_intrusive(new probabilistic::detail::CardinalityCounter(*c))); +} -bool CardinalityVal::Typify(TypePtr arg_type) - { - if ( type ) - return false; +bool CardinalityVal::Typify(TypePtr arg_type) { + if ( type ) + return false; - type = std::move(arg_type); + type = std::move(arg_type); - auto tl = make_intrusive(type); - tl->Append(type); - hash = new detail::CompositeHash(std::move(tl)); + auto tl = make_intrusive(type); + tl->Append(type); + hash = new detail::CompositeHash(std::move(tl)); - return true; - } + return true; +} -void CardinalityVal::Add(const Val* val) - { - auto key = hash->MakeHashKey(*val, true); - c->AddElement(key->Hash()); - } +void CardinalityVal::Add(const Val* val) { + auto key = hash->MakeHashKey(*val, true); + c->AddElement(key->Hash()); +} IMPLEMENT_OPAQUE_VALUE(CardinalityVal) -broker::expected CardinalityVal::DoSerialize() const - { - broker::vector d; +broker::expected CardinalityVal::DoSerialize() const { + broker::vector d; - if ( type ) - { - auto t = SerializeType(type); - if ( ! t ) - return broker::ec::invalid_data; + if ( type ) { + auto t = SerializeType(type); + if ( ! t ) + return broker::ec::invalid_data; - d.emplace_back(std::move(*t)); - } - else - d.emplace_back(broker::none()); + d.emplace_back(std::move(*t)); + } + else + d.emplace_back(broker::none()); - auto cs = c->Serialize(); - if ( ! cs ) - return broker::ec::invalid_data; + auto cs = c->Serialize(); + if ( ! cs ) + return broker::ec::invalid_data; - d.emplace_back(*cs); - return {std::move(d)}; - } + d.emplace_back(*cs); + return {std::move(d)}; +} -bool CardinalityVal::DoUnserialize(const broker::data& data) - { - auto v = broker::get_if(&data); +bool CardinalityVal::DoUnserialize(const broker::data& data) { + auto v = broker::get_if(&data); - if ( ! (v && v->size() == 2) ) - return false; + if ( ! (v && v->size() == 2) ) + return false; - auto no_type = broker::get_if(&(*v)[0]); - if ( ! no_type ) - { - auto t = UnserializeType((*v)[0]); + auto no_type = broker::get_if(&(*v)[0]); + if ( ! no_type ) { + auto t = UnserializeType((*v)[0]); - if ( ! (t && Typify(std::move(t))) ) - return false; - } + if ( ! (t && Typify(std::move(t))) ) + return false; + } - auto cu = probabilistic::detail::CardinalityCounter::Unserialize((*v)[1]); - if ( ! cu ) - return false; + auto cu = probabilistic::detail::CardinalityCounter::Unserialize((*v)[1]); + if ( ! cu ) + return false; - c = cu.release(); - return true; - } + c = cu.release(); + return true; +} -ParaglobVal::ParaglobVal(std::unique_ptr p) : OpaqueVal(paraglob_type) - { - this->internal_paraglob = std::move(p); - } +ParaglobVal::ParaglobVal(std::unique_ptr p) : OpaqueVal(paraglob_type) { + this->internal_paraglob = std::move(p); +} -VectorValPtr ParaglobVal::Get(StringVal*& pattern) - { - auto rval = make_intrusive(id::string_vec); - std::string string_pattern(reinterpret_cast(pattern->Bytes()), pattern->Len()); +VectorValPtr ParaglobVal::Get(StringVal*& pattern) { + auto rval = make_intrusive(id::string_vec); + std::string string_pattern(reinterpret_cast(pattern->Bytes()), pattern->Len()); - std::vector matches = this->internal_paraglob->get(string_pattern); - for ( size_t i = 0; i < matches.size(); i++ ) - rval->Assign(i, make_intrusive(matches.at(i))); + std::vector matches = this->internal_paraglob->get(string_pattern); + for ( size_t i = 0; i < matches.size(); i++ ) + rval->Assign(i, make_intrusive(matches.at(i))); - return rval; - } + return rval; +} -bool ParaglobVal::operator==(const ParaglobVal& other) const - { - return *(this->internal_paraglob) == *(other.internal_paraglob); - } +bool ParaglobVal::operator==(const ParaglobVal& other) const { + return *(this->internal_paraglob) == *(other.internal_paraglob); +} IMPLEMENT_OPAQUE_VALUE(ParaglobVal) -broker::expected ParaglobVal::DoSerialize() const - { - broker::vector d; - std::unique_ptr> iv = this->internal_paraglob->serialize(); - for ( uint8_t a : *(iv.get()) ) - d.emplace_back(static_cast(a)); - return {std::move(d)}; - } +broker::expected ParaglobVal::DoSerialize() const { + broker::vector d; + std::unique_ptr> iv = this->internal_paraglob->serialize(); + for ( uint8_t a : *(iv.get()) ) + d.emplace_back(static_cast(a)); + return {std::move(d)}; +} -bool ParaglobVal::DoUnserialize(const broker::data& data) - { - auto d = broker::get_if(&data); - if ( ! d ) - return false; +bool ParaglobVal::DoUnserialize(const broker::data& data) { + auto d = broker::get_if(&data); + if ( ! d ) + return false; - std::unique_ptr> iv(new std::vector); - iv->resize(d->size()); + std::unique_ptr> iv(new std::vector); + iv->resize(d->size()); - for ( std::vector::size_type i = 0; i < d->size(); ++i ) - { - if ( ! get_vector_idx(*d, i, iv.get()->data() + i) ) - return false; - } + for ( std::vector::size_type i = 0; i < d->size(); ++i ) { + if ( ! get_vector_idx(*d, i, iv.get()->data() + i) ) + return false; + } - try - { - this->internal_paraglob = std::make_unique(std::move(iv)); - } - catch ( const paraglob::underflow_error& e ) - { - reporter->Error("Paraglob underflow error -> %s", e.what()); - return false; - } - catch ( const paraglob::overflow_error& e ) - { - reporter->Error("Paraglob overflow error -> %s", e.what()); - return false; - } + try { + this->internal_paraglob = std::make_unique(std::move(iv)); + } catch ( const paraglob::underflow_error& e ) { + reporter->Error("Paraglob underflow error -> %s", e.what()); + return false; + } catch ( const paraglob::overflow_error& e ) { + reporter->Error("Paraglob overflow error -> %s", e.what()); + return false; + } - return true; - } + return true; +} -ValPtr ParaglobVal::DoClone(CloneState* state) - { - try - { - return make_intrusive( - std::make_unique(this->internal_paraglob->serialize())); - } - catch ( const paraglob::underflow_error& e ) - { - reporter->Error("Paraglob underflow error while cloning -> %s", e.what()); - return nullptr; - } - catch ( const paraglob::overflow_error& e ) - { - reporter->Error("Paraglob overflow error while cloning -> %s", e.what()); - return nullptr; - } - } +ValPtr ParaglobVal::DoClone(CloneState* state) { + try { + return make_intrusive(std::make_unique(this->internal_paraglob->serialize())); + } catch ( const paraglob::underflow_error& e ) { + reporter->Error("Paraglob underflow error while cloning -> %s", e.what()); + return nullptr; + } catch ( const paraglob::overflow_error& e ) { + reporter->Error("Paraglob overflow error while cloning -> %s", e.what()); + return nullptr; + } +} -broker::expected TelemetryVal::DoSerialize() const - { - return broker::make_error(broker::ec::invalid_data, "cannot serialize metric handles"); - } +broker::expected TelemetryVal::DoSerialize() const { + return broker::make_error(broker::ec::invalid_data, "cannot serialize metric handles"); +} -bool TelemetryVal::DoUnserialize(const broker::data&) - { - return false; - } +bool TelemetryVal::DoUnserialize(const broker::data&) { return false; } -TelemetryVal::TelemetryVal(telemetry::IntCounter) : OpaqueVal(int_counter_metric_type) { } +TelemetryVal::TelemetryVal(telemetry::IntCounter) : OpaqueVal(int_counter_metric_type) {} -TelemetryVal::TelemetryVal(telemetry::IntCounterFamily) : OpaqueVal(int_counter_metric_family_type) - { - } +TelemetryVal::TelemetryVal(telemetry::IntCounterFamily) : OpaqueVal(int_counter_metric_family_type) {} -TelemetryVal::TelemetryVal(telemetry::DblCounter) : OpaqueVal(dbl_counter_metric_type) { } +TelemetryVal::TelemetryVal(telemetry::DblCounter) : OpaqueVal(dbl_counter_metric_type) {} -TelemetryVal::TelemetryVal(telemetry::DblCounterFamily) : OpaqueVal(dbl_counter_metric_family_type) - { - } +TelemetryVal::TelemetryVal(telemetry::DblCounterFamily) : OpaqueVal(dbl_counter_metric_family_type) {} -TelemetryVal::TelemetryVal(telemetry::IntGauge) : OpaqueVal(int_gauge_metric_type) { } +TelemetryVal::TelemetryVal(telemetry::IntGauge) : OpaqueVal(int_gauge_metric_type) {} -TelemetryVal::TelemetryVal(telemetry::IntGaugeFamily) : OpaqueVal(int_gauge_metric_family_type) { } +TelemetryVal::TelemetryVal(telemetry::IntGaugeFamily) : OpaqueVal(int_gauge_metric_family_type) {} -TelemetryVal::TelemetryVal(telemetry::DblGauge) : OpaqueVal(dbl_gauge_metric_type) { } +TelemetryVal::TelemetryVal(telemetry::DblGauge) : OpaqueVal(dbl_gauge_metric_type) {} -TelemetryVal::TelemetryVal(telemetry::DblGaugeFamily) : OpaqueVal(dbl_gauge_metric_family_type) { } +TelemetryVal::TelemetryVal(telemetry::DblGaugeFamily) : OpaqueVal(dbl_gauge_metric_family_type) {} -TelemetryVal::TelemetryVal(telemetry::IntHistogram) : OpaqueVal(int_histogram_metric_type) { } +TelemetryVal::TelemetryVal(telemetry::IntHistogram) : OpaqueVal(int_histogram_metric_type) {} -TelemetryVal::TelemetryVal(telemetry::IntHistogramFamily) - : OpaqueVal(int_histogram_metric_family_type) - { - } +TelemetryVal::TelemetryVal(telemetry::IntHistogramFamily) : OpaqueVal(int_histogram_metric_family_type) {} -TelemetryVal::TelemetryVal(telemetry::DblHistogram) : OpaqueVal(dbl_histogram_metric_type) { } +TelemetryVal::TelemetryVal(telemetry::DblHistogram) : OpaqueVal(dbl_histogram_metric_type) {} -TelemetryVal::TelemetryVal(telemetry::DblHistogramFamily) - : OpaqueVal(dbl_histogram_metric_family_type) - { - } +TelemetryVal::TelemetryVal(telemetry::DblHistogramFamily) : OpaqueVal(dbl_histogram_metric_family_type) {} - } +} // namespace zeek diff --git a/src/OpaqueVal.h b/src/OpaqueVal.h index 35ed81e54e..843bed1757 100644 --- a/src/OpaqueVal.h +++ b/src/OpaqueVal.h @@ -21,22 +21,18 @@ #include "zeek/telemetry/Gauge.h" #include "zeek/telemetry/Histogram.h" -namespace broker - { +namespace broker { class data; - } +} -namespace zeek - { +namespace zeek { -namespace probabilistic - { +namespace probabilistic { class BloomFilter; - } -namespace probabilistic::detail - { +} +namespace probabilistic::detail { class CardinalityCounter; - } +} class OpaqueVal; using OpaqueValPtr = IntrusivePtr; @@ -48,60 +44,59 @@ using BloomFilterValPtr = IntrusivePtr; * Singleton that registers all available all available types of opaque * values. This facilitates their serialization into Broker values. */ -class OpaqueMgr - { +class OpaqueMgr { public: - using Factory = OpaqueValPtr(); + using Factory = OpaqueValPtr(); - /** - * Return's a unique ID for the type of an opaque value. - * @param v opaque value to return type for; its class must have been - * registered with the manager, otherwise this method will abort - * execution. - * - * @return type ID, which can used with *Instantiate()* to create a - * new instance of the same type. - */ - const std::string& TypeID(const OpaqueVal* v) const; + /** + * Return's a unique ID for the type of an opaque value. + * @param v opaque value to return type for; its class must have been + * registered with the manager, otherwise this method will abort + * execution. + * + * @return type ID, which can used with *Instantiate()* to create a + * new instance of the same type. + */ + const std::string& TypeID(const OpaqueVal* v) const; - /** - * Instantiates a new opaque value of a specific opaque type. - * - * @param id unique type ID for the class to instantiate; this will - * normally have been returned earlier by *TypeID()*. - * - * @return A freshly instantiated value of the OpaqueVal-derived - * classes that *id* specifies, with reference count at +1. If *id* - * is unknown, this will return null. - * - */ - OpaqueValPtr Instantiate(const std::string& id) const; + /** + * Instantiates a new opaque value of a specific opaque type. + * + * @param id unique type ID for the class to instantiate; this will + * normally have been returned earlier by *TypeID()*. + * + * @return A freshly instantiated value of the OpaqueVal-derived + * classes that *id* specifies, with reference count at +1. If *id* + * is unknown, this will return null. + * + */ + OpaqueValPtr Instantiate(const std::string& id) const; - /** Returns the global manager singleton object. */ - static OpaqueMgr* mgr(); + /** Returns the global manager singleton object. */ + static OpaqueMgr* mgr(); - /** - * Internal helper class to register an OpaqueVal-derived classes - * with the manager. - */ - template class Register - { - public: - Register(const char* id) { OpaqueMgr::mgr()->_types.emplace(id, &T::OpaqueInstantiate); } - }; + /** + * Internal helper class to register an OpaqueVal-derived classes + * with the manager. + */ + template + class Register { + public: + Register(const char* id) { OpaqueMgr::mgr()->_types.emplace(id, &T::OpaqueInstantiate); } + }; private: - std::unordered_map _types; - }; + std::unordered_map _types; +}; /** Macro to insert into an OpaqueVal-derived class's declaration. */ -#define DECLARE_OPAQUE_VALUE(T) \ - friend class zeek::OpaqueMgr::Register; \ - friend zeek::IntrusivePtr zeek::make_intrusive(); \ - broker::expected DoSerialize() const override; \ - bool DoUnserialize(const broker::data& data) override; \ - const char* OpaqueName() const override { return #T; } \ - static zeek::OpaqueValPtr OpaqueInstantiate() { return zeek::make_intrusive(); } +#define DECLARE_OPAQUE_VALUE(T) \ + friend class zeek::OpaqueMgr::Register; \ + friend zeek::IntrusivePtr zeek::make_intrusive(); \ + broker::expected DoSerialize() const override; \ + bool DoUnserialize(const broker::data& data) override; \ + const char* OpaqueName() const override { return #T; } \ + static zeek::OpaqueValPtr OpaqueInstantiate() { return zeek::make_intrusive(); } #define __OPAQUE_MERGE(a, b) a##b #define __OPAQUE_ID(x) __OPAQUE_MERGE(_opaque, x) @@ -114,348 +109,335 @@ private: * completely internally, with no further script-level operators provided * (other than bif functions). See OpaqueVal.h for derived classes. */ -class OpaqueVal : public Val - { +class OpaqueVal : public Val { public: - explicit OpaqueVal(OpaqueTypePtr t); - ~OpaqueVal() override = default; + explicit OpaqueVal(OpaqueTypePtr t); + ~OpaqueVal() override = default; - /** - * Serializes the value into a Broker representation. - * - * @return the broker representation, or an error if serialization - * isn't supported or failed. - */ - broker::expected Serialize() const; + /** + * Serializes the value into a Broker representation. + * + * @return the broker representation, or an error if serialization + * isn't supported or failed. + */ + broker::expected Serialize() const; - /** - * Reinstantiates a value from its serialized Broker representation. - * - * @param data Broker representation as returned by *Serialize()*. - * @return unserialized instances with reference count at +1 - */ - static OpaqueValPtr Unserialize(const broker::data& data); + /** + * Reinstantiates a value from its serialized Broker representation. + * + * @param data Broker representation as returned by *Serialize()*. + * @return unserialized instances with reference count at +1 + */ + static OpaqueValPtr Unserialize(const broker::data& data); protected: - friend class Val; - friend class OpaqueMgr; + friend class Val; + friend class OpaqueMgr; - /** - * Must be overridden to provide a serialized version of the derived - * class' state. - * - * @return the serialized data or an error if serialization - * isn't supported or failed. - */ - virtual broker::expected DoSerialize() const = 0; + /** + * Must be overridden to provide a serialized version of the derived + * class' state. + * + * @return the serialized data or an error if serialization + * isn't supported or failed. + */ + virtual broker::expected DoSerialize() const = 0; - /** - * Must be overridden to recreate the derived class' state from a - * serialization. - * - * @return true if successful. - */ - virtual bool DoUnserialize(const broker::data& data) = 0; + /** + * Must be overridden to recreate the derived class' state from a + * serialization. + * + * @return true if successful. + */ + virtual bool DoUnserialize(const broker::data& data) = 0; - /** - * Internal helper for the serialization machinery. Automatically - * overridden by the `DECLARE_OPAQUE_VALUE` macro. - */ - virtual const char* OpaqueName() const = 0; + /** + * Internal helper for the serialization machinery. Automatically + * overridden by the `DECLARE_OPAQUE_VALUE` macro. + */ + virtual const char* OpaqueName() const = 0; - /** - * Provides an implementation of *Val::DoClone()* that leverages the - * serialization methods to deep-copy an instance. Derived classes - * may also override this with a more efficient custom clone - * implementation of their own. - */ - ValPtr DoClone(CloneState* state) override; + /** + * Provides an implementation of *Val::DoClone()* that leverages the + * serialization methods to deep-copy an instance. Derived classes + * may also override this with a more efficient custom clone + * implementation of their own. + */ + ValPtr DoClone(CloneState* state) override; - /** - * Helper function for derived class that need to record a type - * during serialization. - */ - static broker::expected SerializeType(const TypePtr& t); + /** + * Helper function for derived class that need to record a type + * during serialization. + */ + static broker::expected SerializeType(const TypePtr& t); - /** - * Helper function for derived class that need to restore a type - * during unserialization. Returns the type at reference count +1. - */ - static TypePtr UnserializeType(const broker::data& data); + /** + * Helper function for derived class that need to restore a type + * during unserialization. Returns the type at reference count +1. + */ + static TypePtr UnserializeType(const broker::data& data); - void ValDescribe(ODesc* d) const override; - void ValDescribeReST(ODesc* d) const override; - }; + void ValDescribe(ODesc* d) const override; + void ValDescribeReST(ODesc* d) const override; +}; -class HashVal : public OpaqueVal - { +class HashVal : public OpaqueVal { public: - template - static void digest_all(detail::HashAlgorithm alg, const T& vlist, u_char* result) - { - auto h = detail::hash_init(alg); + template + static void digest_all(detail::HashAlgorithm alg, const T& vlist, u_char* result) { + auto h = detail::hash_init(alg); - for ( const auto& v : vlist ) - digest_one(h, v); + for ( const auto& v : vlist ) + digest_one(h, v); - detail::hash_final(h, result); - } + detail::hash_final(h, result); + } - bool IsValid() const; - bool Init(); - bool Feed(const void* data, size_t size); - StringValPtr Get(); + bool IsValid() const; + bool Init(); + bool Feed(const void* data, size_t size); + StringValPtr Get(); protected: - static void digest_one(EVP_MD_CTX* h, const Val* v); - static void digest_one(EVP_MD_CTX* h, const ValPtr& v); + static void digest_one(EVP_MD_CTX* h, const Val* v); + static void digest_one(EVP_MD_CTX* h, const ValPtr& v); - explicit HashVal(OpaqueTypePtr t); + explicit HashVal(OpaqueTypePtr t); - virtual bool DoInit(); - virtual bool DoFeed(const void* data, size_t size); - virtual StringValPtr DoGet(); + virtual bool DoInit(); + virtual bool DoFeed(const void* data, size_t size); + virtual StringValPtr DoGet(); private: - // This flag exists because Get() can only be called once. - bool valid; - }; + // This flag exists because Get() can only be called once. + bool valid; +}; -class MD5Val : public HashVal - { +class MD5Val : public HashVal { public: - template static void digest(const T& vlist, u_char result[MD5_DIGEST_LENGTH]) - { - digest_all(detail::Hash_MD5, vlist, result); - } + template + static void digest(const T& vlist, u_char result[MD5_DIGEST_LENGTH]) { + digest_all(detail::Hash_MD5, vlist, result); + } - template - static void hmac(const T& vlist, u_char key[MD5_DIGEST_LENGTH], - u_char result[MD5_DIGEST_LENGTH]) - { - digest(vlist, result); + template + static void hmac(const T& vlist, u_char key[MD5_DIGEST_LENGTH], u_char result[MD5_DIGEST_LENGTH]) { + digest(vlist, result); - for ( int i = 0; i < MD5_DIGEST_LENGTH; ++i ) - result[i] ^= key[i]; + for ( int i = 0; i < MD5_DIGEST_LENGTH; ++i ) + result[i] ^= key[i]; - detail::internal_md5(result, MD5_DIGEST_LENGTH, result); - } + detail::internal_md5(result, MD5_DIGEST_LENGTH, result); + } - MD5Val(); - ~MD5Val(); + MD5Val(); + ~MD5Val(); - ValPtr DoClone(CloneState* state) override; + ValPtr DoClone(CloneState* state) override; protected: - friend class Val; + friend class Val; - bool DoInit() override; - bool DoFeed(const void* data, size_t size) override; - StringValPtr DoGet() override; + bool DoInit() override; + bool DoFeed(const void* data, size_t size) override; + StringValPtr DoGet() override; - DECLARE_OPAQUE_VALUE(MD5Val) + DECLARE_OPAQUE_VALUE(MD5Val) private: #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - EVP_MD_CTX* ctx; + EVP_MD_CTX* ctx; #else - MD5_CTX ctx; + MD5_CTX ctx; #endif - }; +}; -class SHA1Val : public HashVal - { +class SHA1Val : public HashVal { public: - template static void digest(const T& vlist, u_char result[SHA_DIGEST_LENGTH]) - { - digest_all(detail::Hash_SHA1, vlist, result); - } + template + static void digest(const T& vlist, u_char result[SHA_DIGEST_LENGTH]) { + digest_all(detail::Hash_SHA1, vlist, result); + } - SHA1Val(); - ~SHA1Val(); + SHA1Val(); + ~SHA1Val(); - ValPtr DoClone(CloneState* state) override; + ValPtr DoClone(CloneState* state) override; protected: - friend class Val; + friend class Val; - bool DoInit() override; - bool DoFeed(const void* data, size_t size) override; - StringValPtr DoGet() override; + bool DoInit() override; + bool DoFeed(const void* data, size_t size) override; + StringValPtr DoGet() override; - DECLARE_OPAQUE_VALUE(SHA1Val) + DECLARE_OPAQUE_VALUE(SHA1Val) private: #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - EVP_MD_CTX* ctx; + EVP_MD_CTX* ctx; #else - SHA_CTX ctx; + SHA_CTX ctx; #endif - }; +}; -class SHA256Val : public HashVal - { +class SHA256Val : public HashVal { public: - template static void digest(const T& vlist, u_char result[SHA256_DIGEST_LENGTH]) - { - digest_all(detail::Hash_SHA256, vlist, result); - } + template + static void digest(const T& vlist, u_char result[SHA256_DIGEST_LENGTH]) { + digest_all(detail::Hash_SHA256, vlist, result); + } - SHA256Val(); - ~SHA256Val(); + SHA256Val(); + ~SHA256Val(); - ValPtr DoClone(CloneState* state) override; + ValPtr DoClone(CloneState* state) override; protected: - friend class Val; + friend class Val; - bool DoInit() override; - bool DoFeed(const void* data, size_t size) override; - StringValPtr DoGet() override; + bool DoInit() override; + bool DoFeed(const void* data, size_t size) override; + StringValPtr DoGet() override; - DECLARE_OPAQUE_VALUE(SHA256Val) + DECLARE_OPAQUE_VALUE(SHA256Val) private: #if ( OPENSSL_VERSION_NUMBER < 0x30000000L ) || defined(LIBRESSL_VERSION_NUMBER) - EVP_MD_CTX* ctx; + EVP_MD_CTX* ctx; #else - SHA256_CTX ctx; + SHA256_CTX ctx; #endif - }; +}; -class EntropyVal : public OpaqueVal - { +class EntropyVal : public OpaqueVal { public: - EntropyVal(); + EntropyVal(); - bool Feed(const void* data, size_t size); - bool Get(double* r_ent, double* r_chisq, double* r_mean, double* r_montepicalc, double* r_scc); + bool Feed(const void* data, size_t size); + bool Get(double* r_ent, double* r_chisq, double* r_mean, double* r_montepicalc, double* r_scc); protected: - friend class Val; + friend class Val; - DECLARE_OPAQUE_VALUE(EntropyVal) + DECLARE_OPAQUE_VALUE(EntropyVal) private: - detail::RandTest state; - }; + detail::RandTest state; +}; -class BloomFilterVal : public OpaqueVal - { +class BloomFilterVal : public OpaqueVal { public: - explicit BloomFilterVal(probabilistic::BloomFilter* bf); - ~BloomFilterVal() override; + explicit BloomFilterVal(probabilistic::BloomFilter* bf); + ~BloomFilterVal() override; - ValPtr DoClone(CloneState* state) override; + ValPtr DoClone(CloneState* state) override; - const TypePtr& Type() const { return type; } + const TypePtr& Type() const { return type; } - bool Typify(TypePtr type); + bool Typify(TypePtr type); - void Add(const Val* val); - bool Decrement(const Val* val); - size_t Count(const Val* val) const; - void Clear(); - bool Empty() const; - std::string InternalState() const; + void Add(const Val* val); + bool Decrement(const Val* val); + size_t Count(const Val* val) const; + void Clear(); + bool Empty() const; + std::string InternalState() const; - static BloomFilterValPtr Merge(const BloomFilterVal* x, const BloomFilterVal* y); - static BloomFilterValPtr Intersect(const BloomFilterVal* x, const BloomFilterVal* y); + static BloomFilterValPtr Merge(const BloomFilterVal* x, const BloomFilterVal* y); + static BloomFilterValPtr Intersect(const BloomFilterVal* x, const BloomFilterVal* y); protected: - friend class Val; - BloomFilterVal(); + friend class Val; + BloomFilterVal(); - DECLARE_OPAQUE_VALUE(BloomFilterVal) + DECLARE_OPAQUE_VALUE(BloomFilterVal) private: - // Disable. - BloomFilterVal(const BloomFilterVal&); - BloomFilterVal& operator=(const BloomFilterVal&); + // Disable. + BloomFilterVal(const BloomFilterVal&); + BloomFilterVal& operator=(const BloomFilterVal&); - TypePtr type; - detail::CompositeHash* hash; - probabilistic::BloomFilter* bloom_filter; - }; + TypePtr type; + detail::CompositeHash* hash; + probabilistic::BloomFilter* bloom_filter; +}; -class CardinalityVal : public OpaqueVal - { +class CardinalityVal : public OpaqueVal { public: - explicit CardinalityVal(probabilistic::detail::CardinalityCounter*); - ~CardinalityVal() override; + explicit CardinalityVal(probabilistic::detail::CardinalityCounter*); + ~CardinalityVal() override; - ValPtr DoClone(CloneState* state) override; + ValPtr DoClone(CloneState* state) override; - void Add(const Val* val); + void Add(const Val* val); - const TypePtr& Type() const { return type; } + const TypePtr& Type() const { return type; } - bool Typify(TypePtr type); + bool Typify(TypePtr type); - probabilistic::detail::CardinalityCounter* Get() { return c; }; + probabilistic::detail::CardinalityCounter* Get() { return c; }; protected: - CardinalityVal(); + CardinalityVal(); - DECLARE_OPAQUE_VALUE(CardinalityVal) + DECLARE_OPAQUE_VALUE(CardinalityVal) private: - TypePtr type; - detail::CompositeHash* hash; - probabilistic::detail::CardinalityCounter* c; - }; + TypePtr type; + detail::CompositeHash* hash; + probabilistic::detail::CardinalityCounter* c; +}; -class ParaglobVal : public OpaqueVal - { +class ParaglobVal : public OpaqueVal { public: - explicit ParaglobVal(std::unique_ptr p); - VectorValPtr Get(StringVal*& pattern); - ValPtr DoClone(CloneState* state) override; - bool operator==(const ParaglobVal& other) const; + explicit ParaglobVal(std::unique_ptr p); + VectorValPtr Get(StringVal*& pattern); + ValPtr DoClone(CloneState* state) override; + bool operator==(const ParaglobVal& other) const; protected: - ParaglobVal() : OpaqueVal(paraglob_type) { } + ParaglobVal() : OpaqueVal(paraglob_type) {} - DECLARE_OPAQUE_VALUE(ParaglobVal) + DECLARE_OPAQUE_VALUE(ParaglobVal) private: - std::unique_ptr internal_paraglob; - }; + std::unique_ptr internal_paraglob; +}; /** * Base class for metric handles. Handle types are not serializable. */ -class TelemetryVal : public OpaqueVal - { +class TelemetryVal : public OpaqueVal { protected: - explicit TelemetryVal(telemetry::IntCounter); - explicit TelemetryVal(telemetry::IntCounterFamily); - explicit TelemetryVal(telemetry::DblCounter); - explicit TelemetryVal(telemetry::DblCounterFamily); - explicit TelemetryVal(telemetry::IntGauge); - explicit TelemetryVal(telemetry::IntGaugeFamily); - explicit TelemetryVal(telemetry::DblGauge); - explicit TelemetryVal(telemetry::DblGaugeFamily); - explicit TelemetryVal(telemetry::IntHistogram); - explicit TelemetryVal(telemetry::IntHistogramFamily); - explicit TelemetryVal(telemetry::DblHistogram); - explicit TelemetryVal(telemetry::DblHistogramFamily); + explicit TelemetryVal(telemetry::IntCounter); + explicit TelemetryVal(telemetry::IntCounterFamily); + explicit TelemetryVal(telemetry::DblCounter); + explicit TelemetryVal(telemetry::DblCounterFamily); + explicit TelemetryVal(telemetry::IntGauge); + explicit TelemetryVal(telemetry::IntGaugeFamily); + explicit TelemetryVal(telemetry::DblGauge); + explicit TelemetryVal(telemetry::DblGaugeFamily); + explicit TelemetryVal(telemetry::IntHistogram); + explicit TelemetryVal(telemetry::IntHistogramFamily); + explicit TelemetryVal(telemetry::DblHistogram); + explicit TelemetryVal(telemetry::DblHistogramFamily); - broker::expected DoSerialize() const override; - bool DoUnserialize(const broker::data& data) override; - }; + broker::expected DoSerialize() const override; + bool DoUnserialize(const broker::data& data) override; +}; -template class TelemetryValImpl : public TelemetryVal - { +template +class TelemetryValImpl : public TelemetryVal { public: - using HandleType = Handle; + using HandleType = Handle; - explicit TelemetryValImpl(Handle hdl) : TelemetryVal(hdl), hdl(hdl) { } + explicit TelemetryValImpl(Handle hdl) : TelemetryVal(hdl), hdl(hdl) {} - Handle GetHandle() const noexcept { return hdl; } + Handle GetHandle() const noexcept { return hdl; } protected: - ValPtr DoClone(CloneState*) override { return make_intrusive(hdl); } + ValPtr DoClone(CloneState*) override { return make_intrusive(hdl); } - const char* OpaqueName() const override { return Handle::OpaqueName; } + const char* OpaqueName() const override { return Handle::OpaqueName; } private: - Handle hdl; - }; + Handle hdl; +}; using IntCounterMetricVal = TelemetryValImpl; using IntCounterMetricFamilyVal = TelemetryValImpl; @@ -470,4 +452,4 @@ using IntHistogramMetricFamilyVal = TelemetryValImpl; using DblHistogramMetricFamilyVal = TelemetryValImpl; - } // namespace zeek +} // namespace zeek diff --git a/src/Options.cc b/src/Options.cc index 86e8d01e51..795f5e6948 100644 --- a/src/Options.cc +++ b/src/Options.cc @@ -19,712 +19,592 @@ #include "zeek/logging/writers/ascii/Ascii.h" #include "zeek/script_opt/ScriptOpt.h" -namespace zeek - { +namespace zeek { -void Options::filter_supervisor_options() - { - pcap_filter = {}; - signature_files = {}; - pcap_output_file = {}; - } +void Options::filter_supervisor_options() { + pcap_filter = {}; + signature_files = {}; + pcap_output_file = {}; +} -void Options::filter_supervised_node_options() - { - auto og = *this; - *this = {}; +void Options::filter_supervised_node_options() { + auto og = *this; + *this = {}; - debug_log_streams = og.debug_log_streams; - debug_script_tracing_file = og.debug_script_tracing_file; - script_code_to_exec = og.script_code_to_exec; - script_prefixes = og.script_prefixes; + debug_log_streams = og.debug_log_streams; + debug_script_tracing_file = og.debug_script_tracing_file; + script_code_to_exec = og.script_code_to_exec; + script_prefixes = og.script_prefixes; - signature_re_level = og.signature_re_level; - ignore_checksums = og.ignore_checksums; - use_watchdog = og.use_watchdog; - pseudo_realtime = og.pseudo_realtime; - dns_mode = og.dns_mode; + signature_re_level = og.signature_re_level; + ignore_checksums = og.ignore_checksums; + use_watchdog = og.use_watchdog; + pseudo_realtime = og.pseudo_realtime; + dns_mode = og.dns_mode; - bare_mode = og.bare_mode; - perftools_check_leaks = og.perftools_check_leaks; - perftools_profile = og.perftools_profile; - deterministic_mode = og.deterministic_mode; - abort_on_scripting_errors = og.abort_on_scripting_errors; + bare_mode = og.bare_mode; + perftools_check_leaks = og.perftools_check_leaks; + perftools_profile = og.perftools_profile; + deterministic_mode = og.deterministic_mode; + abort_on_scripting_errors = og.abort_on_scripting_errors; - pcap_filter = og.pcap_filter; - signature_files = og.signature_files; + pcap_filter = og.pcap_filter; + signature_files = og.signature_files; - // TODO: These are likely to be handled in a node-specific or - // use-case-specific way. e.g. interfaces is already handled for the - // "cluster" use-case, but don't have supervised-pcap-reading - // functionality yet. - /* interface = og.interface; */ - /* pcap_file = og.pcap_file; */ + // TODO: These are likely to be handled in a node-specific or + // use-case-specific way. e.g. interfaces is already handled for the + // "cluster" use-case, but don't have supervised-pcap-reading + // functionality yet. + /* interface = og.interface; */ + /* pcap_file = og.pcap_file; */ - pcap_output_file = og.pcap_output_file; - random_seed_input_file = og.random_seed_input_file; - random_seed_output_file = og.random_seed_output_file; - process_status_file = og.process_status_file; + pcap_output_file = og.pcap_output_file; + random_seed_input_file = og.random_seed_input_file; + random_seed_output_file = og.random_seed_output_file; + process_status_file = og.process_status_file; - plugins_to_load = og.plugins_to_load; - scripts_to_load = og.scripts_to_load; - script_options_to_set = og.script_options_to_set; - } + plugins_to_load = og.plugins_to_load; + scripts_to_load = og.scripts_to_load; + script_options_to_set = og.script_options_to_set; +} -bool fake_dns() - { - return getenv("ZEEK_DNS_FAKE"); - } +bool fake_dns() { return getenv("ZEEK_DNS_FAKE"); } extern const char* zeek_version(); -void usage(const char* prog, int code) - { - fprintf(stderr, "zeek version %s\n", zeek_version()); +void usage(const char* prog, int code) { + fprintf(stderr, "zeek version %s\n", zeek_version()); - fprintf(stderr, "usage: %s [options] [file ...]\n", prog); - fprintf(stderr, "usage: %s --test [doctest-options] -- [options] [file ...]\n", prog); - fprintf(stderr, " | Zeek script file, or read stdin\n"); - fprintf(stderr, - " -a|--parse-only | exit immediately after parsing scripts\n"); - fprintf(stderr, - " -b|--bare-mode | don't load scripts from the base/ directory\n"); - fprintf(stderr, - " -c|--capture-unprocessed | write unprocessed packets to a tcpdump file\n"); - fprintf(stderr, " -d|--debug-script | activate Zeek script debugging\n"); - fprintf(stderr, " -e|--exec | augment loaded scripts by given code\n"); - fprintf(stderr, " -f|--filter | tcpdump filter\n"); - fprintf(stderr, " -h|--help | command line help\n"); - fprintf(stderr, - " -i|--iface | read from given interface (only one allowed)\n"); - fprintf( - stderr, - " -p|--prefix | add given prefix to Zeek script file resolution\n"); - fprintf(stderr, " -r|--readfile | read from given tcpdump file (only one " - "allowed, pass '-' as the filename to read from stdin)\n"); - fprintf(stderr, " -s|--rulefile | read rules from given file\n"); - fprintf(stderr, " -t|--tracefile | activate execution tracing\n"); - fprintf(stderr, " -u|--usage-issues | find variable usage issues and exit\n"); - fprintf(stderr, " --no-unused-warnings | suppress warnings of unused " - "functions/hooks/events\n"); - fprintf(stderr, " -v|--version | print version and exit\n"); - fprintf(stderr, " -V|--build-info | print build information and exit\n"); - fprintf(stderr, " -w|--writefile | write to given tcpdump file\n"); + fprintf(stderr, "usage: %s [options] [file ...]\n", prog); + fprintf(stderr, "usage: %s --test [doctest-options] -- [options] [file ...]\n", prog); + fprintf(stderr, " | Zeek script file, or read stdin\n"); + fprintf(stderr, " -a|--parse-only | exit immediately after parsing scripts\n"); + fprintf(stderr, " -b|--bare-mode | don't load scripts from the base/ directory\n"); + fprintf(stderr, " -c|--capture-unprocessed | write unprocessed packets to a tcpdump file\n"); + fprintf(stderr, " -d|--debug-script | activate Zeek script debugging\n"); + fprintf(stderr, " -e|--exec | augment loaded scripts by given code\n"); + fprintf(stderr, " -f|--filter | tcpdump filter\n"); + fprintf(stderr, " -h|--help | command line help\n"); + fprintf(stderr, " -i|--iface | read from given interface (only one allowed)\n"); + fprintf(stderr, " -p|--prefix | add given prefix to Zeek script file resolution\n"); + fprintf(stderr, + " -r|--readfile | read from given tcpdump file (only one " + "allowed, pass '-' as the filename to read from stdin)\n"); + fprintf(stderr, " -s|--rulefile | read rules from given file\n"); + fprintf(stderr, " -t|--tracefile | activate execution tracing\n"); + fprintf(stderr, " -u|--usage-issues | find variable usage issues and exit\n"); + fprintf(stderr, + " --no-unused-warnings | suppress warnings of unused " + "functions/hooks/events\n"); + fprintf(stderr, " -v|--version | print version and exit\n"); + fprintf(stderr, " -V|--build-info | print build information and exit\n"); + fprintf(stderr, " -w|--writefile | write to given tcpdump file\n"); #ifdef DEBUG - fprintf(stderr, " -B|--debug | Enable debugging output for selected " - "streams ('-B help' for help)\n"); + fprintf(stderr, + " -B|--debug | Enable debugging output for selected " + "streams ('-B help' for help)\n"); #endif - fprintf(stderr, " -C|--no-checksums | ignore checksums\n"); - fprintf(stderr, " -D|--deterministic | initialize random seeds to zero\n"); - fprintf(stderr, " -E|--event-trace | generate a replayable event trace to " - "the given file\n"); - fprintf(stderr, " -F|--force-dns | force DNS\n"); - fprintf(stderr, " -G|--load-seeds | load seeds from given file\n"); - fprintf(stderr, " -H|--save-seeds | save seeds to given file\n"); - fprintf(stderr, " -I|--print-id | print out given ID\n"); - fprintf(stderr, " -N|--print-plugins | print available plugins and exit (-NN " - "for verbose)\n"); - fprintf(stderr, " -O|--optimize