diff --git a/src/ping/core/callback.c b/src/ping/core/callback.c index 3ce108f..a97c454 100644 --- a/src/ping/core/callback.c +++ b/src/ping/core/callback.c @@ -4,11 +4,11 @@ /* Forward declarations */ static int extract_our_echo(const icmp_reply_t *reply, uint16_t our_id, - uint16_t *seq_out); + struct in_addr dest, uint16_t *seq_out); static void handle_echo_reply(struct ping_state *state, const icmp_reply_t *reply); static int extract_our_error(const icmp_reply_t *reply, uint16_t our_id, - icmp_offending_packet_t *out, uint16_t *seq_out); + struct in_addr dest, icmp_offending_packet_t *out, uint16_t *seq_out); static void handle_icmp_error(struct ping_state *state, const icmp_reply_t *reply); /* -------------------- */ @@ -28,14 +28,14 @@ ping_callback(const icmp_reply_t *reply, void *userdata) static int extract_our_echo(const icmp_reply_t *reply, uint16_t our_id, - uint16_t *seq_out) + struct in_addr dest, uint16_t *seq_out) { uint16_t id; uint16_t seq; if (0 > icmp_reply_id_seq(reply, &id, &seq)) return 0; - if (our_id != id) + if (our_id != id || reply->from.s_addr != dest.s_addr) return 0; *seq_out = seq; return 1; @@ -47,7 +47,7 @@ handle_echo_reply(struct ping_state *state, const icmp_reply_t *reply) uint16_t seq; int64_t rtt; - if (0 == extract_our_echo(reply, state->id, &seq)) + if (0 == extract_our_echo(reply, state->id, state->dest, &seq)) return; rtt = ping_tracker_record_recv(state->tracker, seq, &reply->timestamp); if (0 > rtt) @@ -62,13 +62,11 @@ handle_echo_reply(struct ping_state *state, const icmp_reply_t *reply) static int extract_our_error(const icmp_reply_t *reply, uint16_t our_id, - icmp_offending_packet_t *out, uint16_t *seq_out) + struct in_addr dest, icmp_offending_packet_t *out, uint16_t *seq_out) { - if (0 > icmp_error_extract_offending(reply, out)) - return 0; - if (ICMP_TYPE_ECHO_REQUEST != out->icmp_type) - return 0; - if (our_id != out->rest.echo.id) + if ((0 > icmp_error_extract_offending(reply, out)) + || ICMP_TYPE_ECHO_REQUEST != out->icmp_type + || our_id != out->rest.echo.id || out->dst.s_addr != dest.s_addr) return 0; *seq_out = out->rest.echo.seq; return 1; @@ -80,7 +78,7 @@ handle_icmp_error(struct ping_state *state, const icmp_reply_t *reply) icmp_offending_packet_t offending; uint16_t seq; - if (0 == extract_our_error(reply, state->id, &offending, &seq)) + if (0 == extract_our_error(reply, state->id, state->dest, &offending, &seq)) return; if (0 > ping_tracker_record_recv(state->tracker, seq, &reply->timestamp)) return;