1
0
mirror of https://github.com/wg/wrk synced 2025-01-23 04:02:59 +08:00

move address resolution into lua

This commit is contained in:
Will 2015-02-07 14:08:30 +09:00
parent 93348e2814
commit 6f0aa32ede
6 changed files with 173 additions and 55 deletions

View File

@ -6,7 +6,6 @@
#include <fcntl.h>
#include <getopt.h>
#include <math.h>
#include <netdb.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <stdarg.h>

View File

@ -4,6 +4,7 @@
#include <string.h>
#include "script.h"
#include "http_parser.h"
#include "zmalloc.h"
typedef struct {
char *name;
@ -11,10 +12,20 @@ typedef struct {
void *value;
} table_field;
static int script_addr_tostring(lua_State *);
static int script_addr_gc(lua_State *);
static int script_stats_len(lua_State *);
static int script_stats_get(lua_State *);
static int script_wrk_lookup(lua_State *);
static int script_wrk_connect(lua_State *);
static void set_fields(lua_State *, int index, const table_field *);
static const struct luaL_reg addrlib[] = {
{ "__tostring", script_addr_tostring },
{ "__gc" , script_addr_gc },
{ NULL, NULL }
};
static const struct luaL_reg statslib[] = {
{ "__index", script_stats_get },
{ "__len", script_stats_len },
@ -26,16 +37,22 @@ lua_State *script_create(char *scheme, char *host, char *port, char *path) {
luaL_openlibs(L);
luaL_dostring(L, "wrk = require \"wrk\"");
luaL_newmetatable(L, "wrk.addr");
luaL_register(L, NULL, addrlib);
lua_pop(L, 1);
luaL_newmetatable(L, "wrk.stats");
luaL_register(L, NULL, statslib);
lua_pop(L, 1);
const table_field fields[] = {
{ "scheme", LUA_TSTRING, scheme },
{ "host", LUA_TSTRING, host },
{ "port", LUA_TSTRING, port },
{ "path", LUA_TSTRING, path },
{ NULL, 0, NULL },
{ "scheme", LUA_TSTRING, scheme },
{ "host", LUA_TSTRING, host },
{ "port", LUA_TSTRING, port },
{ "path", LUA_TSTRING, path },
{ "lookup", LUA_TFUNCTION, script_wrk_lookup },
{ "connect", LUA_TFUNCTION, script_wrk_connect },
{ NULL, 0, NULL },
};
lua_getglobal(L, "wrk");
@ -45,6 +62,36 @@ lua_State *script_create(char *scheme, char *host, char *port, char *path) {
return L;
}
void script_prepare_setup(lua_State *L, char *script) {
if (script && luaL_dofile(L, script)) {
const char *cause = lua_tostring(L, -1);
fprintf(stderr, "%s: %s\n", script, cause);
}
}
bool script_resolve(lua_State *L, char *host, char *service) {
lua_getglobal(L, "wrk");
lua_getfield(L, -1, "resolve");
lua_pushstring(L, host);
lua_pushstring(L, service);
lua_call(L, 2, 0);
lua_getfield(L, -1, "addrs");
size_t count = lua_objlen(L, -1);
lua_pop(L, 2);
return count > 0;
}
struct addrinfo *script_peek_addr(lua_State *L) {
lua_getglobal(L, "wrk");
lua_getfield(L, -1, "addrs");
lua_rawgeti(L, -1, 1);
struct addrinfo *addr = lua_touserdata(L, -1);
lua_pop(L, 3);
return addr;
}
void script_headers(lua_State *L, char **headers) {
lua_getglobal(L, "wrk");
lua_getfield(L, 1, "headers");
@ -225,6 +272,34 @@ size_t script_verify_request(lua_State *L) {
return count;
}
static struct addrinfo *checkaddr(lua_State *L) {
struct addrinfo *addr = luaL_checkudata(L, -1, "wrk.addr");
luaL_argcheck(L, addr != NULL, 1, "`addr' expected");
return addr;
}
static int script_addr_tostring(lua_State *L) {
struct addrinfo *addr = checkaddr(L);
char host[NI_MAXHOST];
char service[NI_MAXSERV];
int flags = NI_NUMERICHOST | NI_NUMERICSERV;
int rc = getnameinfo(addr->ai_addr, addr->ai_addrlen, host, NI_MAXHOST, service, NI_MAXSERV, flags);
if (rc != 0) {
const char *msg = gai_strerror(rc);
return luaL_error(L, "addr tostring failed %s", msg);
}
lua_pushfstring(L, "%s:%s", host, service);
return 1;
}
static int script_addr_gc(lua_State *L) {
struct addrinfo *addr = checkaddr(L);
zfree(addr->ai_addr);
return 0;
}
static stats *checkstats(lua_State *L) {
stats **s = luaL_checkudata(L, 1, "wrk.stats");
luaL_argcheck(L, s != NULL, 1, "`stats' expected");
@ -262,10 +337,57 @@ static int script_stats_len(lua_State *L) {
return 1;
}
static int script_wrk_lookup(lua_State *L) {
struct addrinfo *addrs;
struct addrinfo hints = {
.ai_family = AF_UNSPEC,
.ai_socktype = SOCK_STREAM
};
int rc, index = 1;
const char *host = lua_tostring(L, -2);
const char *service = lua_tostring(L, -1);
if ((rc = getaddrinfo(host, service, &hints, &addrs)) != 0) {
const char *msg = gai_strerror(rc);
fprintf(stderr, "unable to resolve %s:%s %s\n", host, service, msg);
exit(1);
}
lua_newtable(L);
for (struct addrinfo *addr = addrs; addr != NULL; addr = addr->ai_next) {
struct addrinfo *udata = lua_newuserdata(L, sizeof(*udata));
luaL_getmetatable(L, "wrk.addr");
lua_setmetatable(L, -2);
*udata = *addr;
udata->ai_addr = zmalloc(addr->ai_addrlen);
memcpy(udata->ai_addr, addr->ai_addr, addr->ai_addrlen);
lua_rawseti(L, -2, index++);
}
freeaddrinfo(addrs);
return 1;
}
static int script_wrk_connect(lua_State *L) {
struct addrinfo *addr = checkaddr(L);
int fd, connected = 0;
if ((fd = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol)) != -1) {
connected = connect(fd, addr->ai_addr, addr->ai_addrlen) == 0;
close(fd);
}
lua_pushboolean(L, connected);
return 1;
}
static void set_fields(lua_State *L, int index, const table_field *fields) {
for (int i = 0; fields[i].name; i++) {
table_field f = fields[i];
switch (f.value == NULL ? LUA_TNIL : f.type) {
case LUA_TFUNCTION:
lua_pushcfunction(L, (lua_CFunction) f.value);
break;
case LUA_TNUMBER:
lua_pushinteger(L, *((lua_Integer *) f.value));
break;

View File

@ -5,6 +5,9 @@
#include <lua.h>
#include <lualib.h>
#include <lauxlib.h>
#include <sys/types.h>
#include <netdb.h>
#include <unistd.h>
#include "stats.h"
typedef struct {
@ -14,6 +17,9 @@ typedef struct {
} buffer;
lua_State *script_create(char *, char *, char *, char *);
void script_prepare_setup(lua_State *, char *);
bool script_resolve(lua_State *, char *, char *);
struct addrinfo *script_peek_addr(lua_State *);
void script_headers(lua_State *, char **);
size_t script_verify_request(lua_State *L);

View File

@ -4,7 +4,6 @@
#include "main.h"
static struct config {
struct addrinfo addr;
uint64_t threads;
uint64_t connections;
uint64_t duration;
@ -58,10 +57,8 @@ static void usage() {
}
int main(int argc, char **argv) {
struct addrinfo *addrs, *addr;
struct http_parser_url parser_url;
char *url, **headers;
int rc;
headers = zmalloc((argc / 2) * sizeof(char *));
@ -85,26 +82,9 @@ int main(int argc, char **argv) {
path = &url[parser_url.field_data[UF_PATH].off];
}
struct addrinfo hints = {
.ai_family = AF_UNSPEC,
.ai_socktype = SOCK_STREAM
};
if ((rc = getaddrinfo(host, service, &hints, &addrs)) != 0) {
const char *msg = gai_strerror(rc);
fprintf(stderr, "unable to resolve %s:%s %s\n", host, service, msg);
exit(1);
}
for (addr = addrs; addr != NULL; addr = addr->ai_next) {
int fd = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
if (fd == -1) continue;
rc = connect(fd, addr->ai_addr, addr->ai_addrlen);
close(fd);
if (rc == 0) break;
}
if (addr == NULL) {
lua_State *L = script_create(schema, host, port, path);
script_prepare_setup(L, cfg.script);
if (!script_resolve(L, host, service)) {
char *msg = strerror(errno);
fprintf(stderr, "unable to connect to %s:%s %s\n", host, service, msg);
exit(1);
@ -125,7 +105,6 @@ int main(int argc, char **argv) {
signal(SIGPIPE, SIG_IGN);
signal(SIGINT, SIG_IGN);
cfg.addr = *addr;
pthread_mutex_init(&statistics.mutex, NULL);
statistics.latency = stats_alloc(SAMPLES);
@ -138,6 +117,7 @@ int main(int argc, char **argv) {
for (uint64_t i = 0; i < cfg.threads; i++) {
thread *t = &threads[i];
t->loop = aeCreateEventLoop(10 + cfg.connections * 3);
t->addr = script_peek_addr(L);
t->connections = connections;
t->stop_at = stop_at;
@ -217,7 +197,6 @@ int main(int argc, char **argv) {
printf("Requests/sec: %9.2Lf\n", req_per_s);
printf("Transfer/sec: %10sB\n", format_binary(bytes_per_s));
lua_State *L = threads[0].L;
if (script_has_done(L)) {
script_summary(L, runtime_us, complete, bytes);
script_errors(L, &errors);
@ -274,16 +253,16 @@ void *thread_main(void *arg) {
}
static int connect_socket(thread *thread, connection *c) {
struct addrinfo addr = cfg.addr;
struct addrinfo *addr = thread->addr;
struct aeEventLoop *loop = thread->loop;
int fd, flags;
fd = socket(addr.ai_family, addr.ai_socktype, addr.ai_protocol);
fd = socket(addr->ai_family, addr->ai_socktype, addr->ai_protocol);
flags = fcntl(fd, F_GETFL, 0);
fcntl(fd, F_SETFL, flags | O_NONBLOCK);
if (connect(fd, addr.ai_addr, addr.ai_addrlen) == -1) {
if (connect(fd, addr->ai_addr, addr->ai_addrlen) == -1) {
if (errno != EINPROGRESS) goto error;
}

View File

@ -5,6 +5,7 @@
#include <pthread.h>
#include <inttypes.h>
#include <sys/types.h>
#include <netdb.h>
#include <openssl/ssl.h>
#include <openssl/err.h>
@ -25,6 +26,7 @@
typedef struct {
pthread_t thread;
aeEventLoop *loop;
struct addrinfo *addr;
uint64_t connections;
int interval;
uint64_t stop_at;

View File

@ -8,6 +8,37 @@ local wrk = {
body = nil
}
function wrk.resolve(host, service)
local addrs = wrk.lookup(host, service)
for i = #addrs, 1, -1 do
if not wrk.connect(addrs[i]) then
table.remove(addrs, i)
end
end
wrk.addrs = addrs
end
function wrk.init(args)
if not wrk.headers["Host"] then
local host = wrk.host
local port = wrk.port
host = host:find(":") and ("[" .. host .. "]") or host
host = port and (host .. ":" .. port) or host
wrk.headers["Host"] = host
end
if type(init) == "function" then
init(args)
end
local req = wrk.format()
wrk.request = function()
return req
end
end
function wrk.format(method, path, headers, body)
local method = method or wrk.method
local path = path or wrk.path
@ -32,25 +63,4 @@ function wrk.format(method, path, headers, body)
return table.concat(s, "\r\n")
end
function wrk.init(args)
if not wrk.headers["Host"] then
local host = wrk.host
local port = wrk.port
host = host:find(":") and ("[" .. host .. "]") or host
host = port and (host .. ":" .. port) or host
wrk.headers["Host"] = host
end
if type(init) == "function" then
init(args)
end
local req = wrk.format()
wrk.request = function()
return req
end
end
return wrk