#include #include #include #include "icmp.h" #include "icmp_types.h" #include "internal/icmp_packet.h" /* Helper to build an error reply with embedded ICMP packet */ static icmp_reply_t make_error_reply(uint8_t error_type, uint8_t *buffer, size_t buffer_size) { icmp_reply_t reply; memset(&reply, 0, sizeof(reply)); memset(buffer, 0, buffer_size); reply.type = error_type; reply.code = 0; reply.payload = buffer; reply.payload_len = buffer_size; return reply; } /* Helper to fill IP header in buffer */ static void fill_ip_header(uint8_t *buffer, const char *src, const char *dst, uint8_t protocol) { struct ip_header *ip = (struct ip_header *)buffer; ip->version_ihl = 0x45; /* version 4, IHL 5 (20 bytes) */ ip->ttl = 64; ip->protocol = protocol; ip->saddr = inet_addr(src); ip->daddr = inet_addr(dst); } /* Helper to fill ICMP header after IP header */ static void fill_icmp_header(uint8_t *buffer, uint8_t type, uint8_t code, uint16_t id, uint16_t seq) { buffer[20] = type; /* After 20-byte IP header */ buffer[21] = code; buffer[22] = 0; /* checksum high */ buffer[23] = 0; /* checksum low */ /* id and seq in network byte order */ *(uint16_t *)(buffer + 24) = htons(id); *(uint16_t *)(buffer + 26) = htons(seq); } Test(extract_offending, time_exceeded_with_echo) { uint8_t buffer[28]; icmp_reply_t reply = make_error_reply(ICMP_TYPE_TIME_EXCEEDED, buffer, sizeof(buffer)); fill_ip_header(buffer, "192.168.1.100", "8.8.8.8", 1); fill_icmp_header(buffer, ICMP_TYPE_ECHO_REQUEST, 0, 0x1234, 0x5678); icmp_offending_packet_t offending; int ret = icmp_error_extract_offending(&reply, &offending); cr_assert_eq(ret, 0, "Should return 0 for valid Time Exceeded"); cr_assert_eq(offending.src.s_addr, inet_addr("192.168.1.100"), "Source address should match"); cr_assert_eq(offending.dst.s_addr, inet_addr("8.8.8.8"), "Destination address should match"); cr_assert_eq(offending.protocol, 1, "Protocol should be 1 (ICMP)"); cr_assert_eq(offending.icmp_type, ICMP_TYPE_ECHO_REQUEST, "ICMP type should be Echo Request"); cr_assert_eq(offending.icmp_code, 0, "ICMP code should be 0"); cr_assert_eq(offending.rest.echo.id, 0x1234, "Echo ID should match"); cr_assert_eq(offending.rest.echo.seq, 0x5678, "Echo seq should match"); } Test(extract_offending, dest_unreachable_with_timestamp) { uint8_t buffer[28]; icmp_reply_t reply = make_error_reply(ICMP_TYPE_DEST_UNREACHABLE, buffer, sizeof(buffer)); fill_ip_header(buffer, "10.0.0.1", "10.0.0.2", 1); fill_icmp_header(buffer, ICMP_TYPE_TIMESTAMP_REQUEST, 0, 0xABCD, 0xEF01); icmp_offending_packet_t offending; int ret = icmp_error_extract_offending(&reply, &offending); cr_assert_eq(ret, 0, "Should return 0 for valid Dest Unreachable"); cr_assert_eq(offending.protocol, 1, "Protocol should be ICMP"); cr_assert_eq(offending.icmp_type, ICMP_TYPE_TIMESTAMP_REQUEST); cr_assert_eq(offending.rest.echo.id, 0xABCD); cr_assert_eq(offending.rest.echo.seq, 0xEF01); } Test(extract_offending, redirect_with_gateway) { uint8_t buffer[28]; icmp_reply_t reply = make_error_reply(ICMP_TYPE_REDIRECT, buffer, sizeof(buffer)); fill_ip_header(buffer, "10.0.0.1", "10.0.0.2", 1); /* For redirect, the 4 bytes after type/code/checksum are the gateway address */ buffer[20] = ICMP_TYPE_REDIRECT; buffer[21] = 0; buffer[22] = 0; buffer[23] = 0; *(uint32_t *)(buffer + 24) = inet_addr("192.168.1.1"); /* Gateway */ icmp_offending_packet_t offending; int ret = icmp_error_extract_offending(&reply, &offending); cr_assert_eq(ret, 0, "Should return 0 for valid Redirect"); cr_assert_eq(offending.icmp_type, ICMP_TYPE_REDIRECT); cr_assert_eq(offending.rest.gateway, inet_addr("192.168.1.1"), "Gateway address should match"); } Test(extract_offending, parameter_problem) { uint8_t buffer[28]; icmp_reply_t reply = make_error_reply(ICMP_TYPE_PARAMETER_PROBLEM, buffer, sizeof(buffer)); fill_ip_header(buffer, "1.2.3.4", "5.6.7.8", 1); fill_icmp_header(buffer, ICMP_TYPE_ECHO_REQUEST, 0, 100, 200); icmp_offending_packet_t offending; int ret = icmp_error_extract_offending(&reply, &offending); cr_assert_eq(ret, 0, "Should return 0 for Parameter Problem"); cr_assert_eq(offending.rest.echo.id, 100); cr_assert_eq(offending.rest.echo.seq, 200); } Test(extract_offending, non_icmp_protocol) { uint8_t buffer[28]; icmp_reply_t reply = make_error_reply(ICMP_TYPE_TIME_EXCEEDED, buffer, sizeof(buffer)); /* Embedded packet is UDP (protocol 17), not ICMP */ fill_ip_header(buffer, "10.0.0.1", "10.0.0.2", 17); memset(buffer + 20, 0, 8); /* 8 bytes of UDP header */ icmp_offending_packet_t offending; int ret = icmp_error_extract_offending(&reply, &offending); cr_assert_eq(ret, 0, "Should return 0 even for non-ICMP"); cr_assert_eq(offending.protocol, 17, "Protocol should be 17 (UDP)"); /* icmp_type and icmp_code are set but not meaningful for UDP */ } Test(extract_offending, non_error_type_returns_error) { uint8_t buffer[28]; icmp_reply_t reply = make_error_reply(ICMP_TYPE_ECHO_REPLY, buffer, sizeof(buffer)); fill_ip_header(buffer, "10.0.0.1", "10.0.0.2", 1); icmp_offending_packet_t offending; int ret = icmp_error_extract_offending(&reply, &offending); cr_assert_eq(ret, -1, "Should return -1 for non-error type (Echo Reply)"); } Test(extract_offending, null_reply_returns_error) { icmp_offending_packet_t offending; int ret = icmp_error_extract_offending(NULL, &offending); cr_assert_eq(ret, -1, "Should return -1 for NULL reply"); } Test(extract_offending, null_offending_returns_error) { uint8_t buffer[28]; icmp_reply_t reply = make_error_reply(ICMP_TYPE_TIME_EXCEEDED, buffer, sizeof(buffer)); int ret = icmp_error_extract_offending(&reply, NULL); cr_assert_eq(ret, -1, "Should return -1 for NULL offending"); } Test(extract_offending, payload_too_short_returns_error) { uint8_t buffer[27]; /* One byte short */ icmp_reply_t reply = make_error_reply(ICMP_TYPE_TIME_EXCEEDED, buffer, sizeof(buffer)); icmp_offending_packet_t offending; int ret = icmp_error_extract_offending(&reply, &offending); cr_assert_eq(ret, -1, "Should return -1 when payload < 28 bytes"); } Test(extract_offending, ip_header_with_options) { uint8_t buffer[40]; /* 20 byte IP + options + 8 byte ICMP */ icmp_reply_t reply = make_error_reply(ICMP_TYPE_TIME_EXCEEDED, buffer, sizeof(buffer)); /* IP header with IHL=6 (24 bytes including 4 bytes of options) */ struct ip_header *ip = (struct ip_header *)buffer; ip->version_ihl = 0x46; /* version 4, IHL 6 (24 bytes) */ ip->ttl = 64; ip->protocol = 1; ip->saddr = inet_addr("192.168.1.1"); ip->daddr = inet_addr("192.168.1.2"); /* ICMP header starts at offset 24 */ fill_icmp_header(buffer + 4, ICMP_TYPE_ECHO_REQUEST, 0, 0x9999, 0xAAAA); icmp_offending_packet_t offending; int ret = icmp_error_extract_offending(&reply, &offending); cr_assert_eq(ret, 0, "Should handle IP headers with options"); cr_assert_eq(offending.protocol, 1); cr_assert_eq(offending.icmp_type, ICMP_TYPE_ECHO_REQUEST); cr_assert_eq(offending.rest.echo.id, 0x9999); cr_assert_eq(offending.rest.echo.seq, 0xAAAA); } Test(extract_offending, frag_needed_mtu_from_outer_icmp) { /* outer ICMP header (8 bytes): type=3, code=4, unused=0, mtu=1400 */ uint8_t outer[8]; /* embedded IP (20 bytes) + embedded ICMP echo request (8 bytes) */ uint8_t embedded[28]; icmp_reply_t reply; memset(outer, 0, sizeof(outer)); outer[0] = ICMP_TYPE_DEST_UNREACHABLE; outer[1] = ICMP_CODE_FRAG_NEEDED; *(uint16_t *)(outer + 6) = htons(1400); /* next-hop MTU */ fill_ip_header(embedded, "10.0.0.1", "10.0.0.2", 1); fill_icmp_header(embedded, ICMP_TYPE_ECHO_REQUEST, 0, 0x1234, 0x0001); memset(&reply, 0, sizeof(reply)); reply.type = ICMP_TYPE_DEST_UNREACHABLE; reply.code = ICMP_CODE_FRAG_NEEDED; reply.payload = embedded; reply.payload_len = sizeof(embedded); reply.ip_payload = outer; icmp_offending_packet_t offending; int ret = icmp_error_extract_offending(&reply, &offending); cr_assert_eq(ret, 0); cr_assert_eq(offending.icmp_type, ICMP_TYPE_ECHO_REQUEST); cr_assert_eq(offending.rest.echo.id, 0x1234); cr_assert_eq(offending.rest.echo.seq, 0x0001); cr_assert_eq(offending.next_mtu, 1400, "MTU should be in host byte order"); } Test(extract_offending, redirect_with_tcp_packet) { uint8_t buffer[48]; icmp_reply_t reply = make_error_reply(ICMP_TYPE_REDIRECT, buffer, sizeof(buffer)); fill_ip_header(buffer, "10.0.0.1", "10.0.0.2", 6); // Protocol 6 = TCP /* TCP header after IP */ memset(buffer + 20, 0, 20); buffer[20] = 0x50; buffer[21] = 0x00; icmp_offending_packet_t offending; int ret = icmp_error_extract_offending(&reply, &offending); cr_assert_eq(ret, 0); cr_assert_eq(offending.protocol, 6, "Should be TCP"); }