diff --git a/conn/dialer.c b/conn/dialer.c index 9f34914..9d3b8ef 100644 --- a/conn/dialer.c +++ b/conn/dialer.c @@ -30,30 +30,17 @@ struct Dialer* libp2p_conn_dialer_new(struct Libp2pPeer* peer, struct Peerstore* struct Dialer* dialer = (struct Dialer*)malloc(sizeof(struct Dialer)); if (dialer != NULL) { dialer->peerstore = peerstore; - dialer->peer_id = malloc(peer->id_size + 1); - memset(dialer->peer_id, 0, peer->id_size + 1); - if (dialer->peer_id != NULL) { - strncpy(dialer->peer_id, peer->id, peer->id_size); - // convert private key to rsa private key - /* - struct RsaPrivateKey* rsa_private_key = libp2p_crypto_rsa_rsa_private_key_new(); - if (!libp2p_crypto_encoding_x509_der_to_private_key(private_key->data, private_key->data_size, rsa_private_key)) { - libp2p_crypto_rsa_rsa_private_key_free(rsa_private_key); - libp2p_conn_dialer_free(dialer); - return NULL; + dialer->private_key = rsa_private_key; + dialer->transport_dialers = NULL; + dialer->fallback_dialer = libp2p_conn_tcp_transport_dialer_new(dialer->peer_id, rsa_private_key); + if (peer != NULL) { + dialer->peer_id = malloc(peer->id_size + 1); + memset(dialer->peer_id, 0, peer->id_size + 1); + if (dialer->peer_id != NULL) { + strncpy(dialer->peer_id, peer->id, peer->id_size); } - if (!libp2p_crypto_rsa_private_key_fill_public_key(rsa_private_key)) { - libp2p_crypto_rsa_rsa_private_key_free(rsa_private_key); - libp2p_conn_dialer_free(dialer); - return NULL; - } - */ - dialer->private_key = rsa_private_key; - //TODO: build transport dialers - dialer->transport_dialers = NULL; - dialer->fallback_dialer = libp2p_conn_tcp_transport_dialer_new(dialer->peer_id, rsa_private_key); - return dialer; } + return dialer; } libp2p_conn_dialer_free(dialer); return NULL; @@ -66,7 +53,8 @@ struct Dialer* libp2p_conn_dialer_new(struct Libp2pPeer* peer, struct Peerstore* */ void libp2p_conn_dialer_free(struct Dialer* in) { if (in != NULL) { - free(in->peer_id); + if (in->peer_id != NULL) + free(in->peer_id); libp2p_crypto_rsa_rsa_private_key_free(in->private_key); if (in->transport_dialers != NULL) { struct Libp2pLinkedList* current = in->transport_dialers; diff --git a/conn/transport_dialer.c b/conn/transport_dialer.c index 84a2b72..00ede25 100644 --- a/conn/transport_dialer.c +++ b/conn/transport_dialer.c @@ -6,12 +6,14 @@ struct TransportDialer* libp2p_conn_transport_dialer_new(char* peer_id, struct RsaPrivateKey* private_key) { struct TransportDialer* out = (struct TransportDialer*)malloc(sizeof(struct TransportDialer)); if (out != NULL) { - out->peer_id = malloc(strlen(peer_id) + 1); - strcpy(out->peer_id, peer_id); + out->peer_id = NULL; + out->private_key = NULL; + if (peer_id != NULL) { + out->peer_id = malloc(strlen(peer_id) + 1); + strcpy(out->peer_id, peer_id); + } if (private_key != NULL) { out->private_key = private_key; - } else { - out->private_key = NULL; } } return out; diff --git a/include/libp2p/net/connectionstream.h b/include/libp2p/net/connectionstream.h index d928e75..73e757f 100644 --- a/include/libp2p/net/connectionstream.h +++ b/include/libp2p/net/connectionstream.h @@ -3,6 +3,19 @@ #include "libp2p/net/stream.h" #include "libp2p/conn/session.h" +/*** + * This is a Stream wrapper around a basic tcp/ip connection + */ + +/*** + * Create a new stream based on a network connection, and attempt to connect + * @param fd the handle to the network connection + * @param ip the IP address of the connection + * @param port the port of the connection + * @returns a Stream + */ +struct Stream* libp2p_net_connection_new(int fd, char* ip, int port, struct SessionContext* session_context); + /*** * Create a new stream based on a network connection * @param fd the handle to the network connection @@ -10,7 +23,7 @@ * @param port the port of the connection * @returns a Stream */ -struct Stream* libp2p_net_connection_new(int fd, char* ip, int port, struct SessionContext* session_context); +struct Stream* libp2p_net_connection_established(int fd, char* ip, int port, struct SessionContext* session_context); /** * Attempt to upgrade the parent_stream to use the new stream by default diff --git a/include/libp2p/net/server.h b/include/libp2p/net/server.h new file mode 100644 index 0000000..0617d2c --- /dev/null +++ b/include/libp2p/net/server.h @@ -0,0 +1,19 @@ +/** + * Header for libp2p/net/server + */ + +/*** + * Start a server given the information + * NOTE: This spins off a thread. + * @param ip the ip address to attach to + * @param port the port to use + * @param protocol_handlers the protocol handlers + * @returns true(1) on success, false(0) otherwise + */ +int libp2p_net_server_start(const char* ip, int port, struct Libp2pVector* protocol_handlers); + +/*** + * Shut down the server started by libp2p_net_start_server + * @returns true(1) on success, false(0) otherwise + */ +int libp2p_net_server_stop(); diff --git a/include/libp2p/utils/thread_pool.h b/include/libp2p/utils/thread_pool.h new file mode 100644 index 0000000..04d5af2 --- /dev/null +++ b/include/libp2p/utils/thread_pool.h @@ -0,0 +1,187 @@ +/********************************** + * @author Johan Hanssen Seferidis + * License: MIT + * + **********************************/ + +#ifndef _THPOOL_ +#define _THPOOL_ + +#ifdef __cplusplus +extern "C" { +#endif + +/* =================================== API ======================================= */ + + +typedef struct thpool_* threadpool; + + +/** + * @brief Initialize threadpool + * + * Initializes a threadpool. This function will not return untill all + * threads have initialized successfully. + * + * @example + * + * .. + * threadpool thpool; //First we declare a threadpool + * thpool = thpool_init(4); //then we initialize it to 4 threads + * .. + * + * @param num_threads number of threads to be created in the threadpool + * @return threadpool created threadpool on success, + * NULL on error + */ +threadpool thpool_init(int num_threads); + + +/** + * @brief Add work to the job queue + * + * Takes an action and its argument and adds it to the threadpool's job queue. + * If you want to add to work a function with more than one arguments then + * a way to implement this is by passing a pointer to a structure. + * + * NOTICE: You have to cast both the function and argument to not get warnings. + * + * @example + * + * void print_num(int num){ + * printf("%d\n", num); + * } + * + * int main() { + * .. + * int a = 10; + * thpool_add_work(thpool, (void*)print_num, (void*)a); + * .. + * } + * + * @param threadpool threadpool to which the work will be added + * @param function_p pointer to function to add as work + * @param arg_p pointer to an argument + * @return 0 on successs, -1 otherwise. + */ +int thpool_add_work(threadpool, void (*function_p)(void*), void* arg_p); + + +/** + * @brief Wait for all queued jobs to finish + * + * Will wait for all jobs - both queued and currently running to finish. + * Once the queue is empty and all work has completed, the calling thread + * (probably the main program) will continue. + * + * Smart polling is used in wait. The polling is initially 0 - meaning that + * there is virtually no polling at all. If after 1 seconds the threads + * haven't finished, the polling interval starts growing exponentially + * untill it reaches max_secs seconds. Then it jumps down to a maximum polling + * interval assuming that heavy processing is being used in the threadpool. + * + * @example + * + * .. + * threadpool thpool = thpool_init(4); + * .. + * // Add a bunch of work + * .. + * thpool_wait(thpool); + * puts("All added work has finished"); + * .. + * + * @param threadpool the threadpool to wait for + * @return nothing + */ +void thpool_wait(threadpool); + + +/** + * @brief Pauses all threads immediately + * + * The threads will be paused no matter if they are idle or working. + * The threads return to their previous states once thpool_resume + * is called. + * + * While the thread is being paused, new work can be added. + * + * @example + * + * threadpool thpool = thpool_init(4); + * thpool_pause(thpool); + * .. + * // Add a bunch of work + * .. + * thpool_resume(thpool); // Let the threads start their magic + * + * @param threadpool the threadpool where the threads should be paused + * @return nothing + */ +void thpool_pause(threadpool); + + +/** + * @brief Unpauses all threads if they are paused + * + * @example + * .. + * thpool_pause(thpool); + * sleep(10); // Delay execution 10 seconds + * thpool_resume(thpool); + * .. + * + * @param threadpool the threadpool where the threads should be unpaused + * @return nothing + */ +void thpool_resume(threadpool); + + +/** + * @brief Destroy the threadpool + * + * This will wait for the currently active threads to finish and then 'kill' + * the whole threadpool to free up memory. + * + * @example + * int main() { + * threadpool thpool1 = thpool_init(2); + * threadpool thpool2 = thpool_init(2); + * .. + * thpool_destroy(thpool1); + * .. + * return 0; + * } + * + * @param threadpool the threadpool to destroy + * @return nothing + */ +void thpool_destroy(threadpool); + + +/** + * @brief Show currently working threads + * + * Working threads are the threads that are performing work (not idle). + * + * @example + * int main() { + * threadpool thpool1 = thpool_init(2); + * threadpool thpool2 = thpool_init(2); + * .. + * printf("Working threads: %d\n", thpool_num_threads_working(thpool1)); + * .. + * return 0; + * } + * + * @param threadpool the threadpool of interest + * @return integer number of threads working + */ +int thpool_num_threads_working(threadpool); + + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/include/libp2p/yamux/yamux.h b/include/libp2p/yamux/yamux.h index 6153e75..a550df8 100644 --- a/include/libp2p/yamux/yamux.h +++ b/include/libp2p/yamux/yamux.h @@ -24,6 +24,8 @@ struct YamuxContext { int am_server; int state; // the state of the connection struct Libp2pVector* protocol_handlers; + struct StreamMessage* buffered_message; + long buffered_message_pos; }; struct YamuxChannelContext { diff --git a/net/Makefile b/net/Makefile index 74eb944..59d703b 100644 --- a/net/Makefile +++ b/net/Makefile @@ -7,7 +7,7 @@ endif LFLAGS = DEPS = -OBJS = sctp.o socket.o tcp.o udp.o multistream.o protocol.o connectionstream.o stream.o +OBJS = sctp.o socket.o tcp.o udp.o multistream.o protocol.o connectionstream.o stream.o server.o %.o: %.c $(DEPS) $(CC) -c -o $@ $< $(CFLAGS) diff --git a/net/connectionstream.c b/net/connectionstream.c index a73f2d1..104398b 100644 --- a/net/connectionstream.c +++ b/net/connectionstream.c @@ -66,7 +66,54 @@ int libp2p_net_connection_peek(void* stream_context) { * @returns true(1) on success, false(0) otherwise */ int libp2p_net_connection_read(void* stream_context, struct StreamMessage** msg, int timeout_secs) { - return 0; + struct ConnectionContext* ctx = (struct ConnectionContext*) stream_context; + // read from the socket + uint8_t buffer[4096]; + uint8_t* result_buffer = NULL; + int current_size = 0; + while (1) { + int retVal = socket_read(ctx->socket_descriptor, (char*)&buffer[0], 4096, 0, timeout_secs); + if (retVal < 1) { // get out of the loop + if (retVal < 0) // error + return -1; + break; + } + // add what we got to the message + if (result_buffer == NULL) { + result_buffer = malloc(retVal); + if (result_buffer == NULL) + return 0; + current_size = retVal; + memcpy(result_buffer, buffer, retVal); + } else { + void* alloc = realloc(result_buffer, current_size + retVal); + if (alloc == NULL) { + free(result_buffer); + return 0; + } + memcpy(&result_buffer[current_size], buffer, retVal); + current_size += retVal; + } + // Everything ok, loop again (possibly) + if (retVal != 4096) + break; + } + + // now build the message + if (current_size > 0) { + *msg = libp2p_stream_message_new(); + struct StreamMessage* result = *msg; + if (result == NULL) { + free(result_buffer); + return 0; + } + result->data = result_buffer; + result->data_size = current_size; + result->error_number = 0; + libp2p_logger_debug("connectionstream", "libp2p_connectionstream_read: Received %d bytes", result->data_size); + } + + return current_size; } /** @@ -107,6 +154,14 @@ int libp2p_net_connection_write(void* stream_context, struct StreamMessage* msg) return socket_write(ctx->socket_descriptor, (char*)msg->data, msg->data_size, 0); } +int libp2p_net_handle_upgrade(struct Stream* old_stream, struct Stream* new_stream) { + struct ConnectionContext* ctx = (struct ConnectionContext*) old_stream->stream_context; + if (ctx->session_context != NULL) { + ctx->session_context->default_stream = new_stream; + } + return 1; +} + /*** * Create a new stream based on a network connection * @param fd the handle to the network connection @@ -114,7 +169,7 @@ int libp2p_net_connection_write(void* stream_context, struct StreamMessage* msg) * @param port the port of the connection * @returns a Stream */ -struct Stream* libp2p_net_connection_new(int fd, char* ip, int port, struct SessionContext* session_context) { +struct Stream* libp2p_net_connection_established(int fd, char* ip, int port, struct SessionContext* session_context) { struct Stream* out = (struct Stream*) malloc(sizeof(struct Stream)); if (out != NULL) { out->stream_type = STREAM_TYPE_RAW; @@ -123,6 +178,7 @@ struct Stream* libp2p_net_connection_new(int fd, char* ip, int port, struct Sess out->read = libp2p_net_connection_read; out->read_raw = libp2p_net_connection_read_raw; out->write = libp2p_net_connection_write; + out->handle_upgrade = libp2p_net_handle_upgrade; // Multiaddresss char str[strlen(ip) + 25]; memset(str, 0, strlen(ip) + 16); @@ -138,11 +194,26 @@ struct Stream* libp2p_net_connection_new(int fd, char* ip, int port, struct Sess out->stream_context = ctx; ctx->socket_descriptor = fd; ctx->session_context = session_context; - if (!socket_connect4_with_timeout(ctx->socket_descriptor, hostname_to_ip(ip), port, 10) == 0) { - // unable to connect - libp2p_stream_free(out); - out = NULL; - } + } + } + return out; +} + +/*** + * Create a new stream based on a network connection, and attempt to connect + * @param fd the handle to the network connection + * @param ip the IP address of the connection + * @param port the port of the connection + * @returns a Stream + */ +struct Stream* libp2p_net_connection_new(int fd, char* ip, int port, struct SessionContext* session_context) { + struct Stream* out = libp2p_net_connection_established(fd, ip, port, session_context); + if (out != NULL) { + struct ConnectionContext* ctx = (struct ConnectionContext*) out->stream_context; + if (!socket_connect4_with_timeout(ctx->socket_descriptor, hostname_to_ip(ip), port, 10) == 0) { + // unable to connect + libp2p_stream_free(out); + out = NULL; } } return out; diff --git a/net/protocol.c b/net/protocol.c index 252d077..84bb887 100644 --- a/net/protocol.c +++ b/net/protocol.c @@ -15,10 +15,12 @@ * @returns true(1) if there was a match, false(0) otherwise */ const struct Libp2pProtocolHandler* protocol_compare(struct StreamMessage* msg, struct Libp2pVector* protocol_handlers) { - for(int i = 0; i < protocol_handlers->total; i++) { - const struct Libp2pProtocolHandler* handler = (const struct Libp2pProtocolHandler*) libp2p_utils_vector_get(protocol_handlers, i); - if (handler->CanHandle(msg)) { - return handler; + if (protocol_handlers != NULL) { + for(int i = 0; i < protocol_handlers->total; i++) { + const struct Libp2pProtocolHandler* handler = (const struct Libp2pProtocolHandler*) libp2p_utils_vector_get(protocol_handlers, i); + if (handler->CanHandle(msg)) { + return handler; + } } } return NULL; @@ -50,16 +52,8 @@ int libp2p_protocol_marshal(struct StreamMessage* msg, struct Stream* stream, st const struct Libp2pProtocolHandler* handler = protocol_compare(msg, handlers); if (handler == NULL) { - // turn msg->data to a null terminated string for the error message - char str[msg->data_size + 1]; - memcpy(str, msg->data, msg->data_size); - str[msg->data_size] = 0; - for(int i = 0; i < msg->data_size; i++) { - if (str[i] == '\n') { - str[i] = 0; - break; - } - } + // set the msg->error code + msg->error_number = 100; return -1; } diff --git a/net/server.c b/net/server.c new file mode 100644 index 0000000..ccbda72 --- /dev/null +++ b/net/server.c @@ -0,0 +1,187 @@ +/** + * A simple tcp server that uses thread pools and protocol handlers + */ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "libp2p/conn/session.h" +#include "libp2p/net/connectionstream.h" +#include "libp2p/net/multistream.h" +#include "libp2p/net/p2pnet.h" +#include "libp2p/net/protocol.h" +#include "libp2p/nodeio/nodeio.h" +#include "libp2p/os/utils.h" +#include "libp2p/record/message.h" +#include "libp2p/routing/dht_protocol.h" +#include "libp2p/secio/secio.h" +#include "libp2p/utils/logger.h" +#include "libp2p/utils/thread_pool.h" + +struct server_connection_params { + uint32_t ip_address_binary; + const char* ip_address_text; + uint16_t port; + struct Libp2pVector* protocol_handlers; +}; + +struct client_connection_params { + int file_descriptor; + int count; + uint16_t port; + char* ip; + struct Libp2pVector* protocol_handlers; +}; + +// this is the thread id NOTE: there should only be 1 server per instance, as this is a global +pthread_t server_pthread; + +#define BUF_SIZE 4096 + +// this should be set to 5 for normal operation, perhaps higher for debugging purposes +#define DEFAULT_NETWORK_TIMEOUT 5 + +static int server_shutting_down = 0; + +/** + * We've received a new connection. Find out what they want. + * + * @param ptr a pointer to a null_connection_params struct + */ +void libp2p_net_connection (void *ptr) { + struct client_connection_params *connection_param = (struct client_connection_params*) ptr; + int retVal = 0; + + libp2p_logger_info("null", "Connection %d, count %d\n", connection_param->file_descriptor, connection_param->count); + + //TODO: build a stream from the given information + struct Stream* clientStream = libp2p_net_connection_established(connection_param->file_descriptor, connection_param->ip, connection_param->port, NULL); + + // try to read from the network + struct StreamMessage *results = NULL; + // handle the call + for(;;) { + // Read from the network + if (!clientStream->read(clientStream->stream_context, &results, DEFAULT_NETWORK_TIMEOUT)) { + // problem reading + break; + } + if (results != NULL) { + retVal = libp2p_protocol_marshal(results, clientStream, connection_param->protocol_handlers); + libp2p_stream_message_free(results); + results = NULL; + } + if (retVal < 0 || server_shutting_down) { + // exit the loop on error + break; + } + } // end of loop + + connection_param->count--; // update counter. + if (connection_param->ip != NULL) + free(connection_param->ip); + free (connection_param); + return; +} + +/*** + * Called by the daemon to listen for connections + * @param ptr a pointer to an IpfsNodeListenParams struct + * @returns nothing useful. + */ +void* libp2p_server_listen (void *ptr) +{ + server_shutting_down = 0; + int socketfd, s, count = 0; + threadpool thpool = thpool_init(25); + struct server_connection_params *connection_param = (struct server_connection_params*)ptr; + + if ((socketfd = socket_listen(socket_tcp4(), &(connection_param->ip_address_binary), &(connection_param->port))) <= 0) { + libp2p_logger_error("null", "Failed to init null router. Address: %d, Port: %d\n", connection_param->ip_address_text, connection_param->port); + return (void*) 2; + } + + struct client_connection_params* clientConnection = NULL; + + // the main loop, listening for new connections + for (;;) { + int numDescriptors = socket_read_select4(socketfd, 2); + if (server_shutting_down) { + break; + } + if (numDescriptors > 0) { + s = socket_accept4(socketfd, &(connection_param->ip_address_binary), &(connection_param->port)); + if (count >= 50) { // limit reached. + close (s); + continue; + } + + count++; + clientConnection = malloc (sizeof (struct client_connection_params)); + if (clientConnection) { + clientConnection->file_descriptor = s; + clientConnection->count = count; + clientConnection->port = connection_param->port; + clientConnection->ip = malloc(INET_ADDRSTRLEN); + clientConnection->protocol_handlers = connection_param->protocol_handlers; + if (clientConnection->ip == NULL) { + // we are out of memory + free(clientConnection); + continue; + } + if (inet_ntop(AF_INET, &(connection_param->ip_address_binary), clientConnection->ip, INET_ADDRSTRLEN) == NULL) { + free(clientConnection->ip); + clientConnection->ip = NULL; + clientConnection->port = 0; + } + // Create pthread for clientConnection. + thpool_add_work(thpool, libp2p_net_connection, clientConnection); + } + } else { + // timeout... + } + } + + thpool_destroy(thpool); + + free(connection_param); + + close(socketfd); + + return (void*) 2; +} + +/*** + * Start a server given the information + * NOTE: This spins off a thread. + * @param ip the ip address to attach to + * @param port the port to use + * @param protocol_handlers the protocol handlers + * @returns true(1) on success, false(0) otherwise + */ +int libp2p_net_server_start(const char* ip, int port, struct Libp2pVector* protocol_handlers) { + struct server_connection_params* params = (struct server_connection_params*) malloc(sizeof(struct server_connection_params)); + params->ip_address_text = ip; + inet_pton(AF_INET, ip, ¶ms->ip_address_binary); + params->port = port; + params->protocol_handlers = protocol_handlers; + // start on a separate thread + pthread_create(&server_pthread, NULL, libp2p_server_listen, params); + return 1; +} + +/*** + * Shut down the server started by libp2p_net_start_server + * @returns true(1) on success, false(0) otherwise + */ +int libp2p_net_server_stop() { + server_shutting_down = 1; + pthread_join(server_pthread, NULL); + return 1; +} diff --git a/secio/secio.c b/secio/secio.c index 1c98b08..cffef89 100644 --- a/secio/secio.c +++ b/secio/secio.c @@ -1308,6 +1308,14 @@ int libp2p_secio_peek(void* stream_context) { return ctx->stream->parent_stream->peek(ctx->stream->parent_stream->stream_context); } +/*** + * Read a certain amount of bytes from the network + * @param stream_context the secio context + * @param buffer where to put the bytes read + * @param buffer_size the size of the incoming buffer + * @param timeout_secs the network timeout + * @returns the number of bytes read. + */ int libp2p_secio_read_raw(void* stream_context, uint8_t* buffer, int buffer_size, int timeout_secs) { if (stream_context == NULL) { return -1; @@ -1320,6 +1328,7 @@ int libp2p_secio_read_raw(void* stream_context, uint8_t* buffer, int buffer_size } ctx->buffered_message_pos = 0; } + // max_to_read is the lesser of bytes read or buffer_size int max_to_read = (buffer_size > ctx->buffered_message->data_size ? ctx->buffered_message->data_size : buffer_size); memcpy(buffer, &ctx->buffered_message->data[ctx->buffered_message_pos], max_to_read); ctx->buffered_message_pos += max_to_read; diff --git a/test/test_net.h b/test/test_net.h new file mode 100644 index 0000000..3b37e4d --- /dev/null +++ b/test/test_net.h @@ -0,0 +1,10 @@ +#include +#include "libp2p/net/server.h" + +int test_net_server_startup_shutdown() { + + libp2p_net_server_start("127.0.0.1", 1234, NULL); + sleep(5); + libp2p_net_server_stop(); + return 1; +} diff --git a/test/test_yamux.h b/test/test_yamux.h index aa6da5f..9e0ea5f 100644 --- a/test/test_yamux.h +++ b/test/test_yamux.h @@ -5,6 +5,7 @@ #include "libp2p/utils/logger.h" #include "libp2p/net/stream.h" #include "libp2p/net/multistream.h" +#include "libp2p/net/server.h" /*** * Helpers @@ -50,6 +51,24 @@ int mock_identify_read_protocol(void* context, struct StreamMessage** msg, int n return 1; } +/*** + * Sends back the identify protocol (in a yamux wrapper) to fake negotiation + */ +int mock_multistream_read_protocol(void* context, struct StreamMessage** msg, int network_timeout) { + struct StreamMessage message; + const char* id = "/multistream/1.0.0\n"; + message.data_size = strlen(id); + message.data = (uint8_t*)id; + + *msg = libp2p_yamux_prepare_to_send(&message); + // adjust the frame + struct yamux_frame* frame = (struct yamux_frame*)(*msg)->data; + frame->streamid = 1; + frame->flags = yamux_frame_syn; + encode_frame(frame); + return 1; +} + int mock_counter = 0; /*** @@ -137,6 +156,43 @@ int test_yamux_identify() { return retVal; } +/*** + * Attempt to add a protocol to the Yamux protocol + */ +/* +int test_yamux_multistream() { + int retVal = 0; + // setup + // mock + struct Stream* mock_stream = mock_stream_new(); + mock_stream->read = mock_yamux_read_protocol; + // protocol handlers + struct Libp2pVector* protocol_handlers = libp2p_utils_vector_new(1); + struct Libp2pProtocolHandler* handler = libp2p_identify_build_protocol_handler(protocol_handlers); + libp2p_utils_vector_add(protocol_handlers, handler); + // yamux + struct Stream* yamux_stream = libp2p_yamux_stream_new(mock_stream, 0, protocol_handlers); + if (yamux_stream == NULL) + goto exit; + // Now add in another protocol + mock_stream->read = mock_multistream_read_protocol; + if (!libp2p_yamux_stream_add(yamux_stream->stream_context, libp2p_multistream_stream_new(yamux_stream))) { + goto exit; + } + // tear down + retVal = 1; + exit: + if (yamux_stream != NULL) + yamux_stream->close(yamux_stream); + libp2p_protocol_handlers_shutdown(protocol_handlers); + if (mock_message != NULL) { + libp2p_stream_message_free(mock_message); + mock_message = NULL; + } + return retVal; +} +*/ + int test_yamux_incoming_protocol_request() { int retVal = 0; @@ -270,3 +326,42 @@ int test_yamux_identity_frame() { return retVal; } + +int test_yamux_client_server_connect() { + int retVal = 0; + struct Libp2pVector* protocol_handlers = NULL; + struct StreamMessage* resultMessage = NULL; + + //libp2p_logger_add_class("connectionstream"); + + // setup + // build the protocol handler that can handle yamux + protocol_handlers = libp2p_utils_vector_new(1); + struct Libp2pProtocolHandler* handler = libp2p_yamux_build_protocol_handler(protocol_handlers); + libp2p_utils_vector_add(protocol_handlers, handler); + // set up server + libp2p_net_server_start("127.0.0.1", 1234, protocol_handlers); + sleep(1); + // set up client (easiest to use transport dialers) + struct Dialer* dialer = libp2p_conn_dialer_new(NULL, NULL, NULL); + struct MultiAddress* server_ma = multiaddress_new_from_string("/ip4/127.0.0.1/tcp/1234"); + struct Stream* stream = libp2p_conn_dialer_get_connection(dialer, server_ma); + if (stream == NULL) { + fprintf(stderr, "Unable to get stream.\n"); + goto exit; + } + // have client attempt to connect to server and negotiate yamux + struct Stream* yamux_stream = libp2p_yamux_stream_new(stream, 0, protocol_handlers); + if (yamux_stream == NULL) { + fprintf(stderr, "Was supposed to get yamux protocol id, but instead received nothing.\n"); + goto exit; + } + retVal = 1; + exit: + libp2p_net_server_stop(); + if (protocol_handlers != NULL) { + + } + return retVal; + +} diff --git a/test/testit.c b/test/testit.c index bcf995f..292795a 100644 --- a/test/testit.c +++ b/test/testit.c @@ -14,6 +14,7 @@ #include "test_record.h" #include "test_peer.h" #include "test_yamux.h" +#include "test_net.h" #include "libp2p/utils/logger.h" struct test { @@ -117,6 +118,8 @@ int build_test_collection() { add_test("test_yamux_stream_new", test_yamux_stream_new, 1); add_test("test_yamux_identify", test_yamux_identify, 1); add_test("test_yamux_incoming_protocol_request", test_yamux_incoming_protocol_request, 1); + add_test("test_net_server_startup_shutdown", test_net_server_startup_shutdown, 1); + add_test("test_yamux_client_server_connect", test_yamux_client_server_connect, 1); return 1; }; diff --git a/utils/Makefile b/utils/Makefile index 6adedf8..e8bd434 100644 --- a/utils/Makefile +++ b/utils/Makefile @@ -7,7 +7,7 @@ endif LFLAGS = DEPS = -OBJS = string_list.o vector.o linked_list.o logger.o urlencode.o +OBJS = string_list.o vector.o linked_list.o logger.o urlencode.o thread_pool.o %.o: %.c $(DEPS) $(CC) -c -o $@ $< $(CFLAGS) diff --git a/utils/thread_pool.c b/utils/thread_pool.c new file mode 100644 index 0000000..d32e7e9 --- /dev/null +++ b/utils/thread_pool.c @@ -0,0 +1,555 @@ +/* ******************************** + * Author: Johan Hanssen Seferidis + * License: MIT + * Description: Library providing a threading pool where you can add + * work. For usage, check the thpool.h file or README.md + * + *//** @file thpool.h *//* + * + ********************************/ + +#define _POSIX_C_SOURCE 200809L +#include +#include +#include +#include +#include +#include +#include +#if defined(__linux__) +#include +#endif + +#include "libp2p/utils/thread_pool.h" + +#ifdef THPOOL_DEBUG +#define THPOOL_DEBUG 1 +#else +#define THPOOL_DEBUG 0 +#endif + +#if !defined(DISABLE_PRINT) || defined(THPOOL_DEBUG) +#define err(str) fprintf(stderr, str) +#else +#define err(str) +#endif + +static volatile int threads_keepalive; +static volatile int threads_on_hold; + + + +/* ========================== STRUCTURES ============================ */ + + +/* Binary semaphore */ +typedef struct bsem { + pthread_mutex_t mutex; + pthread_cond_t cond; + int v; +} bsem; + + +/* Job */ +typedef struct job{ + struct job* prev; /* pointer to previous job */ + void (*function)(void* arg); /* function pointer */ + void* arg; /* function's argument */ +} job; + + +/* Job queue */ +typedef struct jobqueue{ + pthread_mutex_t rwmutex; /* used for queue r/w access */ + job *front; /* pointer to front of queue */ + job *rear; /* pointer to rear of queue */ + bsem *has_jobs; /* flag as binary semaphore */ + int len; /* number of jobs in queue */ +} jobqueue; + + +/* Thread */ +typedef struct thread{ + int id; /* friendly id */ + pthread_t pthread; /* pointer to actual thread */ + struct thpool_* thpool_p; /* access to thpool */ +} thread; + + +/* Threadpool */ +typedef struct thpool_{ + thread** threads; /* pointer to threads */ + volatile int num_threads_alive; /* threads currently alive */ + volatile int num_threads_working; /* threads currently working */ + pthread_mutex_t thcount_lock; /* used for thread count etc */ + pthread_cond_t threads_all_idle; /* signal to thpool_wait */ + jobqueue jobqueue; /* job queue */ +} thpool_; + + + + + +/* ========================== PROTOTYPES ============================ */ + + +static int thread_init(thpool_* thpool_p, struct thread** thread_p, int id); +static void* thread_do(struct thread* thread_p); +static void thread_hold(int sig_id); +static void thread_destroy(struct thread* thread_p); + +static int jobqueue_init(jobqueue* jobqueue_p); +static void jobqueue_clear(jobqueue* jobqueue_p); +static void jobqueue_push(jobqueue* jobqueue_p, struct job* newjob_p); +static struct job* jobqueue_pull(jobqueue* jobqueue_p); +static void jobqueue_destroy(jobqueue* jobqueue_p); + +static void bsem_init(struct bsem *bsem_p, int value); +static void bsem_reset(struct bsem *bsem_p); +static void bsem_post(struct bsem *bsem_p); +static void bsem_post_all(struct bsem *bsem_p); +static void bsem_wait(struct bsem *bsem_p); + + + + + +/* ========================== THREADPOOL ============================ */ + + +/* Initialise thread pool */ +struct thpool_* thpool_init(int num_threads){ + + threads_on_hold = 0; + threads_keepalive = 1; + + if (num_threads < 0){ + num_threads = 0; + } + + /* Make new thread pool */ + thpool_* thpool_p; + thpool_p = (struct thpool_*)malloc(sizeof(struct thpool_)); + if (thpool_p == NULL){ + err("thpool_init(): Could not allocate memory for thread pool\n"); + return NULL; + } + thpool_p->num_threads_alive = 0; + thpool_p->num_threads_working = 0; + + /* Initialise the job queue */ + if (jobqueue_init(&thpool_p->jobqueue) == -1){ + err("thpool_init(): Could not allocate memory for job queue\n"); + free(thpool_p); + return NULL; + } + + /* Make threads in pool */ + thpool_p->threads = (struct thread**)malloc(num_threads * sizeof(struct thread *)); + if (thpool_p->threads == NULL){ + err("thpool_init(): Could not allocate memory for threads\n"); + jobqueue_destroy(&thpool_p->jobqueue); + free(thpool_p); + return NULL; + } + + pthread_mutex_init(&(thpool_p->thcount_lock), NULL); + pthread_cond_init(&thpool_p->threads_all_idle, NULL); + + /* Thread init */ + int n; + for (n=0; nthreads[n], n); +#if THPOOL_DEBUG + printf("THPOOL_DEBUG: Created thread %d in pool \n", n); +#endif + } + + /* Wait for threads to initialize */ + while (thpool_p->num_threads_alive != num_threads) {} + + return thpool_p; +} + + +/* Add work to the thread pool */ +int thpool_add_work(thpool_* thpool_p, void (*function_p)(void*), void* arg_p){ + job* newjob; + + newjob=(struct job*)malloc(sizeof(struct job)); + if (newjob==NULL){ + err("thpool_add_work(): Could not allocate memory for new job\n"); + return -1; + } + + /* add function and argument */ + newjob->function=function_p; + newjob->arg=arg_p; + + /* add job to queue */ + jobqueue_push(&thpool_p->jobqueue, newjob); + + return 0; +} + + +/* Wait until all jobs have finished */ +void thpool_wait(thpool_* thpool_p){ + pthread_mutex_lock(&thpool_p->thcount_lock); + while (thpool_p->jobqueue.len || thpool_p->num_threads_working) { + pthread_cond_wait(&thpool_p->threads_all_idle, &thpool_p->thcount_lock); + } + pthread_mutex_unlock(&thpool_p->thcount_lock); +} + + +/* Destroy the threadpool */ +void thpool_destroy(thpool_* thpool_p){ + /* No need to destory if it's NULL */ + if (thpool_p == NULL) return ; + + volatile int threads_total = thpool_p->num_threads_alive; + + /* End each thread 's infinite loop */ + threads_keepalive = 0; + + /* Give one second to kill idle threads */ + double TIMEOUT = 1.0; + time_t start, end; + double tpassed = 0.0; + time (&start); + while (tpassed < TIMEOUT && thpool_p->num_threads_alive){ + bsem_post_all(thpool_p->jobqueue.has_jobs); + time (&end); + tpassed = difftime(end,start); + } + + /* Poll remaining threads */ + while (thpool_p->num_threads_alive){ + bsem_post_all(thpool_p->jobqueue.has_jobs); + sleep(1); + } + + /* Job queue cleanup */ + jobqueue_destroy(&thpool_p->jobqueue); + /* Deallocs */ + int n; + for (n=0; n < threads_total; n++){ + thread_destroy(thpool_p->threads[n]); + } + free(thpool_p->threads); + free(thpool_p); +} + + +/* Pause all threads in threadpool */ +void thpool_pause(thpool_* thpool_p) { + int n; + for (n=0; n < thpool_p->num_threads_alive; n++){ + pthread_kill(thpool_p->threads[n]->pthread, SIGUSR1); + } +} + + +/* Resume all threads in threadpool */ +void thpool_resume(thpool_* thpool_p) { + // resuming a single threadpool hasn't been + // implemented yet, meanwhile this supresses + // the warnings + (void)thpool_p; + + threads_on_hold = 0; +} + + +int thpool_num_threads_working(thpool_* thpool_p){ + return thpool_p->num_threads_working; +} + + + + + +/* ============================ THREAD ============================== */ + + +/* Initialize a thread in the thread pool + * + * @param thread address to the pointer of the thread to be created + * @param id id to be given to the thread + * @return 0 on success, -1 otherwise. + */ +static int thread_init (thpool_* thpool_p, struct thread** thread_p, int id){ + + *thread_p = (struct thread*)malloc(sizeof(struct thread)); + if (thread_p == NULL){ + err("thread_init(): Could not allocate memory for thread\n"); + return -1; + } + + (*thread_p)->thpool_p = thpool_p; + (*thread_p)->id = id; + + pthread_create(&(*thread_p)->pthread, NULL, (void *)thread_do, (*thread_p)); + pthread_detach((*thread_p)->pthread); + return 0; +} + + +/* Sets the calling thread on hold */ +static void thread_hold(int sig_id) { + (void)sig_id; + threads_on_hold = 1; + while (threads_on_hold){ + sleep(1); + } +} + + +#if defined(__APPLE__) && defined(__MACH__) + int pthread_setname_np(const char* name); +#endif + +/* What each thread is doing +* +* In principle this is an endless loop. The only time this loop gets interuppted is once +* thpool_destroy() is invoked or the program exits. +* +* @param thread thread that will run this function +* @return nothing +*/ +static void* thread_do(struct thread* thread_p){ + + /* Set thread name for profiling and debuging */ + char thread_name[128] = {0}; + sprintf(thread_name, "thread-pool-%d", thread_p->id); + +#if defined(__linux__) + /* Use prctl instead to prevent using _GNU_SOURCE flag and implicit declaration */ + prctl(PR_SET_NAME, thread_name); +#elif defined(__APPLE__) && defined(__MACH__) + pthread_setname_np(thread_name); +#else + err("thread_do(): pthread_setname_np is not supported on this system"); +#endif + + /* Assure all threads have been created before starting serving */ + thpool_* thpool_p = thread_p->thpool_p; + + /* Register signal handler */ + struct sigaction act; + sigemptyset(&act.sa_mask); + act.sa_flags = 0; + act.sa_handler = thread_hold; + if (sigaction(SIGUSR1, &act, NULL) == -1) { + err("thread_do(): cannot handle SIGUSR1"); + } + + /* Mark thread as alive (initialized) */ + pthread_mutex_lock(&thpool_p->thcount_lock); + thpool_p->num_threads_alive += 1; + pthread_mutex_unlock(&thpool_p->thcount_lock); + + while(threads_keepalive){ + + bsem_wait(thpool_p->jobqueue.has_jobs); + + if (threads_keepalive){ + + pthread_mutex_lock(&thpool_p->thcount_lock); + thpool_p->num_threads_working++; + pthread_mutex_unlock(&thpool_p->thcount_lock); + + /* Read job from queue and execute it */ + void (*func_buff)(void*); + void* arg_buff; + job* job_p = jobqueue_pull(&thpool_p->jobqueue); + if (job_p) { + func_buff = job_p->function; + arg_buff = job_p->arg; + func_buff(arg_buff); + free(job_p); + } + + pthread_mutex_lock(&thpool_p->thcount_lock); + thpool_p->num_threads_working--; + if (!thpool_p->num_threads_working) { + pthread_cond_signal(&thpool_p->threads_all_idle); + } + pthread_mutex_unlock(&thpool_p->thcount_lock); + + } + } + pthread_mutex_lock(&thpool_p->thcount_lock); + thpool_p->num_threads_alive --; + pthread_mutex_unlock(&thpool_p->thcount_lock); + + return NULL; +} + + +/* Frees a thread */ +static void thread_destroy (thread* thread_p){ + free(thread_p); +} + + + + + +/* ============================ JOB QUEUE =========================== */ + + +/* Initialize queue */ +static int jobqueue_init(jobqueue* jobqueue_p){ + jobqueue_p->len = 0; + jobqueue_p->front = NULL; + jobqueue_p->rear = NULL; + + jobqueue_p->has_jobs = (struct bsem*)malloc(sizeof(struct bsem)); + if (jobqueue_p->has_jobs == NULL){ + return -1; + } + + pthread_mutex_init(&(jobqueue_p->rwmutex), NULL); + bsem_init(jobqueue_p->has_jobs, 0); + + return 0; +} + + +/* Clear the queue */ +static void jobqueue_clear(jobqueue* jobqueue_p){ + + while(jobqueue_p->len){ + free(jobqueue_pull(jobqueue_p)); + } + + jobqueue_p->front = NULL; + jobqueue_p->rear = NULL; + bsem_reset(jobqueue_p->has_jobs); + jobqueue_p->len = 0; + +} + + +/* Add (allocated) job to queue + */ +static void jobqueue_push(jobqueue* jobqueue_p, struct job* newjob){ + + pthread_mutex_lock(&jobqueue_p->rwmutex); + newjob->prev = NULL; + + switch(jobqueue_p->len){ + + case 0: /* if no jobs in queue */ + jobqueue_p->front = newjob; + jobqueue_p->rear = newjob; + break; + + default: /* if jobs in queue */ + jobqueue_p->rear->prev = newjob; + jobqueue_p->rear = newjob; + + } + jobqueue_p->len++; + + bsem_post(jobqueue_p->has_jobs); + pthread_mutex_unlock(&jobqueue_p->rwmutex); +} + + +/* Get first job from queue(removes it from queue) +<<<<<<< HEAD + * + * Notice: Caller MUST hold a mutex +======= +>>>>>>> da2c0fe45e43ce0937f272c8cd2704bdc0afb490 + */ +static struct job* jobqueue_pull(jobqueue* jobqueue_p){ + + pthread_mutex_lock(&jobqueue_p->rwmutex); + job* job_p = jobqueue_p->front; + + switch(jobqueue_p->len){ + + case 0: /* if no jobs in queue */ + break; + + case 1: /* if one job in queue */ + jobqueue_p->front = NULL; + jobqueue_p->rear = NULL; + jobqueue_p->len = 0; + break; + + default: /* if >1 jobs in queue */ + jobqueue_p->front = job_p->prev; + jobqueue_p->len--; + /* more than one job in queue -> post it */ + bsem_post(jobqueue_p->has_jobs); + + } + + pthread_mutex_unlock(&jobqueue_p->rwmutex); + return job_p; +} + + +/* Free all queue resources back to the system */ +static void jobqueue_destroy(jobqueue* jobqueue_p){ + jobqueue_clear(jobqueue_p); + free(jobqueue_p->has_jobs); +} + + + + + +/* ======================== SYNCHRONISATION ========================= */ + + +/* Init semaphore to 1 or 0 */ +static void bsem_init(bsem *bsem_p, int value) { + if (value < 0 || value > 1) { + err("bsem_init(): Binary semaphore can take only values 1 or 0"); + exit(1); + } + pthread_mutex_init(&(bsem_p->mutex), NULL); + pthread_cond_init(&(bsem_p->cond), NULL); + bsem_p->v = value; +} + + +/* Reset semaphore to 0 */ +static void bsem_reset(bsem *bsem_p) { + bsem_init(bsem_p, 0); +} + + +/* Post to at least one thread */ +static void bsem_post(bsem *bsem_p) { + pthread_mutex_lock(&bsem_p->mutex); + bsem_p->v = 1; + pthread_cond_signal(&bsem_p->cond); + pthread_mutex_unlock(&bsem_p->mutex); +} + + +/* Post to all threads */ +static void bsem_post_all(bsem *bsem_p) { + pthread_mutex_lock(&bsem_p->mutex); + bsem_p->v = 1; + pthread_cond_broadcast(&bsem_p->cond); + pthread_mutex_unlock(&bsem_p->mutex); +} + + +/* Wait on semaphore until semaphore has value 0 */ +static void bsem_wait(bsem* bsem_p) { + pthread_mutex_lock(&bsem_p->mutex); + while (bsem_p->v != 1) { + pthread_cond_wait(&bsem_p->cond, &bsem_p->mutex); + } + bsem_p->v = 0; + pthread_mutex_unlock(&bsem_p->mutex); +} diff --git a/yamux/session.c b/yamux/session.c index 322dca1..9f5e8e0 100644 --- a/yamux/session.c +++ b/yamux/session.c @@ -280,13 +280,18 @@ int yamux_decode(void* context, const uint8_t* incoming, size_t incoming_size, s memcpy(msg->data, &incoming[sizeof(struct yamux_frame)], msg->data_size); } - struct Stream* yamuxChannelStream = yamux_channel_new(yamuxContext, f.streamid, msg); - struct YamuxChannelContext* channelContext = (struct YamuxChannelContext*)yamuxChannelStream->stream_context; + // if we didn't initiate it, add this new channel (odd stream id is from client, even is from server) + if ( (f.streamid % 2 == 0 && yamuxContext->am_server) || (f.streamid % 2 == 1 && yamuxContext->am_server) ) { + struct Stream* yamuxChannelStream = yamux_channel_new(yamuxContext, f.streamid, msg); + if (yamuxChannelStream == NULL) + return -EPROTO; + struct YamuxChannelContext* channelContext = (struct YamuxChannelContext*)yamuxChannelStream->stream_context; - if (yamux_session->new_stream_fn) - yamux_session->new_stream_fn(yamuxContext, yamuxContext->stream, msg); + if (yamux_session->new_stream_fn) + yamux_session->new_stream_fn(yamuxContext, yamuxContext->stream, msg); - channelContext->state = yamux_stream_syn_recv; + channelContext->state = yamux_stream_syn_recv; + } *return_message = msg; } else diff --git a/yamux/yamux.c b/yamux/yamux.c index 3cd6bb2..9e8871e 100644 --- a/yamux/yamux.c +++ b/yamux/yamux.c @@ -272,6 +272,24 @@ int libp2p_yamux_write(void* stream_context, struct StreamMessage* message) { return retVal; } +/*** + * Given a context, get the YamuxContext + * @param stream_context a YamuxChannelContext or a YamuxContext + * @returns the YamuxContext, or NULL on error + */ +struct YamuxContext* libp2p_yamux_get_context(void* stream_context) { + char proto = ((uint8_t*)stream_context)[0]; + struct YamuxChannelContext* channel = NULL; + struct YamuxContext* ctx = NULL; + if (proto == YAMUX_CHANNEL_CONTEXT) { + channel = (struct YamuxChannelContext*)stream_context; + ctx = channel->yamux_context; + } else if (proto == YAMUX_CONTEXT) { + ctx = (struct YamuxContext*)stream_context; + } + return ctx; +} + /*** * Check to see if there is anything waiting on the network. * @param stream_context the YamuxContext @@ -281,7 +299,7 @@ int libp2p_yamux_peek(void* stream_context) { if (stream_context == NULL) return -1; - struct YamuxContext* ctx = (struct YamuxContext*)stream_context; + struct YamuxContext* ctx = libp2p_yamux_get_context(stream_context); struct Stream* parent_stream = ctx->stream->parent_stream; if (parent_stream == NULL) return -1; @@ -289,9 +307,41 @@ int libp2p_yamux_peek(void* stream_context) { return parent_stream->peek(parent_stream->stream_context); } +/*** + * Read from the network, and place it in the buffer + * NOTE: This may put something in the internal read buffer (i.e. buffer_size is too small) + * @param stream_context the yamux context + * @param buffer the buffer + * @param buffer_size the size of the incoming buffer (max number of bytes to read) + * @param timeout_secs timeout + * @returns number of bytes read. + */ int libp2p_yamux_read_raw(void* stream_context, uint8_t* buffer, int buffer_size, int timeout_secs) { - //TODO: Implement - return -1; + if (stream_context == NULL) { + return -1; + } + struct YamuxContext* ctx = libp2p_yamux_get_context(stream_context); + if (ctx->buffered_message_pos == -1) { + // we need to get info from the network + if (!ctx->stream->read(ctx->stream->stream_context, &ctx->buffered_message, timeout_secs)) { + return -1; + } + ctx->buffered_message_pos = 0; + } + // max_to_read is the lesser of bytes read or buffer_size + int max_to_read = (buffer_size > ctx->buffered_message->data_size ? ctx->buffered_message->data_size : buffer_size); + memcpy(buffer, &ctx->buffered_message->data[ctx->buffered_message_pos], max_to_read); + ctx->buffered_message_pos += max_to_read; + if (ctx->buffered_message_pos == ctx->buffered_message->data_size) { + // we read everything + libp2p_stream_message_free(ctx->buffered_message); + ctx->buffered_message = NULL; + ctx->buffered_message_pos = -1; + } else { + // we didn't read everything. + ctx->buffered_message_pos = max_to_read; + } + return max_to_read; } /** @@ -308,11 +358,14 @@ struct YamuxContext* libp2p_yamux_context_new(struct Stream* stream) { ctx->session = yamux_session_new(NULL, stream, yamux_session_server, NULL); ctx->am_server = 0; ctx->state = 0; + ctx->buffered_message = NULL; + ctx->buffered_message_pos = -1; + ctx->protocol_handlers = NULL; } return ctx; } -int libp2p_yamux_negotiate(struct YamuxContext* ctx) { +int libp2p_yamux_negotiate(struct YamuxContext* ctx, int am_server) { const char* protocolID = "/yamux/1.0.0\n"; struct StreamMessage outgoing; struct StreamMessage* results = NULL; @@ -320,23 +373,25 @@ int libp2p_yamux_negotiate(struct YamuxContext* ctx) { int haveTheirs = 0; int peek_result = 0; - // see if they're trying to send something first - peek_result = libp2p_yamux_peek(ctx); - if (peek_result > 0) { - libp2p_logger_debug("yamux", "There is %d bytes waiting for us. Perhaps it is the yamux header we're expecting.\n", peek_result); - // get the protocol - ctx->stream->parent_stream->read(ctx->stream->parent_stream, &results, yamux_default_timeout); - if (results == NULL || results->data_size == 0) { - libp2p_logger_error("yamux", "We thought we had a yamux header, but we got nothing.\n"); - goto exit; + // see if they're trying to send something first (only if we're the client) + if (!am_server) { + peek_result = libp2p_yamux_peek(ctx); + if (peek_result > 0) { + libp2p_logger_debug("yamux", "There is %d bytes waiting for us. Perhaps it is the yamux header we're expecting.\n", peek_result); + // get the protocol + ctx->stream->parent_stream->read(ctx->stream->parent_stream, &results, yamux_default_timeout); + if (results == NULL || results->data_size == 0) { + libp2p_logger_error("yamux", "We thought we had a yamux header, but we got nothing.\n"); + goto exit; + } + if (strncmp((char*)results->data, protocolID, strlen(protocolID)) != 0) { + libp2p_logger_error("yamux", "We thought we had a yamux header, but we received %d bytes that contained %s.\n", (int)results->data_size, results->data); + goto exit; + } + libp2p_stream_message_free(results); + results = NULL; + haveTheirs = 1; } - if (strncmp((char*)results->data, protocolID, strlen(protocolID)) != 0) { - libp2p_logger_error("yamux", "We thought we had a yamux header, but we received %d bytes that contained %s.\n", (int)results->data_size, results->data); - goto exit; - } - libp2p_stream_message_free(results); - results = NULL; - haveTheirs = 1; } // send the protocol id @@ -348,7 +403,7 @@ int libp2p_yamux_negotiate(struct YamuxContext* ctx) { } // wait for them to send the protocol id back - if (!haveTheirs) { + if (!am_server && !haveTheirs) { // expect the same back ctx->stream->parent_stream->read(ctx->stream->parent_stream->stream_context, &results, yamux_default_timeout); if (results == NULL || results->data_size == 0) { @@ -430,7 +485,7 @@ struct Stream* libp2p_yamux_stream_new(struct Stream* parent_stream, int am_serv ctx->am_server = am_server; ctx->protocol_handlers = protocol_handlers; // attempt to negotiate yamux protocol - if (!libp2p_yamux_negotiate(ctx)) { + if (!libp2p_yamux_negotiate(ctx, am_server)) { libp2p_yamux_stream_free(out); return NULL; } @@ -463,6 +518,10 @@ int libp2p_yamux_channel_close(void* context) { void libp2p_yamux_context_free(struct YamuxContext* ctx) { if (ctx == NULL) return; + if (ctx->buffered_message != NULL) { + libp2p_stream_message_free(ctx->buffered_message); + ctx->buffered_message = NULL; + } // free all the channels if (ctx->channels) { for(int i = 0; i < ctx->channels->total; i++) {