diff --git a/src/module_util.cc b/src/module_util.cc index 500d7f4d08..db6cc75012 100644 --- a/src/module_util.cc +++ b/src/module_util.cc @@ -5,6 +5,8 @@ #include #include "module_util.h" +#include "3rdparty/doctest.h" + using namespace std; static int streq(const char* s1, const char* s2) @@ -12,7 +14,20 @@ static int streq(const char* s1, const char* s2) return ! strcmp(s1, s2); } -// Returns it without trailing "::". +TEST_CASE("module_util streq") + { + CHECK(streq("abcd", "abcd") == true); + CHECK(streq("abcd", "efgh") == false); + } + +TEST_CASE("module_util extract_module_name") + { + CHECK(extract_module_name("mod") == GLOBAL_MODULE_NAME); + CHECK(extract_module_name("mod::") == "mod"); + CHECK(extract_module_name("mod::var") == "mod"); + } + +// Returns it without trailing "::" var section. string extract_module_name(const char* name) { string module_name = name; @@ -26,6 +41,14 @@ string extract_module_name(const char* name) return module_name; } +TEST_CASE("module_util extract_var_name") + { + CHECK(extract_var_name("mod") == "mod"); + CHECK(extract_var_name("mod::") == ""); + CHECK(extract_var_name("mod::var") == "var"); + CHECK(extract_var_name("::var") == "var"); + } + string extract_var_name(const char *name) { string var_name = name; @@ -40,6 +63,13 @@ string extract_var_name(const char *name) return var_name.substr(pos+2); } +TEST_CASE("module_util normalized_module_name") + { + CHECK(normalized_module_name("a") == "a"); + CHECK(normalized_module_name("module") == "module"); + CHECK(normalized_module_name("module::") == "module"); + } + string normalized_module_name(const char* module_name) { int mod_len; @@ -50,6 +80,18 @@ string normalized_module_name(const char* module_name) return string(module_name, mod_len); } +TEST_CASE("module_util make_full_var_name") + { + CHECK(make_full_var_name(nullptr, "GLOBAL::var") == "var"); + CHECK(make_full_var_name(GLOBAL_MODULE_NAME, "var") == "var"); + CHECK(make_full_var_name(nullptr, "notglobal::var") == "notglobal::var"); + CHECK(make_full_var_name(nullptr, "::var") == "::var"); + + CHECK(make_full_var_name("module", "var") == "module::var"); + CHECK(make_full_var_name("module::", "var") == "module::var"); + CHECK(make_full_var_name("", "var") == "::var"); + } + string make_full_var_name(const char* module_name, const char* var_name) { if ( ! module_name || streq(module_name, GLOBAL_MODULE_NAME) || diff --git a/src/util.cc b/src/util.cc index 461835964e..b4b5969c0f 100644 --- a/src/util.cc +++ b/src/util.cc @@ -53,6 +53,15 @@ #include "iosource/Manager.h" #include "ConvertUTF.h" +#include "3rdparty/doctest.h" + +TEST_CASE("util extract_ip") + { + CHECK(extract_ip("[1.2.3.4]") == "1.2.3.4"); + CHECK(extract_ip("0x1.2.3.4") == "1.2.3.4"); + CHECK(extract_ip("[]") == ""); + } + /** * Return IP address without enclosing brackets and any leading 0x. Also * trims leading/trailing whitespace. @@ -74,6 +83,27 @@ std::string extract_ip(const std::string& i) return s; } +TEST_CASE("util extract_ip_and_len") + { + int len; + std::string out = extract_ip_and_len("[1.2.3.4/24]", &len); + CHECK(out == "1.2.3.4"); + CHECK(len == 24); + + out = extract_ip_and_len("0x1.2.3.4/32", &len); + CHECK(out == "1.2.3.4"); + CHECK(len == 32); + + out = extract_ip_and_len("[]/abcd", &len); + CHECK(out == ""); + CHECK(len == 0); + + len = 12345; + out = extract_ip_and_len("[]/16", nullptr); + CHECK(out == ""); + CHECK(len == 12345); + } + /** * Given a subnet string, return IP address and subnet length separately. */ @@ -89,6 +119,12 @@ std::string extract_ip_and_len(const std::string& i, int* len) return extract_ip(i.substr(0, pos)); } +TEST_CASE("util get_unescaped_string") + { + CHECK(get_unescaped_string("abcde") == "abcde"); + CHECK(get_unescaped_string("\\x41BCD\\x45") == "ABCDE"); + } + /** * Takes a string, unescapes all characters that are escaped as hex codes * (\x##) and turns them into the equivalent ascii-codes. Returns a string @@ -127,6 +163,31 @@ std::string get_unescaped_string(const std::string& arg_str) return outstring; } +TEST_CASE("util get_escaped_string") + { + SUBCASE("returned ODesc") + { + ODesc* d = get_escaped_string(nullptr, "a bcd\n", 6, false); + CHECK(strcmp(d->Description(), "a\\x20bcd\\x0a") == 0); + } + + SUBCASE("provided ODesc") + { + ODesc d2; + get_escaped_string(&d2, "ab\\e", 4, true); + CHECK(strcmp(d2.Description(), "\\x61\\x62\\\\\\x65") == 0); + } + + SUBCASE("std::string versions") + { + std::string s = get_escaped_string("a b c", 5, false); + CHECK(s == "a\\x20b\\x20c"); + + s = get_escaped_string("d e", false); + CHECK(s == "d\\x20e"); + } + } + /** * Takes a string, escapes characters into equivalent hex codes (\x##), and * returns a string containing all escaped values. @@ -184,6 +245,12 @@ char* copy_string(const char* s) return c; } +TEST_CASE("util streq") + { + CHECK(streq("abcd", "abcd") == true); + CHECK(streq("abcd", "efgh") == false); + } + int streq(const char* s1, const char* s2) { return ! strcmp(s1, s2); @@ -287,6 +354,30 @@ char* skip_digits(char* s) return s; } +TEST_CASE("util get_word") + { + char orig[10]; + strcpy(orig, "two words"); + + SUBCASE("get first word") + { + char* a = (char*)orig; + char* b = get_word(a); + + CHECK(strcmp(a, "words") == 0); + CHECK(strcmp(b, "two") == 0); + } + + SUBCASE("get length of first word") + { + int len = strlen(orig); + int len2; + const char* b = nullptr; + get_word(len, orig, len2, b); + CHECK(len2 == 3); + } + } + char* get_word(char*& s) { char* w = s; @@ -316,6 +407,17 @@ void get_word(int length, const char* s, int& pwlen, const char*& pw) pwlen = len; } +TEST_CASE("util to_upper") + { + char a[10]; + strcpy(a, "aBcD"); + to_upper(a); + CHECK(strcmp(a, "ABCD") == 0); + + std::string b = "aBcD"; + CHECK(to_upper(b) == "ABCD"); + } + void to_upper(char* s) { while ( *s ) @@ -363,6 +465,16 @@ unsigned char encode_hex(int h) return hex[h]; } +TEST_CASE("util strpbrk_n") + { + const char* s = "abcdef"; + const char* o = strpbrk_n(5, s, "gc"); + CHECK(strcmp(o, "cdef") == 0); + + const char* f = strpbrk_n(5, s, "xyz"); + CHECK(f == nullptr); + } + // Same as strpbrk except that s is not NUL-terminated, but limited by // len. Note that '\0' is always implicitly contained in charset. const char* strpbrk_n(size_t len, const char* s, const char* charset) @@ -375,6 +487,20 @@ const char* strpbrk_n(size_t len, const char* s, const char* charset) } #ifndef HAVE_STRCASESTR + +TEST_CASE("util strcasestr") + { + const char* s = "this is a string"; + const char* out = strcasestr(s, "is"); + CHECK(strcmp(out, "is a string") == 0); + + const char* out2 = strcasestr(s, "IS"); + CHECK(strcmp(out2, "is a string") == 0); + + const char* out3 = strcasestr(s, "not there"); + CHECK(strcmp(out2, s) == 0); + } + // This code is derived from software contributed to BSD by Chris Torek. char* strcasestr(const char* s, const char* find) { @@ -401,6 +527,22 @@ char* strcasestr(const char* s, const char* find) } #endif +TEST_CASE("util atoi_n") + { + const char* dec = "12345"; + int val; + + CHECK(atoi_n(strlen(dec), dec, nullptr, 10, val) == 1); + CHECK(val == 12345); + + const char* hex = "12AB"; + CHECK(atoi_n(strlen(hex), hex, nullptr, 16, val) == 1); + CHECK(val == 0x12AB); + + const char* fail = "XYZ"; + CHECK(atoi_n(strlen(fail), fail, nullptr, 10, val) == 0); + } + template int atoi_n(int len, const char* s, const char** end, int base, T& result) { T n = 0; @@ -453,6 +595,15 @@ template int atoi_n(int len, const char* s, const char** end, int base template int atoi_n(int len, const char* s, const char** end, int base, int64_t& result); template int atoi_n(int len, const char* s, const char** end, int base, uint64_t& result); +TEST_CASE("util uitoa_n") + { + int val = 12345; + char str[20]; + const char* result = uitoa_n(val, str, 20, 10, "pref: "); + // TODO: i'm not sure this is the correct output. was it supposed to reverse the digits? + CHECK(strcmp(str, "pref: 54321") == 0); + } + char* uitoa_n(uint64_t value, char* str, int n, int base, const char* prefix) { static char dig[] = "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; @@ -486,6 +637,21 @@ char* uitoa_n(uint64_t value, char* str, int n, int base, const char* prefix) return str; } +TEST_CASE("util strstr_n") + { + const u_char* s = reinterpret_cast("this is a string"); + int out = strstr_n(16, s, 3, reinterpret_cast("str")); + CHECK(out == 10); + + out = strstr_n(16, s, 17, reinterpret_cast("is")); + CHECK(out == -1); + + out = strstr_n(16, s, 2, reinterpret_cast("IS")); + CHECK(out == -1); + + out = strstr_n(16, s, 9, reinterpret_cast("not there")); + CHECK(out == -1); + } int strstr_n(const int big_len, const u_char* big, const int little_len, const u_char* little) @@ -510,6 +676,12 @@ int fputs(int len, const char* s, FILE* fp) return 0; } +TEST_CASE("util is_printable") + { + CHECK(is_printable("abcd", 4) == true); + CHECK(is_printable("ab\0d", 4) == false); + } + bool is_printable(const char* s, int len) { while ( --len >= 0 ) @@ -518,6 +690,15 @@ bool is_printable(const char* s, int len) return true; } +TEST_CASE("util strtolower") + { + const char* a = "aBcD"; + CHECK(strtolower(a) == "abcd"); + + std::string b = "aBcD"; + CHECK(strtolower(b) == "abcd"); + } + std::string strtolower(const std::string& s) { std::string t = s; @@ -525,6 +706,20 @@ std::string strtolower(const std::string& s) return t; } +TEST_CASE("util fmt_bytes") + { + const char* a = "abcd"; + const char* af = fmt_bytes(a, 4); + CHECK(strcmp(a, af) == 0); + + const char* b = "abc\0abc"; + const char* bf = fmt_bytes(b, 7); + CHECK(strcmp(bf, "abc\\x00abc") == 0); + + const char* cf = fmt_bytes(a, 3); + CHECK(strcmp(cf, "abc") == 0); + } + const char* fmt_bytes(const char* data, int len) { static char buf[1024]; @@ -682,6 +877,13 @@ bool is_file(const std::string& path) return S_ISREG(st.st_mode); } +TEST_CASE("util strreplace") + { + string s = "this is not a string"; + CHECK(strreplace(s, "not", "really") == "this is really a string"); + CHECK(strreplace(s, "not ", "") == "this is a string"); + } + string strreplace(const string& s, const string& o, const string& n) { string r = s; @@ -699,6 +901,18 @@ string strreplace(const string& s, const string& o, const string& n) return r; } +TEST_CASE("util strstrip") + { + string s = " abcd"; + CHECK(strstrip(s) == "abcd"); + + s = "abcd "; + CHECK(strstrip(s) == "abcd"); + + s = " abcd "; + CHECK(strstrip(s) == "abcd"); + } + std::string strstrip(std::string s) { auto notspace = [](unsigned char c) { return ! std::isspace(c); }; @@ -1013,6 +1227,12 @@ string bro_prefixes() return rval; } +TEST_CASE("util is_package_loader") + { + CHECK(is_package_loader("/some/path/__load__.zeek") == true); + CHECK(is_package_loader("/some/path/notload.zeek") == false); + } + const array script_extensions = {".zeek", ".bro"}; bool is_package_loader(const string& path) @@ -1072,6 +1292,32 @@ FILE* open_package(string& path, const string& mode) return 0; } +TEST_CASE("util path ops") + { + SUBCASE("SafeDirname") + { + SafeDirname d("/this/is/a/path", false); + CHECK(d.result == "/this/is/a"); + + SafeDirname d2("invalid", false); + CHECK(d2.result == "."); + + SafeDirname d3("./filename", false); + CHECK(d2.result == "."); + } + + SUBCASE("SafeBasename") + { + SafeBasename b("/this/is/a/path", false); + CHECK(b.result == "path"); + CHECK(! b.error); + + SafeBasename b2("justafile", false); + CHECK(b2.result == "justafile"); + CHECK(! b2.error); + } + } + void SafePathOp::CheckValid(const char* op_result, const char* path, bool error_aborts) { @@ -1128,6 +1374,16 @@ void SafeBasename::DoFunc(const string& path, bool error_aborts) delete [] tmp; } +TEST_CASE("util implode_string_vector") + { + std::vector v = { "a", "b", "c" }; + CHECK(implode_string_vector(v, ",") == "a,b,c"); + CHECK(implode_string_vector(v, "") == "abc"); + + v.clear(); + CHECK(implode_string_vector(v, ",") == ""); + } + string implode_string_vector(const std::vector& v, const std::string& delim) { @@ -1144,6 +1400,13 @@ string implode_string_vector(const std::vector& v, return rval; } +TEST_CASE("util flatten_script_name") + { + CHECK(flatten_script_name("script", "some/path") == "some.path.script"); + CHECK(flatten_script_name("other/path/__load__.zeek", "some/path") == "some.path.other.path"); + CHECK(flatten_script_name("path/to/script", "") == "path.to.script"); + } + string flatten_script_name(const string& name, const string& prefix) { string rval = prefix; @@ -1164,6 +1427,23 @@ string flatten_script_name(const string& name, const string& prefix) return rval; } +TEST_CASE("tuil tokenize_string") + { + auto v = tokenize_string("/this/is/a/path", "/", nullptr); + CHECK(v->size() == 5); + CHECK(*v == vector({ "", "this", "is", "a", "path" })); + delete v; + + std::vector v2; + tokenize_string("/this/is/path/2", "/", &v2); + CHECK(v2.size() == 5); + CHECK(v2 == vector({ "", "this", "is", "path", "2" })); + + v2.clear(); + tokenize_string("/wrong/delim", ",", &v2); + CHECK(v2.size() == 1); + } + vector* tokenize_string(string input, const string& delim, vector* rval) { @@ -1182,6 +1462,13 @@ vector* tokenize_string(string input, const string& delim, return rval; } +TEST_CASE("util normalize_path") + { + CHECK(normalize_path("/1/2/3") == "/1/2/3"); + CHECK(normalize_path("/1/./2/3") == "/1/2/3"); + CHECK(normalize_path("/1/2/../3") == "/1/3"); + CHECK(normalize_path("1/2/3/") == "1/2/3"); + } string normalize_path(const string& path) { @@ -1311,6 +1598,13 @@ static bool ends_with(const std::string& s, const std::string& ending) return std::equal(ending.rbegin(), ending.rend(), s.rbegin()); } +TEST_CASE("util ends_with") + { + CHECK(ends_with("abcde", "de") == true); + CHECK(ends_with("abcde", "fg") == false); + CHECK(ends_with("abcde", "abcedf") == false); + } + string find_script_file(const string& filename, const string& path_set) { vector paths; @@ -1810,6 +2104,11 @@ void operator delete[](void* v) #endif +TEST_CASE("util canonify_name") + { + CHECK(canonify_name("file name") == "FILE_NAME"); + } + std::string canonify_name(const std::string& name) { unsigned int len = name.size(); @@ -1887,6 +2186,13 @@ static string json_escape_byte(char c) return result; } +TEST_CASE("util json_escape_utf8") + { + CHECK(json_escape_utf8("string") == "string"); + CHECK(json_escape_utf8("string\n") == "string\n"); + CHECK(json_escape_utf8("string\x82") == "string\\x82"); + } + string json_escape_utf8(const string& val) { string result;