From 6f0aa32ede58ce82b5dfe591b711b00f9b5f1054 Mon Sep 17 00:00:00 2001 From: Will Date: Sat, 7 Feb 2015 14:08:30 +0900 Subject: [PATCH] move address resolution into lua --- src/main.h | 1 - src/script.c | 132 +++++++++++++++++++++++++++++++++++++++++++++++++-- src/script.h | 6 +++ src/wrk.c | 35 +++----------- src/wrk.h | 2 + src/wrk.lua | 52 ++++++++++++-------- 6 files changed, 173 insertions(+), 55 deletions(-) diff --git a/src/main.h b/src/main.h index a1e1cdb..a101ae2 100644 --- a/src/main.h +++ b/src/main.h @@ -6,7 +6,6 @@ #include #include #include -#include #include #include #include diff --git a/src/script.c b/src/script.c index 7561e3c..0d2799d 100644 --- a/src/script.c +++ b/src/script.c @@ -4,6 +4,7 @@ #include #include "script.h" #include "http_parser.h" +#include "zmalloc.h" typedef struct { char *name; @@ -11,10 +12,20 @@ typedef struct { void *value; } table_field; +static int script_addr_tostring(lua_State *); +static int script_addr_gc(lua_State *); static int script_stats_len(lua_State *); static int script_stats_get(lua_State *); +static int script_wrk_lookup(lua_State *); +static int script_wrk_connect(lua_State *); static void set_fields(lua_State *, int index, const table_field *); +static const struct luaL_reg addrlib[] = { + { "__tostring", script_addr_tostring }, + { "__gc" , script_addr_gc }, + { NULL, NULL } +}; + static const struct luaL_reg statslib[] = { { "__index", script_stats_get }, { "__len", script_stats_len }, @@ -26,16 +37,22 @@ lua_State *script_create(char *scheme, char *host, char *port, char *path) { luaL_openlibs(L); luaL_dostring(L, "wrk = require \"wrk\""); + luaL_newmetatable(L, "wrk.addr"); + luaL_register(L, NULL, addrlib); + lua_pop(L, 1); + luaL_newmetatable(L, "wrk.stats"); luaL_register(L, NULL, statslib); lua_pop(L, 1); const table_field fields[] = { - { "scheme", LUA_TSTRING, scheme }, - { "host", LUA_TSTRING, host }, - { "port", LUA_TSTRING, port }, - { "path", LUA_TSTRING, path }, - { NULL, 0, NULL }, + { "scheme", LUA_TSTRING, scheme }, + { "host", LUA_TSTRING, host }, + { "port", LUA_TSTRING, port }, + { "path", LUA_TSTRING, path }, + { "lookup", LUA_TFUNCTION, script_wrk_lookup }, + { "connect", LUA_TFUNCTION, script_wrk_connect }, + { NULL, 0, NULL }, }; lua_getglobal(L, "wrk"); @@ -45,6 +62,36 @@ lua_State *script_create(char *scheme, char *host, char *port, char *path) { return L; } +void script_prepare_setup(lua_State *L, char *script) { + if (script && luaL_dofile(L, script)) { + const char *cause = lua_tostring(L, -1); + fprintf(stderr, "%s: %s\n", script, cause); + } +} + +bool script_resolve(lua_State *L, char *host, char *service) { + lua_getglobal(L, "wrk"); + + lua_getfield(L, -1, "resolve"); + lua_pushstring(L, host); + lua_pushstring(L, service); + lua_call(L, 2, 0); + + lua_getfield(L, -1, "addrs"); + size_t count = lua_objlen(L, -1); + lua_pop(L, 2); + return count > 0; +} + +struct addrinfo *script_peek_addr(lua_State *L) { + lua_getglobal(L, "wrk"); + lua_getfield(L, -1, "addrs"); + lua_rawgeti(L, -1, 1); + struct addrinfo *addr = lua_touserdata(L, -1); + lua_pop(L, 3); + return addr; +} + void script_headers(lua_State *L, char **headers) { lua_getglobal(L, "wrk"); lua_getfield(L, 1, "headers"); @@ -225,6 +272,34 @@ size_t script_verify_request(lua_State *L) { return count; } +static struct addrinfo *checkaddr(lua_State *L) { + struct addrinfo *addr = luaL_checkudata(L, -1, "wrk.addr"); + luaL_argcheck(L, addr != NULL, 1, "`addr' expected"); + return addr; +} + +static int script_addr_tostring(lua_State *L) { + struct addrinfo *addr = checkaddr(L); + char host[NI_MAXHOST]; + char service[NI_MAXSERV]; + + int flags = NI_NUMERICHOST | NI_NUMERICSERV; + int rc = getnameinfo(addr->ai_addr, addr->ai_addrlen, host, NI_MAXHOST, service, NI_MAXSERV, flags); + if (rc != 0) { + const char *msg = gai_strerror(rc); + return luaL_error(L, "addr tostring failed %s", msg); + } + + lua_pushfstring(L, "%s:%s", host, service); + return 1; +} + +static int script_addr_gc(lua_State *L) { + struct addrinfo *addr = checkaddr(L); + zfree(addr->ai_addr); + return 0; +} + static stats *checkstats(lua_State *L) { stats **s = luaL_checkudata(L, 1, "wrk.stats"); luaL_argcheck(L, s != NULL, 1, "`stats' expected"); @@ -262,10 +337,57 @@ static int script_stats_len(lua_State *L) { return 1; } +static int script_wrk_lookup(lua_State *L) { + struct addrinfo *addrs; + struct addrinfo hints = { + .ai_family = AF_UNSPEC, + .ai_socktype = SOCK_STREAM + }; + int rc, index = 1; + + const char *host = lua_tostring(L, -2); + const char *service = lua_tostring(L, -1); + + if ((rc = getaddrinfo(host, service, &hints, &addrs)) != 0) { + const char *msg = gai_strerror(rc); + fprintf(stderr, "unable to resolve %s:%s %s\n", host, service, msg); + exit(1); + } + + lua_newtable(L); + for (struct addrinfo *addr = addrs; addr != NULL; addr = addr->ai_next) { + struct addrinfo *udata = lua_newuserdata(L, sizeof(*udata)); + luaL_getmetatable(L, "wrk.addr"); + lua_setmetatable(L, -2); + + *udata = *addr; + udata->ai_addr = zmalloc(addr->ai_addrlen); + memcpy(udata->ai_addr, addr->ai_addr, addr->ai_addrlen); + lua_rawseti(L, -2, index++); + } + + freeaddrinfo(addrs); + return 1; +} + +static int script_wrk_connect(lua_State *L) { + struct addrinfo *addr = checkaddr(L); + int fd, connected = 0; + if ((fd = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol)) != -1) { + connected = connect(fd, addr->ai_addr, addr->ai_addrlen) == 0; + close(fd); + } + lua_pushboolean(L, connected); + return 1; +} + static void set_fields(lua_State *L, int index, const table_field *fields) { for (int i = 0; fields[i].name; i++) { table_field f = fields[i]; switch (f.value == NULL ? LUA_TNIL : f.type) { + case LUA_TFUNCTION: + lua_pushcfunction(L, (lua_CFunction) f.value); + break; case LUA_TNUMBER: lua_pushinteger(L, *((lua_Integer *) f.value)); break; diff --git a/src/script.h b/src/script.h index 72a3c32..3baccd9 100644 --- a/src/script.h +++ b/src/script.h @@ -5,6 +5,9 @@ #include #include #include +#include +#include +#include #include "stats.h" typedef struct { @@ -14,6 +17,9 @@ typedef struct { } buffer; lua_State *script_create(char *, char *, char *, char *); +void script_prepare_setup(lua_State *, char *); +bool script_resolve(lua_State *, char *, char *); +struct addrinfo *script_peek_addr(lua_State *); void script_headers(lua_State *, char **); size_t script_verify_request(lua_State *L); diff --git a/src/wrk.c b/src/wrk.c index 8eff169..b77a838 100644 --- a/src/wrk.c +++ b/src/wrk.c @@ -4,7 +4,6 @@ #include "main.h" static struct config { - struct addrinfo addr; uint64_t threads; uint64_t connections; uint64_t duration; @@ -58,10 +57,8 @@ static void usage() { } int main(int argc, char **argv) { - struct addrinfo *addrs, *addr; struct http_parser_url parser_url; char *url, **headers; - int rc; headers = zmalloc((argc / 2) * sizeof(char *)); @@ -85,26 +82,9 @@ int main(int argc, char **argv) { path = &url[parser_url.field_data[UF_PATH].off]; } - struct addrinfo hints = { - .ai_family = AF_UNSPEC, - .ai_socktype = SOCK_STREAM - }; - - if ((rc = getaddrinfo(host, service, &hints, &addrs)) != 0) { - const char *msg = gai_strerror(rc); - fprintf(stderr, "unable to resolve %s:%s %s\n", host, service, msg); - exit(1); - } - - for (addr = addrs; addr != NULL; addr = addr->ai_next) { - int fd = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); - if (fd == -1) continue; - rc = connect(fd, addr->ai_addr, addr->ai_addrlen); - close(fd); - if (rc == 0) break; - } - - if (addr == NULL) { + lua_State *L = script_create(schema, host, port, path); + script_prepare_setup(L, cfg.script); + if (!script_resolve(L, host, service)) { char *msg = strerror(errno); fprintf(stderr, "unable to connect to %s:%s %s\n", host, service, msg); exit(1); @@ -125,7 +105,6 @@ int main(int argc, char **argv) { signal(SIGPIPE, SIG_IGN); signal(SIGINT, SIG_IGN); - cfg.addr = *addr; pthread_mutex_init(&statistics.mutex, NULL); statistics.latency = stats_alloc(SAMPLES); @@ -138,6 +117,7 @@ int main(int argc, char **argv) { for (uint64_t i = 0; i < cfg.threads; i++) { thread *t = &threads[i]; t->loop = aeCreateEventLoop(10 + cfg.connections * 3); + t->addr = script_peek_addr(L); t->connections = connections; t->stop_at = stop_at; @@ -217,7 +197,6 @@ int main(int argc, char **argv) { printf("Requests/sec: %9.2Lf\n", req_per_s); printf("Transfer/sec: %10sB\n", format_binary(bytes_per_s)); - lua_State *L = threads[0].L; if (script_has_done(L)) { script_summary(L, runtime_us, complete, bytes); script_errors(L, &errors); @@ -274,16 +253,16 @@ void *thread_main(void *arg) { } static int connect_socket(thread *thread, connection *c) { - struct addrinfo addr = cfg.addr; + struct addrinfo *addr = thread->addr; struct aeEventLoop *loop = thread->loop; int fd, flags; - fd = socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol); + fd = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol); flags = fcntl(fd, F_GETFL, 0); fcntl(fd, F_SETFL, flags | O_NONBLOCK); - if (connect(fd, addr.ai_addr, addr.ai_addrlen) == -1) { + if (connect(fd, addr->ai_addr, addr->ai_addrlen) == -1) { if (errno != EINPROGRESS) goto error; } diff --git a/src/wrk.h b/src/wrk.h index 4ee683f..bf47350 100644 --- a/src/wrk.h +++ b/src/wrk.h @@ -5,6 +5,7 @@ #include #include #include +#include #include #include @@ -25,6 +26,7 @@ typedef struct { pthread_t thread; aeEventLoop *loop; + struct addrinfo *addr; uint64_t connections; int interval; uint64_t stop_at; diff --git a/src/wrk.lua b/src/wrk.lua index 17dc6d7..cb43cde 100644 --- a/src/wrk.lua +++ b/src/wrk.lua @@ -8,6 +8,37 @@ local wrk = { body = nil } +function wrk.resolve(host, service) + local addrs = wrk.lookup(host, service) + for i = #addrs, 1, -1 do + if not wrk.connect(addrs[i]) then + table.remove(addrs, i) + end + end + wrk.addrs = addrs +end + +function wrk.init(args) + if not wrk.headers["Host"] then + local host = wrk.host + local port = wrk.port + + host = host:find(":") and ("[" .. host .. "]") or host + host = port and (host .. ":" .. port) or host + + wrk.headers["Host"] = host + end + + if type(init) == "function" then + init(args) + end + + local req = wrk.format() + wrk.request = function() + return req + end +end + function wrk.format(method, path, headers, body) local method = method or wrk.method local path = path or wrk.path @@ -32,25 +63,4 @@ function wrk.format(method, path, headers, body) return table.concat(s, "\r\n") end -function wrk.init(args) - if not wrk.headers["Host"] then - local host = wrk.host - local port = wrk.port - - host = host:find(":") and ("[" .. host .. "]") or host - host = port and (host .. ":" .. port) or host - - wrk.headers["Host"] = host - end - - if type(init) == "function" then - init(args) - end - - local req = wrk.format() - wrk.request = function() - return req - end -end - return wrk