agent: implement connection loop with retry

This commit is contained in:
2019-01-21 23:26:39 +02:00
parent 598fd7aa74
commit 986af66adb
2 changed files with 186 additions and 155 deletions

View File

@@ -83,10 +83,11 @@ static void signal_handler(int sig)
static void agent_shutdown(void) static void agent_shutdown(void)
{ {
printf("Shutting down agent...\n"); printf("Shutting down agent...\n");
SSL_shutdown(ssl); if (ssl) SSL_shutdown(ssl);
SSL_free(ssl); if (ssl) SSL_free(ssl);
SSL_CTX_free(ctx); /* release context */ if (ctx) SSL_CTX_free(ctx); /* release context */
close(server); /* close socket */ close(server); /* close socket */
_exit(EXIT_SUCCESS);
} }
static void set_env(void) static void set_env(void)
@@ -116,7 +117,7 @@ static void set_env(void)
int main(int count, char *strings[]) int main(int count, char *strings[])
{ {
int bytes, i; int bytes, i, err;
char *hostname, *portnum; char *hostname, *portnum;
if (count != 6) { if (count != 6) {
@@ -127,158 +128,189 @@ int main(int count, char *strings[])
portnum=strings[2]; portnum=strings[2];
if ((ctx = init_ctx()) == NULL) if ((ctx = init_ctx()) == NULL)
_exit(EXIT_FAILURE); _exit(EXIT_FAILURE);
server = connect_to_rmps(hostname, atoi(portnum));
if (!server) {
fprintf(stderr, "Failed to connect to RMPS: %s:%d\n", hostname, atoi(portnum));
_exit(EXIT_FAILURE);
}
load_certs(ctx, strings[3], strings[4], strings[5]);
ssl = SSL_new(ctx); /* create new SSL connection state */
SSL_set_fd(ssl, server); /* attach the socket descriptor */
if (SSL_connect(ssl) == FAIL) { /* perform the connection */ while (1) {
ERR_print_errors_fp(stderr); server = connect_to_rmps(hostname, atoi(portnum));
close(server); if (!server) {
SSL_CTX_free(ctx); fprintf(stderr, "Failed to connect to RMPS on %s:%d - %s\n", hostname, atoi(portnum), strerror(errno));
_exit(EXIT_FAILURE); printf("Retrying...\n");
} sleep(5);
printf("Connected with %s encryption\n", SSL_get_cipher(ssl)); continue;
show_certs(ssl); }
set_env(); load_certs(ctx, strings[3], strings[4], strings[5]);
atexit(agent_shutdown); ssl = SSL_new(ctx); /* create new SSL connection state */
if (!(args = calloc(1, sizeof(*args) * MAX_AGENT_JOBS))) { SSL_set_fd(ssl, server); /* attach the socket descriptor */
fprintf( stderr,
"Failed to calloc() %d bytes for job_args! Exiting...\n",
(int)sizeof(struct job_args) * MAX_AGENT_JOBS );
SSL_shutdown(ssl);
SSL_free(ssl);
close(server);
SSL_CTX_free(ctx);
_exit(EXIT_FAILURE);
}
if (!(job_thread = calloc(1, sizeof(*job_thread) * MAX_AGENT_JOBS))) {
fprintf( stderr,
"Failed to calloc() %d bytes for job_threads! Exiting...\n",
(int)sizeof(pthread_t) * MAX_AGENT_JOBS );
SSL_shutdown(ssl);
SSL_free(ssl);
close(server);
SSL_CTX_free(ctx);
free(args);
_exit(EXIT_FAILURE);
}
for (i = 0; i < MAX_AGENT_JOBS; i++) {
args[i].slot = FREE;
args[i].ssl = ssl;
}
do { if (SSL_connect(ssl) == FAIL) { /* perform the connection */
struct msg_t buf; ERR_print_errors_fp(stderr);
memset(&buf, 0, sizeof(struct msg_t)); close(server);
bytes = SSL_read(ssl, &buf, sizeof(struct msg_t)); SSL_CTX_free(ctx);
if (bytes > 0) { _exit(EXIT_FAILURE);
short index; }
if (bytes != sizeof(struct msg_t)) { printf("Connected with %s encryption\n", SSL_get_cipher(ssl));
fprintf( stderr, show_certs(ssl);
"Received non-standard data from server!\n" ); set_env();
continue; atexit(agent_shutdown);
} if (!(args = calloc(1, sizeof(*args) * MAX_AGENT_JOBS))) {
if (buf.chunk.id == 0) { fprintf( stderr,
if ((index = get_job_slot()) == FAIL) { "Failed to calloc() %d bytes for job_args! Exiting...\n",
buf.chunk.id = -1; /* ID -1 means reject (full) */ (int)sizeof(struct job_args) * MAX_AGENT_JOBS );
sprintf((char*)buf.chunk.data, "The agent's queue is full!"); SSL_shutdown(ssl);
SSL_write(ssl, &buf, sizeof(struct msg_t)); SSL_free(ssl);
continue; close(server);
} SSL_CTX_free(ctx);
args[index].slot = FULL; _exit(EXIT_FAILURE);
memcpy(&args[index].buf, &buf, sizeof(struct msg_t)); }
switch (args[index].buf.meta.type) { if (!(job_thread = calloc(1, sizeof(*job_thread) * MAX_AGENT_JOBS))) {
case UNIX: fprintf( stderr,
pthread_create( &job_thread[index], "Failed to calloc() %d bytes for job_threads! Exiting...\n",
NULL, (int)sizeof(pthread_t) * MAX_AGENT_JOBS );
exec_unix, SSL_shutdown(ssl);
&args[index] ); SSL_free(ssl);
continue; close(server);
case INSTALL_PKG: SSL_CTX_free(ctx);
pthread_create( &job_thread[index], free(args);
NULL, _exit(EXIT_FAILURE);
install_pkg, }
&args[index] ); for (i = 0; i < MAX_AGENT_JOBS; i++) {
continue; args[i].slot = FREE;
case QUERY_PKG: args[i].ssl = ssl;
pthread_create( &job_thread[index],
NULL,
query_pkg,
&args[index] );
continue;
case DELETE_PKG:
pthread_create( &job_thread[index],
NULL,
delete_pkg,
&args[index] );
continue;
case LIST_PKGS:
pthread_create( &job_thread[index],
NULL,
list_pkgs,
&args[index] );
continue;
case UPDATE_PKG:
pthread_create( &job_thread[index],
NULL,
update_pkg,
&args[index] );
continue;
case UPDATE_PKGS:
pthread_create( &job_thread[index],
NULL,
update_pkgs,
&args[index] );
continue;
case GET_OS:
pthread_create( &job_thread[index],
NULL,
get_os,
&args[index] );
continue;
case GET_KERNEL:
pthread_create( &job_thread[index],
NULL,
get_kernel,
&args[index] );
continue;
case GET_UPTIME:
pthread_create( &job_thread[index],
NULL,
get_uptime,
&args[index] );
continue;
case GET_MEMORY:
pthread_create( &job_thread[index],
NULL,
get_memory,
&args[index] );
continue;
default:
buf.chunk.id = -1;
sprintf( (char*)buf.chunk.data,
"Unsupported job type with ID: %d",
buf.meta.type );
SSL_write(ssl, &buf, sizeof(struct msg_t));
continue;
}
} else {
index = find_job(buf.meta.id);
if (index == FAIL) {
sprintf( (char*)buf.chunk.data,
"Data was sent for an invalid job ID" );
SSL_write(ssl, &buf, sizeof(struct msg_t));
} else
memcpy(&args[index].buf, &buf, sizeof(struct msg_t));
}
} }
SSL_shutdown(ssl); do {
SSL_free(ssl); /* release connection state */ struct msg_t buf;
} while (bytes); memset(&buf, 0, sizeof(struct msg_t));
bytes = SSL_read(ssl, &buf, sizeof(struct msg_t));
if (bytes > 0) {
short index;
if (bytes != sizeof(struct msg_t)) {
fprintf( stderr,
"Received non-standard data from server!\n" );
//conntinue;
return 1;
}
if (buf.chunk.id == 0) {
if ((index = get_job_slot()) == FAIL) {
buf.chunk.id = -1; /* ID -1 means reject (full) */
sprintf((char*)buf.chunk.data, "The agent's queue is full!");
SSL_write(ssl, &buf, sizeof(struct msg_t));
continue;
}
args[index].slot = FULL;
memcpy(&args[index].buf, &buf, sizeof(struct msg_t));
switch (args[index].buf.meta.type) {
case UNIX:
pthread_create( &job_thread[index],
NULL,
exec_unix,
&args[index] );
continue;
case INSTALL_PKG:
pthread_create( &job_thread[index],
NULL,
install_pkg,
&args[index] );
continue;
case QUERY_PKG:
pthread_create( &job_thread[index],
NULL,
query_pkg,
&args[index] );
continue;
case DELETE_PKG:
pthread_create( &job_thread[index],
NULL,
delete_pkg,
&args[index] );
continue;
case LIST_PKGS:
pthread_create( &job_thread[index],
NULL,
list_pkgs,
&args[index] );
continue;
case UPDATE_PKG:
pthread_create( &job_thread[index],
NULL,
update_pkg,
&args[index] );
continue;
case UPDATE_PKGS:
pthread_create( &job_thread[index],
NULL,
update_pkgs,
&args[index] );
continue;
case GET_OS:
pthread_create( &job_thread[index],
NULL,
get_os,
&args[index] );
continue;
case GET_KERNEL:
pthread_create( &job_thread[index],
NULL,
get_kernel,
&args[index] );
continue;
case GET_UPTIME:
pthread_create( &job_thread[index],
NULL,
get_uptime,
&args[index] );
continue;
case GET_MEMORY:
pthread_create( &job_thread[index],
NULL,
get_memory,
&args[index] );
continue;
default:
buf.chunk.id = -1;
sprintf( (char*)buf.chunk.data,
"Unsupported job type with ID: %d",
buf.meta.type );
SSL_write(ssl, &buf, sizeof(struct msg_t));
continue;
}
} else {
index = find_job(buf.meta.id);
if (index == FAIL) {
sprintf( (char*)buf.chunk.data,
"Data was sent for an invalid job ID" );
SSL_write(ssl, &buf, sizeof(struct msg_t));
} else
memcpy(&args[index].buf, &buf, sizeof(struct msg_t));
}
}
SSL_shutdown(ssl);
SSL_free(ssl); /* release connection state */
} while (bytes);
if (SSL_get_shutdown(ssl)) {
printf("RMPS server has shutdown, trying to reconnect...\n");
sleep(5);
} else {
err = SSL_get_error(ssl, bytes);
switch (err) {
case SSL_ERROR_WANT_WRITE:
printf("want_write\n");
return 0;
case SSL_ERROR_WANT_READ:
printf("want_read\n");
return 0;
case SSL_ERROR_ZERO_RETURN:
printf("zero_return\n");
return -1;
case SSL_ERROR_SYSCALL:
printf("syscall\n");
return -1;
case SSL_ERROR_SSL:
printf("ssl\n");
return -1;
default:
return -1;
}
}
}
} }

View File

@@ -49,7 +49,6 @@ int connect_to_rmps(const char *hostname, int port)
addr.sin_addr.s_addr = *(long*)(host->h_addr); addr.sin_addr.s_addr = *(long*)(host->h_addr);
if (connect(sd, (struct sockaddr*)&addr, sizeof(addr)) == FAIL) { if (connect(sd, (struct sockaddr*)&addr, sizeof(addr)) == FAIL) {
close(sd); close(sd);
perror(hostname);
return 0; return 0;
} }
return sd; return sd;