protocol/ip: Only attach IP_Hdr to Packet if valid

Ensure packet->ip_hdr is not set (so no one can assume it's valid)
when AnalyzePacket() found something weird with the header.
This commit is contained in:
Arne Welzel 2022-09-22 16:29:37 +02:00 committed by Tim Wojtulewicz
parent be5a30df7d
commit 70c74e9d71

View file

@ -47,11 +47,11 @@ bool IPAnalyzer::AnalyzePacket(size_t len, const uint8_t* data, Packet* packet)
// data about the header. // data about the header.
auto ip = (const struct ip*)data; auto ip = (const struct ip*)data;
uint32_t protocol = ip->ip_v; uint32_t protocol = ip->ip_v;
std::shared_ptr<IP_Hdr> ip_hdr;
// This is a unique pointer because of the mass of early returns from this method.
if ( protocol == 4 ) if ( protocol == 4 )
{ {
packet->ip_hdr = std::make_shared<IP_Hdr>(ip, false); ip_hdr = std::make_shared<IP_Hdr>(ip, false);
packet->l3_proto = L3_IPV4; packet->l3_proto = L3_IPV4;
} }
else if ( protocol == 6 ) else if ( protocol == 6 )
@ -62,7 +62,7 @@ bool IPAnalyzer::AnalyzePacket(size_t len, const uint8_t* data, Packet* packet)
return false; return false;
} }
packet->ip_hdr = std::make_shared<IP_Hdr>((const struct ip6_hdr*)data, false, len); ip_hdr = std::make_shared<IP_Hdr>((const struct ip6_hdr*)data, false, len);
packet->l3_proto = L3_IPV6; packet->l3_proto = L3_IPV6;
} }
else else
@ -71,20 +71,9 @@ bool IPAnalyzer::AnalyzePacket(size_t len, const uint8_t* data, Packet* packet)
return false; return false;
} }
// If there's an encapsulation stack in this packet, meaning this packet is part of a chain
// of tunnels, make sure to store the IP header in the last flow in the stack so it can be
// used by previous analyzers as we return up the chain.
if ( packet->encap )
{
if ( auto* ec = packet->encap->Last() )
ec->ip_hdr = packet->ip_hdr;
}
const struct ip* ip4 = packet->ip_hdr->IP4_Hdr();
// TotalLen() returns the full length of the IP portion of the packet, including // TotalLen() returns the full length of the IP portion of the packet, including
// the IP header and payload. // the IP header and payload.
uint32_t total_len = packet->ip_hdr->TotalLen(); uint32_t total_len = ip_hdr->TotalLen();
if ( total_len == 0 ) if ( total_len == 0 )
{ {
// TCP segmentation offloading can zero out the ip_len field. // TCP segmentation offloading can zero out the ip_len field.
@ -107,7 +96,7 @@ bool IPAnalyzer::AnalyzePacket(size_t len, const uint8_t* data, Packet* packet)
// For both of these it is safe to pass ip_hdr because the presence // For both of these it is safe to pass ip_hdr because the presence
// is guaranteed for the functions that pass data to us. // is guaranteed for the functions that pass data to us.
uint16_t ip_hdr_len = packet->ip_hdr->HdrLen(); uint16_t ip_hdr_len = ip_hdr->HdrLen();
if ( ip_hdr_len > total_len ) if ( ip_hdr_len > total_len )
{ {
Weird("invalid_IP_header_size", packet); Weird("invalid_IP_header_size", packet);
@ -120,7 +109,9 @@ bool IPAnalyzer::AnalyzePacket(size_t len, const uint8_t* data, Packet* packet)
return false; return false;
} }
if ( packet->ip_hdr->IP4_Hdr() ) const struct ip* ip4 = ip_hdr->IP4_Hdr();
if ( ip4 )
{ {
if ( ip_hdr_len < sizeof(struct ip) ) if ( ip_hdr_len < sizeof(struct ip) )
{ {
@ -137,6 +128,18 @@ bool IPAnalyzer::AnalyzePacket(size_t len, const uint8_t* data, Packet* packet)
} }
} }
// If we got here, the IP_Hdr is most likely valid and safe to use.
packet->ip_hdr = ip_hdr;
// If there's an encapsulation stack in this packet, meaning this packet is part of a chain
// of tunnels, make sure to store the IP header in the last flow in the stack so it can be
// used by previous analyzers as we return up the chain.
if ( packet->encap )
{
if ( auto* ec = packet->encap->Last() )
ec->ip_hdr = packet->ip_hdr;
}
// Ignore if packet matches packet filter. // Ignore if packet matches packet filter.
detail::PacketFilter* packet_filter = packet_mgr->GetPacketFilter(false); detail::PacketFilter* packet_filter = packet_mgr->GetPacketFilter(false);
if ( packet_filter && packet_filter->Match(packet->ip_hdr, total_len, len) ) if ( packet_filter && packet_filter->Match(packet->ip_hdr, total_len, len) )