diff --git a/src/main.h b/src/main.h index 9a1dc17..12ad584 100644 --- a/src/main.h +++ b/src/main.h @@ -45,9 +45,9 @@ static int response_body(http_parser *, const char *, size_t); static uint64_t time_us(); -static char *extract_url_part(char *, struct http_parser_url *, enum http_parser_url_fields); +static int parse_args(struct config *, char **, struct http_parser_url *, char **, int, char **); +static char *copy_url_part(char *, struct http_parser_url *, enum http_parser_url_fields); -static int parse_args(struct config *, char **, char **, int, char **); static void print_stats_header(); static void print_stats(char *, stats *, char *(*)(long double)); static void print_stats_latency(stats *); diff --git a/src/script.c b/src/script.c index 475370f..f070b20 100644 --- a/src/script.c +++ b/src/script.c @@ -23,6 +23,8 @@ static int script_wrk_lookup(lua_State *); static int script_wrk_connect(lua_State *); static void set_fields(lua_State *, int index, const table_field *); +static void set_string(lua_State *, int, char *, char *, size_t); +static char *get_url_part(char *, struct http_parser_url *, enum http_parser_url_fields, size_t *); static const struct luaL_reg addrlib[] = { { "__tostring", script_addr_tostring }, @@ -43,7 +45,7 @@ static const struct luaL_reg threadlib[] = { { NULL, NULL } }; -lua_State *script_create(char *file, char *scheme, char *host, char *port, char *path) { +lua_State *script_create(char *file, char *url, char **headers) { lua_State *L = luaL_newstate(); luaL_openlibs(L); (void) luaL_dostring(L, "wrk = require \"wrk\""); @@ -55,19 +57,39 @@ lua_State *script_create(char *file, char *scheme, char *host, char *port, char luaL_newmetatable(L, "wrk.thread"); luaL_register(L, NULL, threadlib); + struct http_parser_url parts = {}; + script_parse_url(url, &parts); + char *path = "/"; + size_t len; + + if (parts.field_set & (1 << UF_PATH)) { + path = &url[parts.field_data[UF_PATH].off]; + } + const table_field fields[] = { - { "scheme", LUA_TSTRING, scheme }, - { "host", LUA_TSTRING, host }, - { "port", LUA_TSTRING, port }, - { "path", LUA_TSTRING, path }, { "lookup", LUA_TFUNCTION, script_wrk_lookup }, { "connect", LUA_TFUNCTION, script_wrk_connect }, + { "path", LUA_TSTRING, path }, { NULL, 0, NULL }, }; lua_getglobal(L, "wrk"); + + set_string(L, 4, "scheme", get_url_part(url, &parts, UF_SCHEMA, &len), len); + set_string(L, 4, "host", get_url_part(url, &parts, UF_HOST, &len), len); + set_string(L, 4, "port", get_url_part(url, &parts, UF_PORT, &len), len); set_fields(L, 4, fields); - lua_pop(L, 4); + + lua_getfield(L, 4, "headers"); + for (char **h = headers; *h; h++) { + char *p = strchr(*h, ':'); + if (p && p[1] == ' ') { + lua_pushlstring(L, *h, p - *h); + lua_pushstring(L, p + 2); + lua_settable(L, 5); + } + } + lua_pop(L, 5); if (file && luaL_dofile(L, file)) { const char *cause = lua_tostring(L, -1); @@ -98,43 +120,26 @@ void script_push_thread(lua_State *L, thread *t) { lua_setmetatable(L, -2); } -void script_setup(lua_State *L, thread *t) { +void script_init(lua_State *L, thread *t, int argc, char **argv) { lua_getglobal(t->L, "wrk"); + script_push_thread(t->L, t); lua_setfield(t->L, -2, "thread"); - lua_pop(t->L, 1); lua_getglobal(L, "wrk"); lua_getfield(L, -1, "setup"); script_push_thread(L, t); lua_call(L, 1, 0); lua_pop(L, 1); -} -void script_headers(lua_State *L, char **headers) { - lua_getglobal(L, "wrk"); - lua_getfield(L, 1, "headers"); - for (char **h = headers; *h; h++) { - char *p = strchr(*h, ':'); - if (p && p[1] == ' ') { - lua_pushlstring(L, *h, p - *h); - lua_pushstring(L, p + 2); - lua_settable(L, 2); - } - } - lua_pop(L, 2); -} - -void script_init(lua_State *L, int argc, char **argv) { - lua_getglobal(L, "wrk"); - lua_getfield(L, -1, "init"); - lua_newtable(L); + lua_getfield(t->L, -1, "init"); + lua_newtable(t->L); for (int i = 0; i < argc; i++) { - lua_pushstring(L, argv[i]); - lua_rawseti(L, -2, i); + lua_pushstring(t->L, argv[i]); + lua_rawseti(t->L, -2, i); } - lua_call(L, 1, 0); - lua_pop(L, 1); + lua_call(t->L, 1, 0); + lua_pop(t->L, 1); } void script_request(lua_State *L, char **buf, size_t *len) { @@ -488,6 +493,31 @@ void script_copy_value(lua_State *src, lua_State *dst, int index) { } } +int script_parse_url(char *url, struct http_parser_url *parts) { + if (!http_parser_parse_url(url, strlen(url), 0, parts)) { + if (!(parts->field_set & (1 << UF_SCHEMA))) return 0; + if (!(parts->field_set & (1 << UF_HOST))) return 0; + return 1; + } + return 0; +} + +static char *get_url_part(char *url, struct http_parser_url *parts, enum http_parser_url_fields field, size_t *len) { + char *value = NULL; + if (parts->field_set & (1 << field)) { + value = &url[parts->field_data[field].off]; + *len = parts->field_data[field].len; + } + return value; +} + +static void set_string(lua_State *L, int index, char *field, char *value, size_t len) { + if (value != NULL) { + lua_pushlstring(L, value, len); + lua_setfield(L, index, field); + } +} + static void set_fields(lua_State *L, int index, const table_field *fields) { for (int i = 0; fields[i].name; i++) { table_field f = fields[i]; diff --git a/src/script.h b/src/script.h index 54f7f17..f6861d4 100644 --- a/src/script.h +++ b/src/script.h @@ -9,14 +9,13 @@ #include "stats.h" #include "wrk.h" -lua_State *script_create(char *, char *, char *, char *, char *); +lua_State *script_create(char *, char *, char **); bool script_resolve(lua_State *, char *, char *); void script_setup(lua_State *, thread *); void script_done(lua_State *, stats *, stats *); -void script_headers(lua_State *, char **); -void script_init(lua_State *, int, char **); +void script_init(lua_State *, thread *, int, char **); void script_request(lua_State *, char **, size_t *); void script_response(lua_State *, int, buffer *, buffer *); size_t script_verify_request(lua_State *L); @@ -28,6 +27,7 @@ void script_summary(lua_State *, uint64_t, uint64_t, uint64_t); void script_errors(lua_State *, errors *); void script_copy_value(lua_State *, lua_State *, int); +int script_parse_url(char *, struct http_parser_url *); void buffer_append(buffer *, const char *, size_t); void buffer_reset(buffer *); diff --git a/src/wrk.c b/src/wrk.c index a212301..9dfa3c7 100644 --- a/src/wrk.c +++ b/src/wrk.c @@ -57,30 +57,18 @@ static void usage() { } int main(int argc, char **argv) { - struct http_parser_url parser_url; - char *url, **headers; + char *url, **headers = zmalloc(argc * sizeof(char *)); + struct http_parser_url parts = {}; - headers = zmalloc((argc / 2) * sizeof(char *)); - - if (parse_args(&cfg, &url, headers, argc, argv)) { + if (parse_args(&cfg, &url, &parts, headers, argc, argv)) { usage(); exit(1); } - if (http_parser_parse_url(url, strlen(url), 0, &parser_url)) { - fprintf(stderr, "invalid URL: %s\n", url); - exit(1); - } - - char *schema = extract_url_part(url, &parser_url, UF_SCHEMA); - char *host = extract_url_part(url, &parser_url, UF_HOST); - char *port = extract_url_part(url, &parser_url, UF_PORT); + char *schema = copy_url_part(url, &parts, UF_SCHEMA); + char *host = copy_url_part(url, &parts, UF_HOST); + char *port = copy_url_part(url, &parts, UF_PORT); char *service = port ? port : schema; - char *path = "/"; - - if (parser_url.field_set & (1 << UF_PATH)) { - path = &url[parser_url.field_data[UF_PATH].off]; - } if (!strncmp("https", schema, 5)) { if ((cfg.ctx = ssl_init()) == NULL) { @@ -100,10 +88,9 @@ int main(int argc, char **argv) { statistics.latency = stats_alloc(cfg.timeout * 1000); statistics.requests = stats_alloc(MAX_THREAD_RATE_S); + thread *threads = zcalloc(cfg.threads * sizeof(thread)); - thread *threads = zcalloc(cfg.threads * sizeof(thread)); - - lua_State *L = script_create(cfg.script, schema, host, port, path); + lua_State *L = script_create(cfg.script, url, headers); if (!script_resolve(L, host, service)) { char *msg = strerror(errno); fprintf(stderr, "unable to connect to %s:%s %s\n", host, service, msg); @@ -116,10 +103,8 @@ int main(int argc, char **argv) { t->connections = cfg.connections / cfg.threads; t->latency = stats_alloc(cfg.timeout * 1000); - t->L = script_create(cfg.script, schema, host, port, path); - script_headers(t->L, headers); - script_setup(L, t); - script_init(t->L, argc - optind, &argv[optind]); + t->L = script_create(cfg.script, url, headers); + script_init(L, t, argc - optind, &argv[optind]); if (i == 0) { cfg.pipeline = script_verify_request(t->L); @@ -468,12 +453,12 @@ static uint64_t time_us() { return (t.tv_sec * 1000000) + t.tv_usec; } -static char *extract_url_part(char *url, struct http_parser_url *parser_url, enum http_parser_url_fields field) { +static char *copy_url_part(char *url, struct http_parser_url *parts, enum http_parser_url_fields field) { char *part = NULL; - if (parser_url->field_set & (1 << field)) { - uint16_t off = parser_url->field_data[field].off; - uint16_t len = parser_url->field_data[field].len; + if (parts->field_set & (1 << field)) { + uint16_t off = parts->field_data[field].off; + uint16_t len = parts->field_data[field].len; part = zcalloc(len + 1 * sizeof(char)); memcpy(part, &url[off], len); } @@ -494,7 +479,7 @@ static struct option longopts[] = { { NULL, 0, NULL, 0 } }; -static int parse_args(struct config *cfg, char **url, char **headers, int argc, char **argv) { +static int parse_args(struct config *cfg, char **url, struct http_parser_url *parts, char **headers, int argc, char **argv) { char **header = headers; int c; @@ -542,6 +527,11 @@ static int parse_args(struct config *cfg, char **url, char **headers, int argc, if (optind == argc || !cfg->threads || !cfg->duration) return -1; + if (!script_parse_url(argv[optind], parts)) { + fprintf(stderr, "invalid URL: %s\n", argv[optind]); + return -1; + } + if (!cfg->connections || cfg->connections < cfg->threads) { fprintf(stderr, "number of connections must be >= threads\n"); return -1;