diff --git a/CHANGES b/CHANGES index 7f12f02dda..0678c65163 100644 --- a/CHANGES +++ b/CHANGES @@ -1,4 +1,10 @@ +3.1.0-dev.319 | 2020-01-06 09:44:11 -0800 + + * Mark safe_snprintf and safe_vsnprintf as deprecated, remove uses of them (Tim Wojtulewicz, Corelight) + + * Add unit tests to util.cc and module_util.cc (Tim Wojtulewicz, Corelight) + 3.1.0-dev.314 | 2019-12-18 13:36:07 -0800 * Add GitHub Action for CI notification emails (Jon Siwek, Corelight) diff --git a/NEWS b/NEWS index b3e957f518..23557fc01d 100644 --- a/NEWS +++ b/NEWS @@ -72,6 +72,9 @@ Deprecated Functionality in favor of the real types they alias. E.g. use int8_t instead of int8. +- The C++ API functions "safe_snprintf" and "safe_vsnprintf" are deprecated. + Use "snprintf" and "vsnprintf" instead. + Zeek 3.0.0 ========== diff --git a/VERSION b/VERSION index 246e324634..6f8d6c8770 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -3.1.0-dev.314 +3.1.0-dev.319 diff --git a/src/DbgBreakpoint.cc b/src/DbgBreakpoint.cc index b1223486d3..50f5295fc5 100644 --- a/src/DbgBreakpoint.cc +++ b/src/DbgBreakpoint.cc @@ -135,7 +135,7 @@ bool DbgBreakpoint::SetLocation(ParseLocationRec plr, string loc_str) } at_stmt = plr.stmt; - safe_snprintf(description, sizeof(description), "%s:%d", + snprintf(description, sizeof(description), "%s:%d", source_filename, source_line); debug_msg("Breakpoint %d set at %s\n", GetID(), Description()); @@ -148,7 +148,7 @@ bool DbgBreakpoint::SetLocation(ParseLocationRec plr, string loc_str) loc_str.c_str()); at_stmt = plr.stmt; const Location* loc = at_stmt->GetLocationInfo(); - safe_snprintf(description, sizeof(description), "%s at %s:%d", + snprintf(description, sizeof(description), "%s at %s:%d", function_name.c_str(), loc->filename, loc->last_line); debug_msg("Breakpoint %d set at %s\n", GetID(), Description()); @@ -171,7 +171,7 @@ bool DbgBreakpoint::SetLocation(Stmt* stmt) AddToGlobalMap(); const Location* loc = stmt->GetLocationInfo(); - safe_snprintf(description, sizeof(description), "%s:%d", + snprintf(description, sizeof(description), "%s:%d", loc->filename, loc->last_line); debug_msg("Breakpoint %d set at %s\n", GetID(), Description()); diff --git a/src/Debug.cc b/src/Debug.cc index d06996b3ce..946fca7bd4 100644 --- a/src/Debug.cc +++ b/src/Debug.cc @@ -717,7 +717,7 @@ static char* get_prompt(bool reset_counter = false) if ( reset_counter ) counter = 0; - safe_snprintf(prompt, sizeof(prompt), "(Zeek [%d]) ", counter++); + snprintf(prompt, sizeof(prompt), "(Zeek [%d]) ", counter++); return prompt; } @@ -743,7 +743,7 @@ string get_context_description(const Stmt* stmt, const Frame* frame) size_t buf_size = strlen(d.Description()) + strlen(loc.filename) + 1024; char* buf = new char[buf_size]; - safe_snprintf(buf, buf_size, "In %s at %s:%d", + snprintf(buf, buf_size, "In %s at %s:%d", d.Description(), loc.filename, loc.last_line); string retval(buf); diff --git a/src/RE.cc b/src/RE.cc index 5b758dfb41..be92ec6e41 100644 --- a/src/RE.cc +++ b/src/RE.cc @@ -109,7 +109,7 @@ void Specific_RE_Matcher::MakeCaseInsensitive() char* s = new char[n + 5 /* slop */]; - safe_snprintf(s, n + 5, fmt, pattern_text); + snprintf(s, n + 5, fmt, pattern_text); delete [] pattern_text; pattern_text = s; @@ -493,7 +493,7 @@ static RE_Matcher* matcher_merge(const RE_Matcher* re1, const RE_Matcher* re2, int n = strlen(text1) + strlen(text2) + strlen(merge_op) + 32 /* slop */ ; char* merge_text = new char[n]; - safe_snprintf(merge_text, n, "(%s)%s(%s)", text1, merge_op, text2); + snprintf(merge_text, n, "(%s)%s(%s)", text1, merge_op, text2); RE_Matcher* merge = new RE_Matcher(merge_text); delete [] merge_text; diff --git a/src/Reporter.cc b/src/Reporter.cc index bf108879ae..5b1de82eff 100644 --- a/src/Reporter.cc +++ b/src/Reporter.cc @@ -430,7 +430,7 @@ void Reporter::DoLog(const char* prefix, EventHandlerPtr event, FILE* out, { va_list aq; va_copy(aq, ap); - int n = safe_vsnprintf(buffer, size, fmt, aq); + int n = vsnprintf(buffer, size, fmt, aq); va_end(aq); if ( postfix ) @@ -451,7 +451,7 @@ void Reporter::DoLog(const char* prefix, EventHandlerPtr event, FILE* out, if ( postfix && *postfix ) // Note, if you change this fmt string, adjust the additional // buffer size above. - safe_snprintf(buffer + strlen(buffer), size - strlen(buffer), " (%s)", postfix); + snprintf(buffer + strlen(buffer), size - strlen(buffer), " (%s)", postfix); bool raise_event = true; diff --git a/src/Val.cc b/src/Val.cc index 907a29d062..d3f3a9eced 100644 --- a/src/Val.cc +++ b/src/Val.cc @@ -2584,7 +2584,7 @@ RecordVal* RecordVal::CoerceTo(const RecordType* t, Val* aggr, bool allow_orphan continue; char buf[512]; - safe_snprintf(buf, sizeof(buf), + snprintf(buf, sizeof(buf), "orphan field \"%s\" in initialization", rv_t->FieldName(i)); Error(buf); @@ -2614,7 +2614,7 @@ RecordVal* RecordVal::CoerceTo(const RecordType* t, Val* aggr, bool allow_orphan ! ar_t->FieldDecl(i)->FindAttr(ATTR_OPTIONAL) ) { char buf[512]; - safe_snprintf(buf, sizeof(buf), + snprintf(buf, sizeof(buf), "non-optional field \"%s\" missing in initialization", ar_t->FieldName(i)); Error(buf); } diff --git a/src/analyzer/protocol/smtp/SMTP.cc b/src/analyzer/protocol/smtp/SMTP.cc index 0628e6e1b5..bdc6503ba1 100644 --- a/src/analyzer/protocol/smtp/SMTP.cc +++ b/src/analyzer/protocol/smtp/SMTP.cc @@ -889,7 +889,7 @@ void SMTP_Analyzer::UnexpectedCommand(const int cmd_code, const int reply_code) // If this happens, please fix the SMTP state machine! // ### Eventually, these should be turned into "weird" events. static char buf[512]; - int len = safe_snprintf(buf, sizeof(buf), + int len = snprintf(buf, sizeof(buf), "%s reply = %d state = %d", SMTP_CMD_WORD(cmd_code), reply_code, state); if ( len > (int) sizeof(buf) ) @@ -902,7 +902,7 @@ void SMTP_Analyzer::UnexpectedReply(const int cmd_code, const int reply_code) // If this happens, please fix the SMTP state machine! // ### Eventually, these should be turned into "weird" events. static char buf[512]; - int len = safe_snprintf(buf, sizeof(buf), + int len = snprintf(buf, sizeof(buf), "%d state = %d, last command = %s", reply_code, state, SMTP_CMD_WORD(cmd_code)); Unexpected (1, "unexpected reply", len, buf); diff --git a/src/analyzer/protocol/tcp/Stats.cc b/src/analyzer/protocol/tcp/Stats.cc index b5337fa2db..9dca450ab7 100644 --- a/src/analyzer/protocol/tcp/Stats.cc +++ b/src/analyzer/protocol/tcp/Stats.cc @@ -71,7 +71,7 @@ void TCPStateStats::PrintStats(BroFile* file, const char* prefix) if ( n > 0 ) { char buf[32]; - safe_snprintf(buf, sizeof(buf), "%-8d", state_cnt[i][j]); + snprintf(buf, sizeof(buf), "%-8d", state_cnt[i][j]); file->Write(buf); } else diff --git a/src/iosource/BPF_Program.cc b/src/iosource/BPF_Program.cc index 4a8e6383d1..a99acd07e4 100644 --- a/src/iosource/BPF_Program.cc +++ b/src/iosource/BPF_Program.cc @@ -85,7 +85,7 @@ bool BPF_Program::Compile(pcap_t* pcap, const char* filter, uint32_t netmask, if ( pcap_compile(pcap, &m_program, (char *) filter, optimize, netmask) < 0 ) { if ( errbuf ) - safe_snprintf(errbuf, errbuf_len, + snprintf(errbuf, errbuf_len, "pcap_compile(%s): %s", filter, pcap_geterr(pcap)); diff --git a/src/iosource/PktSrc.cc b/src/iosource/PktSrc.cc index b30a94d08d..0eae93ac27 100644 --- a/src/iosource/PktSrc.cc +++ b/src/iosource/PktSrc.cc @@ -93,7 +93,7 @@ void PktSrc::Opened(const Properties& arg_props) if ( Packet::GetLinkHeaderSize(arg_props.link_type) < 0 ) { char buf[512]; - safe_snprintf(buf, sizeof(buf), + snprintf(buf, sizeof(buf), "unknown data link type 0x%x", arg_props.link_type); Error(buf); Close(); diff --git a/src/iosource/pcap/Source.cc b/src/iosource/pcap/Source.cc index 44f642ccd1..0c8efd643d 100644 --- a/src/iosource/pcap/Source.cc +++ b/src/iosource/pcap/Source.cc @@ -62,7 +62,7 @@ void PcapSource::OpenLive() if ( pcap_findalldevs(&devs, tmp_errbuf) < 0 ) { - safe_snprintf(errbuf, sizeof(errbuf), + snprintf(errbuf, sizeof(errbuf), "pcap_findalldevs: %s", tmp_errbuf); Error(errbuf); return; @@ -75,7 +75,7 @@ void PcapSource::OpenLive() if ( props.path.empty() ) { - safe_snprintf(errbuf, sizeof(errbuf), + snprintf(errbuf, sizeof(errbuf), "pcap_findalldevs: empty device name"); Error(errbuf); return; @@ -83,7 +83,7 @@ void PcapSource::OpenLive() } else { - safe_snprintf(errbuf, sizeof(errbuf), + snprintf(errbuf, sizeof(errbuf), "pcap_findalldevs: no devices found"); Error(errbuf); return; @@ -263,7 +263,7 @@ bool PcapSource::SetFilter(int index) if ( ! code ) { - safe_snprintf(errbuf, sizeof(errbuf), + snprintf(errbuf, sizeof(errbuf), "No precompiled pcap filter for index %d", index); Error(errbuf); diff --git a/src/module_util.cc b/src/module_util.cc index 500d7f4d08..e0737fc35e 100644 --- a/src/module_util.cc +++ b/src/module_util.cc @@ -1,10 +1,14 @@ // // See the file "COPYING" in the main distribution directory for copyright. -#include -#include #include "module_util.h" +#include +#include +#include + +#include "3rdparty/doctest.h" + using namespace std; static int streq(const char* s1, const char* s2) @@ -12,7 +16,20 @@ static int streq(const char* s1, const char* s2) return ! strcmp(s1, s2); } -// Returns it without trailing "::". +TEST_CASE("module_util streq") + { + CHECK(streq("abcd", "abcd") == true); + CHECK(streq("abcd", "efgh") == false); + } + +TEST_CASE("module_util extract_module_name") + { + CHECK(extract_module_name("mod") == GLOBAL_MODULE_NAME); + CHECK(extract_module_name("mod::") == "mod"); + CHECK(extract_module_name("mod::var") == "mod"); + } + +// Returns it without trailing "::" var section. string extract_module_name(const char* name) { string module_name = name; @@ -26,6 +43,14 @@ string extract_module_name(const char* name) return module_name; } +TEST_CASE("module_util extract_var_name") + { + CHECK(extract_var_name("mod") == "mod"); + CHECK(extract_var_name("mod::") == ""); + CHECK(extract_var_name("mod::var") == "var"); + CHECK(extract_var_name("::var") == "var"); + } + string extract_var_name(const char *name) { string var_name = name; @@ -40,6 +65,13 @@ string extract_var_name(const char *name) return var_name.substr(pos+2); } +TEST_CASE("module_util normalized_module_name") + { + CHECK(normalized_module_name("a") == "a"); + CHECK(normalized_module_name("module") == "module"); + CHECK(normalized_module_name("module::") == "module"); + } + string normalized_module_name(const char* module_name) { int mod_len; @@ -50,6 +82,18 @@ string normalized_module_name(const char* module_name) return string(module_name, mod_len); } +TEST_CASE("module_util make_full_var_name") + { + CHECK(make_full_var_name(nullptr, "GLOBAL::var") == "var"); + CHECK(make_full_var_name(GLOBAL_MODULE_NAME, "var") == "var"); + CHECK(make_full_var_name(nullptr, "notglobal::var") == "notglobal::var"); + CHECK(make_full_var_name(nullptr, "::var") == "::var"); + + CHECK(make_full_var_name("module", "var") == "module::var"); + CHECK(make_full_var_name("module::", "var") == "module::var"); + CHECK(make_full_var_name("", "var") == "::var"); + } + string make_full_var_name(const char* module_name, const char* var_name) { if ( ! module_name || streq(module_name, GLOBAL_MODULE_NAME) || diff --git a/src/net_util.cc b/src/net_util.cc index 1be7969ca8..549d257d7e 100644 --- a/src/net_util.cc +++ b/src/net_util.cc @@ -136,7 +136,7 @@ const char* fmt_conn_id(const IPAddr& src_addr, uint32_t src_port, { static char buffer[512]; - safe_snprintf(buffer, sizeof(buffer), "%s:%d > %s:%d", + snprintf(buffer, sizeof(buffer), "%s:%d > %s:%d", string(src_addr).c_str(), src_port, string(dst_addr).c_str(), dst_port); diff --git a/src/rule-scan.l b/src/rule-scan.l index c75bb5fa9c..e9d2b4fece 100644 --- a/src/rule-scan.l +++ b/src/rule-scan.l @@ -197,7 +197,7 @@ finger { rules_lval.val = Rule::FINGER; return TOK_PATTERN_TYPE; } const char fmt[] = "(?i:%s)"; int n = len + strlen(fmt); char* s = new char[n + 5 /* slop */]; - safe_snprintf(s, n + 5, fmt, yytext + 1); + snprintf(s, n + 5, fmt, yytext + 1); rules_lval.str = s; } else diff --git a/src/strings.bif b/src/strings.bif index 89173ead77..2553b7d068 100644 --- a/src/strings.bif +++ b/src/strings.bif @@ -1096,7 +1096,7 @@ function hexdump%(data_str: string%) : string if ( x == 0 ) { char offset[5]; - safe_snprintf(offset, sizeof(offset), + snprintf(offset, sizeof(offset), "%.4x", data_ptr - data); memcpy(hex_data_ptr, offset, 4); hex_data_ptr += 6; @@ -1104,7 +1104,7 @@ function hexdump%(data_str: string%) : string } char hex_byte[3]; - safe_snprintf(hex_byte, sizeof(hex_byte), + snprintf(hex_byte, sizeof(hex_byte), "%.2x", (u_char) *data_ptr); int val = (u_char) *data_ptr; diff --git a/src/threading/BasicThread.cc b/src/threading/BasicThread.cc index 67434957e5..f8257f5b49 100644 --- a/src/threading/BasicThread.cc +++ b/src/threading/BasicThread.cc @@ -79,7 +79,7 @@ const char* BasicThread::Fmt(const char* format, ...) va_list al; va_start(al, format); - int n = safe_vsnprintf(buf, buf_len, format, al); + int n = vsnprintf(buf, buf_len, format, al); va_end(al); if ( (unsigned int) n >= buf_len ) @@ -89,7 +89,7 @@ const char* BasicThread::Fmt(const char* format, ...) // Is it portable to restart? va_start(al, format); - n = safe_vsnprintf(buf, buf_len, format, al); + n = vsnprintf(buf, buf_len, format, al); va_end(al); } diff --git a/src/util.cc b/src/util.cc index 461835964e..9a5e2d9355 100644 --- a/src/util.cc +++ b/src/util.cc @@ -53,6 +53,15 @@ #include "iosource/Manager.h" #include "ConvertUTF.h" +#include "3rdparty/doctest.h" + +TEST_CASE("util extract_ip") + { + CHECK(extract_ip("[1.2.3.4]") == "1.2.3.4"); + CHECK(extract_ip("0x1.2.3.4") == "1.2.3.4"); + CHECK(extract_ip("[]") == ""); + } + /** * Return IP address without enclosing brackets and any leading 0x. Also * trims leading/trailing whitespace. @@ -74,6 +83,25 @@ std::string extract_ip(const std::string& i) return s; } +TEST_CASE("util extract_ip_and_len") + { + int len; + std::string out = extract_ip_and_len("[1.2.3.4/24]", &len); + CHECK(out == "1.2.3.4"); + CHECK(len == 24); + + out = extract_ip_and_len("0x1.2.3.4/32", &len); + CHECK(out == "1.2.3.4"); + CHECK(len == 32); + + out = extract_ip_and_len("[]/abcd", &len); + CHECK(out == ""); + CHECK(len == 0); + + out = extract_ip_and_len("[]/16", nullptr); + CHECK(out == ""); + } + /** * Given a subnet string, return IP address and subnet length separately. */ @@ -89,6 +117,12 @@ std::string extract_ip_and_len(const std::string& i, int* len) return extract_ip(i.substr(0, pos)); } +TEST_CASE("util get_unescaped_string") + { + CHECK(get_unescaped_string("abcde") == "abcde"); + CHECK(get_unescaped_string("\\x41BCD\\x45") == "ABCDE"); + } + /** * Takes a string, unescapes all characters that are escaped as hex codes * (\x##) and turns them into the equivalent ascii-codes. Returns a string @@ -127,6 +161,31 @@ std::string get_unescaped_string(const std::string& arg_str) return outstring; } +TEST_CASE("util get_escaped_string") + { + SUBCASE("returned ODesc") + { + ODesc* d = get_escaped_string(nullptr, "a bcd\n", 6, false); + CHECK(strcmp(d->Description(), "a\\x20bcd\\x0a") == 0); + } + + SUBCASE("provided ODesc") + { + ODesc d2; + get_escaped_string(&d2, "ab\\e", 4, true); + CHECK(strcmp(d2.Description(), "\\x61\\x62\\\\\\x65") == 0); + } + + SUBCASE("std::string versions") + { + std::string s = get_escaped_string("a b c", 5, false); + CHECK(s == "a\\x20b\\x20c"); + + s = get_escaped_string("d e", false); + CHECK(s == "d\\x20e"); + } + } + /** * Takes a string, escapes characters into equivalent hex codes (\x##), and * returns a string containing all escaped values. @@ -184,6 +243,12 @@ char* copy_string(const char* s) return c; } +TEST_CASE("util streq") + { + CHECK(streq("abcd", "abcd") == true); + CHECK(streq("abcd", "efgh") == false); + } + int streq(const char* s1, const char* s2) { return ! strcmp(s1, s2); @@ -287,6 +352,30 @@ char* skip_digits(char* s) return s; } +TEST_CASE("util get_word") + { + char orig[10]; + strcpy(orig, "two words"); + + SUBCASE("get first word") + { + char* a = (char*)orig; + char* b = get_word(a); + + CHECK(strcmp(a, "words") == 0); + CHECK(strcmp(b, "two") == 0); + } + + SUBCASE("get length of first word") + { + int len = strlen(orig); + int len2; + const char* b = nullptr; + get_word(len, orig, len2, b); + CHECK(len2 == 3); + } + } + char* get_word(char*& s) { char* w = s; @@ -316,6 +405,17 @@ void get_word(int length, const char* s, int& pwlen, const char*& pw) pwlen = len; } +TEST_CASE("util to_upper") + { + char a[10]; + strcpy(a, "aBcD"); + to_upper(a); + CHECK(strcmp(a, "ABCD") == 0); + + std::string b = "aBcD"; + CHECK(to_upper(b) == "ABCD"); + } + void to_upper(char* s) { while ( *s ) @@ -363,6 +463,16 @@ unsigned char encode_hex(int h) return hex[h]; } +TEST_CASE("util strpbrk_n") + { + const char* s = "abcdef"; + const char* o = strpbrk_n(5, s, "gc"); + CHECK(strcmp(o, "cdef") == 0); + + const char* f = strpbrk_n(5, s, "xyz"); + CHECK(f == nullptr); + } + // Same as strpbrk except that s is not NUL-terminated, but limited by // len. Note that '\0' is always implicitly contained in charset. const char* strpbrk_n(size_t len, const char* s, const char* charset) @@ -375,6 +485,20 @@ const char* strpbrk_n(size_t len, const char* s, const char* charset) } #ifndef HAVE_STRCASESTR + +TEST_CASE("util strcasestr") + { + const char* s = "this is a string"; + const char* out = strcasestr(s, "is"); + CHECK(strcmp(out, "is a string") == 0); + + const char* out2 = strcasestr(s, "IS"); + CHECK(strcmp(out2, "is a string") == 0); + + const char* out3 = strcasestr(s, "not there"); + CHECK(strcmp(out2, s) == 0); + } + // This code is derived from software contributed to BSD by Chris Torek. char* strcasestr(const char* s, const char* find) { @@ -401,6 +525,22 @@ char* strcasestr(const char* s, const char* find) } #endif +TEST_CASE("util atoi_n") + { + const char* dec = "12345"; + int val; + + CHECK(atoi_n(strlen(dec), dec, nullptr, 10, val) == 1); + CHECK(val == 12345); + + const char* hex = "12AB"; + CHECK(atoi_n(strlen(hex), hex, nullptr, 16, val) == 1); + CHECK(val == 0x12AB); + + const char* fail = "XYZ"; + CHECK(atoi_n(strlen(fail), fail, nullptr, 10, val) == 0); + } + template int atoi_n(int len, const char* s, const char** end, int base, T& result) { T n = 0; @@ -453,6 +593,15 @@ template int atoi_n(int len, const char* s, const char** end, int base template int atoi_n(int len, const char* s, const char** end, int base, int64_t& result); template int atoi_n(int len, const char* s, const char** end, int base, uint64_t& result); +TEST_CASE("util uitoa_n") + { + int val = 12345; + char str[20]; + const char* result = uitoa_n(val, str, 20, 10, "pref: "); + // TODO: i'm not sure this is the correct output. was it supposed to reverse the digits? + CHECK(strcmp(str, "pref: 54321") == 0); + } + char* uitoa_n(uint64_t value, char* str, int n, int base, const char* prefix) { static char dig[] = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; @@ -486,6 +635,21 @@ char* uitoa_n(uint64_t value, char* str, int n, int base, const char* prefix) return str; } +TEST_CASE("util strstr_n") + { + const u_char* s = reinterpret_cast("this is a string"); + int out = strstr_n(16, s, 3, reinterpret_cast("str")); + CHECK(out == 10); + + out = strstr_n(16, s, 17, reinterpret_cast("is")); + CHECK(out == -1); + + out = strstr_n(16, s, 2, reinterpret_cast("IS")); + CHECK(out == -1); + + out = strstr_n(16, s, 9, reinterpret_cast("not there")); + CHECK(out == -1); + } int strstr_n(const int big_len, const u_char* big, const int little_len, const u_char* little) @@ -510,6 +674,12 @@ int fputs(int len, const char* s, FILE* fp) return 0; } +TEST_CASE("util is_printable") + { + CHECK(is_printable("abcd", 4) == true); + CHECK(is_printable("ab\0d", 4) == false); + } + bool is_printable(const char* s, int len) { while ( --len >= 0 ) @@ -518,6 +688,15 @@ bool is_printable(const char* s, int len) return true; } +TEST_CASE("util strtolower") + { + const char* a = "aBcD"; + CHECK(strtolower(a) == "abcd"); + + std::string b = "aBcD"; + CHECK(strtolower(b) == "abcd"); + } + std::string strtolower(const std::string& s) { std::string t = s; @@ -525,6 +704,20 @@ std::string strtolower(const std::string& s) return t; } +TEST_CASE("util fmt_bytes") + { + const char* a = "abcd"; + const char* af = fmt_bytes(a, 4); + CHECK(strcmp(a, af) == 0); + + const char* b = "abc\0abc"; + const char* bf = fmt_bytes(b, 7); + CHECK(strcmp(bf, "abc\\x00abc") == 0); + + const char* cf = fmt_bytes(a, 3); + CHECK(strcmp(cf, "abc") == 0); + } + const char* fmt_bytes(const char* data, int len) { static char buf[1024]; @@ -557,14 +750,14 @@ const char* fmt(const char* format, va_list al) va_list alc; va_copy(alc, al); - int n = safe_vsnprintf(buf, buf_len, format, al); + int n = vsnprintf(buf, buf_len, format, al); if ( (unsigned int) n >= buf_len ) { // Not enough room, grow the buffer. buf_len = n + 32; buf = (char*) safe_realloc(buf, buf_len); - n = safe_vsnprintf(buf, buf_len, format, alc); + n = vsnprintf(buf, buf_len, format, alc); if ( (unsigned int) n >= buf_len ) reporter->InternalError("confusion reformatting in fmt()"); @@ -682,6 +875,13 @@ bool is_file(const std::string& path) return S_ISREG(st.st_mode); } +TEST_CASE("util strreplace") + { + string s = "this is not a string"; + CHECK(strreplace(s, "not", "really") == "this is really a string"); + CHECK(strreplace(s, "not ", "") == "this is a string"); + } + string strreplace(const string& s, const string& o, const string& n) { string r = s; @@ -699,6 +899,18 @@ string strreplace(const string& s, const string& o, const string& n) return r; } +TEST_CASE("util strstrip") + { + string s = " abcd"; + CHECK(strstrip(s) == "abcd"); + + s = "abcd "; + CHECK(strstrip(s) == "abcd"); + + s = " abcd "; + CHECK(strstrip(s) == "abcd"); + } + std::string strstrip(std::string s) { auto notspace = [](unsigned char c) { return ! std::isspace(c); }; @@ -1013,6 +1225,12 @@ string bro_prefixes() return rval; } +TEST_CASE("util is_package_loader") + { + CHECK(is_package_loader("/some/path/__load__.zeek") == true); + CHECK(is_package_loader("/some/path/notload.zeek") == false); + } + const array script_extensions = {".zeek", ".bro"}; bool is_package_loader(const string& path) @@ -1072,6 +1290,32 @@ FILE* open_package(string& path, const string& mode) return 0; } +TEST_CASE("util path ops") + { + SUBCASE("SafeDirname") + { + SafeDirname d("/this/is/a/path", false); + CHECK(d.result == "/this/is/a"); + + SafeDirname d2("invalid", false); + CHECK(d2.result == "."); + + SafeDirname d3("./filename", false); + CHECK(d2.result == "."); + } + + SUBCASE("SafeBasename") + { + SafeBasename b("/this/is/a/path", false); + CHECK(b.result == "path"); + CHECK(! b.error); + + SafeBasename b2("justafile", false); + CHECK(b2.result == "justafile"); + CHECK(! b2.error); + } + } + void SafePathOp::CheckValid(const char* op_result, const char* path, bool error_aborts) { @@ -1128,6 +1372,16 @@ void SafeBasename::DoFunc(const string& path, bool error_aborts) delete [] tmp; } +TEST_CASE("util implode_string_vector") + { + std::vector v = { "a", "b", "c" }; + CHECK(implode_string_vector(v, ",") == "a,b,c"); + CHECK(implode_string_vector(v, "") == "abc"); + + v.clear(); + CHECK(implode_string_vector(v, ",") == ""); + } + string implode_string_vector(const std::vector& v, const std::string& delim) { @@ -1144,6 +1398,13 @@ string implode_string_vector(const std::vector& v, return rval; } +TEST_CASE("util flatten_script_name") + { + CHECK(flatten_script_name("script", "some/path") == "some.path.script"); + CHECK(flatten_script_name("other/path/__load__.zeek", "some/path") == "some.path.other.path"); + CHECK(flatten_script_name("path/to/script", "") == "path.to.script"); + } + string flatten_script_name(const string& name, const string& prefix) { string rval = prefix; @@ -1164,6 +1425,23 @@ string flatten_script_name(const string& name, const string& prefix) return rval; } +TEST_CASE("util tokenize_string") + { + auto v = tokenize_string("/this/is/a/path", "/", nullptr); + CHECK(v->size() == 5); + CHECK(*v == vector({ "", "this", "is", "a", "path" })); + delete v; + + std::vector v2; + tokenize_string("/this/is/path/2", "/", &v2); + CHECK(v2.size() == 5); + CHECK(v2 == vector({ "", "this", "is", "path", "2" })); + + v2.clear(); + tokenize_string("/wrong/delim", ",", &v2); + CHECK(v2.size() == 1); + } + vector* tokenize_string(string input, const string& delim, vector* rval) { @@ -1182,6 +1460,13 @@ vector* tokenize_string(string input, const string& delim, return rval; } +TEST_CASE("util normalize_path") + { + CHECK(normalize_path("/1/2/3") == "/1/2/3"); + CHECK(normalize_path("/1/./2/3") == "/1/2/3"); + CHECK(normalize_path("/1/2/../3") == "/1/3"); + CHECK(normalize_path("1/2/3/") == "1/2/3"); + } string normalize_path(const string& path) { @@ -1311,6 +1596,13 @@ static bool ends_with(const std::string& s, const std::string& ending) return std::equal(ending.rbegin(), ending.rend(), s.rbegin()); } +TEST_CASE("util ends_with") + { + CHECK(ends_with("abcde", "de") == true); + CHECK(ends_with("abcde", "fg") == false); + CHECK(ends_with("abcde", "abcedf") == false); + } + string find_script_file(const string& filename, const string& path_set) { vector paths; @@ -1344,7 +1636,7 @@ FILE* rotate_file(const char* name, RecordVal* rotate_info) char newname[buflen], tmpname[buflen+4]; - safe_snprintf(newname, buflen, "%s.%d.%.06f.tmp", + snprintf(newname, buflen, "%s.%d.%.06f.tmp", name, getpid(), network_time); newname[buflen-1] = '\0'; strcpy(tmpname, newname); @@ -1810,6 +2102,11 @@ void operator delete[](void* v) #endif +TEST_CASE("util canonify_name") + { + CHECK(canonify_name("file name") == "FILE_NAME"); + } + std::string canonify_name(const std::string& name) { unsigned int len = name.size(); @@ -1887,6 +2184,13 @@ static string json_escape_byte(char c) return result; } +TEST_CASE("util json_escape_utf8") + { + CHECK(json_escape_utf8("string") == "string"); + CHECK(json_escape_utf8("string\n") == "string\n"); + CHECK(json_escape_utf8("string\x82") == "string\\x82"); + } + string json_escape_utf8(const string& val) { string result; diff --git a/src/util.h b/src/util.h index 129ca02276..3665518f96 100644 --- a/src/util.h +++ b/src/util.h @@ -510,6 +510,7 @@ inline char* safe_strncpy(char* dest, const char* src, size_t n) return result; } +ZEEK_DEPRECATED("Remove in v4.1: Use system snprintf instead") inline int safe_snprintf(char* str, size_t size, const char* format, ...) { va_list al; @@ -521,6 +522,7 @@ inline int safe_snprintf(char* str, size_t size, const char* format, ...) return result; } +ZEEK_DEPRECATED("Remove in v4.1: Use system vsnprintf instead") inline int safe_vsnprintf(char* str, size_t size, const char* format, va_list al) { int result = vsnprintf(str, size, format, al);