Skip to content

Commit

Permalink
add more buffer length checks and a fuzzer target
Browse files Browse the repository at this point in the history
  • Loading branch information
mwarning committed Nov 17, 2024
1 parent 50982bb commit ec5b3ef
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 18 deletions.
3 changes: 3 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@

all: main.c
gcc -Wall main.c -o main

fuzzer: main.c
clang -g -DFUZZER -O1 -fsanitize=fuzzer,address main.c -o fuzzer
85 changes: 67 additions & 18 deletions main.c
Original file line number Diff line number Diff line change
Expand Up @@ -328,10 +328,10 @@ void put32bits(uint8_t **buffer, uint32_t value)
*/

// 3foo3bar3com0 => foo.bar.com (No full validation is done!)
char *decode_domain_name(const uint8_t **buf, size_t len)
char *decode_domain_name(const uint8_t **buf, size_t buflen)
{
char domain[256];
for (int i = 1; i < MIN(256, len); i += 1) {
for (int i = 1; i < MIN(256, buflen); i += 1) {
uint8_t c = (*buf)[i];
if (c == 0) {
domain[i - 1] = 0;
Expand Down Expand Up @@ -380,12 +380,15 @@ void encode_domain_name(uint8_t **buffer, const char *domain)
*buffer += i;
}


void decode_header(struct Message *msg, const uint8_t **buffer)
bool decode_header(struct Message *msg, const uint8_t **buf, size_t buflen)
{
msg->id = get16bits(buffer);
if (buflen < 12) {
return false;
}

uint32_t fields = get16bits(buffer);
msg->id = get16bits(buf);

uint32_t fields = get16bits(buf);
msg->qr = (fields & QR_MASK) >> 15;
msg->opcode = (fields & OPCODE_MASK) >> 11;
msg->aa = (fields & AA_MASK) >> 10;
Expand All @@ -394,10 +397,12 @@ void decode_header(struct Message *msg, const uint8_t **buffer)
msg->ra = (fields & RA_MASK) >> 7;
msg->rcode = (fields & RCODE_MASK) >> 0;

msg->qdCount = get16bits(buffer);
msg->anCount = get16bits(buffer);
msg->nsCount = get16bits(buffer);
msg->arCount = get16bits(buffer);
msg->qdCount = get16bits(buf);
msg->anCount = get16bits(buf);
msg->nsCount = get16bits(buf);
msg->arCount = get16bits(buf);

return true;
}

void encode_header(struct Message *msg, uint8_t **buffer)
Expand All @@ -416,14 +421,20 @@ void encode_header(struct Message *msg, uint8_t **buffer)
put16bits(buffer, msg->arCount);
}

bool decode_msg(struct Message *msg, const uint8_t *buffer, size_t size)
bool decode_msg(struct Message *msg, const uint8_t *buf, size_t buflen)
{
int i;
const uint8_t *cur = buf;
const uint8_t *end = buf + buflen;

if (size < 12)
if (buflen > BUFFER_SIZE) {
printf("Too much input data!\n");
return false;
}

decode_header(msg, &buffer);
if (!decode_header(msg, &cur, end - cur)) {
printf("Failed to decode header!\n");
return false;
}

if (msg->anCount != 0 || msg->nsCount != 0) {
printf("Only questions expected!\n");
Expand All @@ -432,18 +443,27 @@ bool decode_msg(struct Message *msg, const uint8_t *buffer, size_t size)

// parse questions
uint32_t qcount = msg->qdCount;
for (i = 0; i < qcount; i += 1) {
for (int i = 0; i < qcount; i += 1) {
struct Question *q = calloc(1, sizeof(struct Question));

q->qName = decode_domain_name(&buffer, size);
q->qType = get16bits(&buffer);
q->qClass = get16bits(&buffer);
q->qName = decode_domain_name(&cur, end - cur);

if (q->qName == NULL) {
printf("Failed to decode domain name!\n");
free(q);
return false;
}

if ((end - cur) < 4) {
printf("Data too small!\n");
free(q->qName);
free(q);
return false;
}

q->qType = get16bits(&cur);
q->qClass = get16bits(&cur);

// prepend question to questions list
q->next = msg->questions;
msg->questions = q;
Expand Down Expand Up @@ -638,6 +658,34 @@ void free_questions(struct Question *qq)
}
}

#ifdef FUZZER
int LLVMFuzzerTestOneInput(const uint8_t *data, size_t size)
{
struct Message msg = {0};

// Assume `data` is a DNS query packet and `size` is its length.
// You might need to adapt this if your decode_msg function expects
// a different format or additional parameters.
decode_msg(&msg, data, size);

/* Print query */
print_message(&msg);

/* Resolve query and put the answers into the query message */
resolve_query(&msg);

/* Print response */
print_message(&msg);

// Free any resources allocated by decode_msg to prevent memory leaks.
free_questions(msg.questions);
free_resource_records(msg.answers);
free_resource_records(msg.authorities);
free_resource_records(msg.additionals);

return 0; // Non-zero return values are reserved for future use.
}
#else
int main()
{
// buffer for input/output binary packet
Expand Down Expand Up @@ -707,3 +755,4 @@ int main()
sendto(sock, buffer, buflen, 0, (struct sockaddr*) &client_addr, addr_len);
}
}
#endif

0 comments on commit ec5b3ef

Please sign in to comment.