diff --git a/ext/socket/extconf.rb b/ext/socket/extconf.rb index 0cc8a88..42cfb06 100644 --- a/ext/socket/extconf.rb +++ b/ext/socket/extconf.rb @@ -444,6 +444,7 @@ def test_recvmsg_with_msg_peek_creates_fds(headers) test_func = "socket(0,0,0)" have_library("nsl", 't_open("", 0, (struct t_info *)NULL)', headers) # SunOS have_library("socket", "socket(0,0,0)", headers) # SunOS + have_library("anl", 'getaddrinfo_a', headers) end if have_func(test_func, headers) @@ -505,6 +506,7 @@ def test_recvmsg_with_msg_peek_creates_fds(headers) unless have_func("gethostname((char *)0, 0)", headers) have_func("uname((struct utsname *)NULL)", headers) end + have_func("getaddrinfo_a", headers) ipv6 = false default_ipv6 = /haiku/ !~ RUBY_PLATFORM diff --git a/ext/socket/lib/socket.rb b/ext/socket/lib/socket.rb index 4ed2df2..8ba3a5a 100644 --- a/ext/socket/lib/socket.rb +++ b/ext/socket/lib/socket.rb @@ -215,6 +215,22 @@ def listen(backlog=Socket::SOMAXCONN) end end + def self.getaddrinfo(nodename, servname, family=nil, socktype=nil, protocol=nil, flags=nil, timeout: nil) + if timeout + if defined?(_getaddrinfo_a) + _getaddrinfo_a(nodename, servname, family, socktype, protocol, flags, timeout: timeout) + else + require "timeout" + + Timeout.timeout(timeout, SocketError) do + _getaddrinfo(nodename, servname, family, socktype, protocol, flags) + end + end + else + _getaddrinfo(nodename, servname, family, socktype, protocol, flags) + end + end + # iterates over the list of Addrinfo objects obtained by Addrinfo.getaddrinfo. # # Addrinfo.foreach(nil, 80) {|x| p x } @@ -223,8 +239,8 @@ def listen(backlog=Socket::SOMAXCONN) # # # # # # # - def self.foreach(nodename, service, family=nil, socktype=nil, protocol=nil, flags=nil, &block) - Addrinfo.getaddrinfo(nodename, service, family, socktype, protocol, flags).each(&block) + def self.foreach(nodename, service, family=nil, socktype=nil, protocol=nil, flags=nil, timeout: nil, &block) + Addrinfo.getaddrinfo(nodename, service, family, socktype, protocol, flags, timeout: timeout).each(&block) end end @@ -606,6 +622,7 @@ def accept_nonblock(exception: true) # _opts_ may have following options: # # [:connect_timeout] specify the timeout in seconds. + # [:resolv_timeout] specify the name resolution timeout in seconds. # # If a block is given, the block is called with the socket. # The value of the block is returned. @@ -619,7 +636,7 @@ def accept_nonblock(exception: true) # puts sock.read # } # - def self.tcp(host, port, local_host = nil, local_port = nil, connect_timeout: nil) # :yield: socket + def self.tcp(host, port, local_host = nil, local_port = nil, connect_timeout: nil, resolv_timeout: nil) # :yield: socket last_error = nil ret = nil @@ -628,7 +645,7 @@ def self.tcp(host, port, local_host = nil, local_port = nil, connect_timeout: ni local_addr_list = Addrinfo.getaddrinfo(local_host, local_port, nil, :STREAM, nil) end - Addrinfo.foreach(host, port, nil, :STREAM) {|ai| + Addrinfo.foreach(host, port, nil, :STREAM, timeout: resolv_timeout) {|ai| if local_addr_list local_addr = local_addr_list.find {|local_ai| local_ai.afamily == ai.afamily } next unless local_addr diff --git a/ext/socket/raddrinfo.c b/ext/socket/raddrinfo.c index a6abad6..981a699 100644 --- a/ext/socket/raddrinfo.c +++ b/ext/socket/raddrinfo.c @@ -9,6 +9,7 @@ ************************************************/ #include "rubysocket.h" +#include "hrtime.h" #if defined(INET6) && (defined(LOOKUP_ORDER_HACK_INET) || defined(LOOKUP_ORDER_HACK_INET6)) #define LOOKUP_ORDERS (sizeof(lookup_order_table) / sizeof(lookup_order_table[0])) @@ -199,6 +200,27 @@ nogvl_getaddrinfo(void *arg) } #endif +#ifdef HAVE_GETADDRINFO_A +struct gai_suspend_arg +{ + struct gaicb *req; + struct timespec *timeout; +}; + +static void * +nogvl_gai_suspend(void *arg) +{ + int ret; + struct gai_suspend_arg *ptr = arg; + struct gaicb const *wait_reqs[1]; + + wait_reqs[0] = ptr->req; + ret = gai_suspend(wait_reqs, 1, ptr->timeout); + + return (void *)(VALUE)ret; +} +#endif + static int numeric_getaddrinfo(const char *node, const char *service, const struct addrinfo *hints, @@ -318,6 +340,59 @@ rb_getaddrinfo(const char *node, const char *service, return ret; } +#ifdef HAVE_GETADDRINFO_A +int +rb_getaddrinfo_a(const char *node, const char *service, + const struct addrinfo *hints, + struct rb_addrinfo **res, struct timespec *timeout) +{ + struct addrinfo *ai; + int ret; + int allocated_by_malloc = 0; + + ret = numeric_getaddrinfo(node, service, hints, &ai); + if (ret == 0) + allocated_by_malloc = 1; + else { + struct gai_suspend_arg arg; + struct gaicb *reqs[1]; + struct gaicb req; + + req.ar_name = node; + req.ar_service = service; + req.ar_request = hints; + + reqs[0] = &req; + ret = getaddrinfo_a(GAI_NOWAIT, reqs, 1, NULL); + if (ret) return ret; + + arg.req = &req; + arg.timeout = timeout; + + ret = (int)(VALUE)rb_thread_call_without_gvl(nogvl_gai_suspend, &arg, RUBY_UBF_IO, 0); + + if (ret) { + /* on Ubuntu 18.04 (or other systems), gai_suspend(3) returns EAI_SYSTEM/ENOENT on timeout */ + if (ret == EAI_SYSTEM && errno == ENOENT) { + return EAI_AGAIN; + } else { + return ret; + } + } + + ret = gai_error(reqs[0]); + ai = reqs[0]->ar_result; + } + + if (ret == 0) { + *res = (struct rb_addrinfo *)xmalloc(sizeof(struct rb_addrinfo)); + (*res)->allocated_by_malloc = allocated_by_malloc; + (*res)->ai = ai; + } + return ret; +} +#endif + void rb_freeaddrinfo(struct rb_addrinfo *ai) { @@ -530,6 +605,42 @@ rsock_getaddrinfo(VALUE host, VALUE port, struct addrinfo *hints, int socktype_h return res; } +#ifdef HAVE_GETADDRINFO_A +static struct rb_addrinfo* +rsock_getaddrinfo_a(VALUE host, VALUE port, struct addrinfo *hints, int socktype_hack, VALUE timeout) +{ + struct rb_addrinfo* res = NULL; + char *hostp, *portp; + int error; + char hbuf[NI_MAXHOST], pbuf[NI_MAXSERV]; + int additional_flags = 0; + + hostp = host_str(host, hbuf, sizeof(hbuf), &additional_flags); + portp = port_str(port, pbuf, sizeof(pbuf), &additional_flags); + + if (socktype_hack && hints->ai_socktype == 0 && str_is_number(portp)) { + hints->ai_socktype = SOCK_DGRAM; + } + hints->ai_flags |= additional_flags; + + if (NIL_P(timeout)) { + error = rb_getaddrinfo(hostp, portp, hints, &res); + } else { + struct timespec _timeout = rb_time_timespec_interval(timeout); + error = rb_getaddrinfo_a(hostp, portp, hints, &res, &_timeout); + } + + if (error) { + if (hostp && hostp[strlen(hostp)-1] == '\n') { + rb_raise(rb_eSocket, "newline at the end of hostname"); + } + rsock_raise_socket_error("getaddrinfo_a", error); + } + + return res; +} +#endif + int rsock_fd_family(int fd) { @@ -811,7 +922,7 @@ rsock_addrinfo_new(struct sockaddr *addr, socklen_t len, static struct rb_addrinfo * call_getaddrinfo(VALUE node, VALUE service, VALUE family, VALUE socktype, VALUE protocol, VALUE flags, - int socktype_hack) + int socktype_hack, VALUE timeout) { struct addrinfo hints; struct rb_addrinfo *res; @@ -828,7 +939,16 @@ call_getaddrinfo(VALUE node, VALUE service, if (!NIL_P(flags)) { hints.ai_flags = NUM2INT(flags); } - res = rsock_getaddrinfo(node, service, &hints, socktype_hack); + + if (NIL_P(timeout)) { + res = rsock_getaddrinfo(node, service, &hints, socktype_hack); + } else { +#ifdef HAVE_GETADDRINFO_A + res = rsock_getaddrinfo_a(node, service, &hints, socktype_hack, timeout); +#else + rb_f_notimplement(); +#endif + } if (res == NULL) rb_raise(rb_eSocket, "host not found"); @@ -842,7 +962,7 @@ init_addrinfo_getaddrinfo(rb_addrinfo_t *rai, VALUE node, VALUE service, VALUE family, VALUE socktype, VALUE protocol, VALUE flags, VALUE inspectnode, VALUE inspectservice) { - struct rb_addrinfo *res = call_getaddrinfo(node, service, family, socktype, protocol, flags, 1); + struct rb_addrinfo *res = call_getaddrinfo(node, service, family, socktype, protocol, flags, 1, Qnil); VALUE canonname; VALUE inspectname = rb_str_equal(node, inspectnode) ? Qnil : make_inspectname(inspectnode, inspectservice, res->ai); @@ -912,7 +1032,7 @@ addrinfo_firstonly_new(VALUE node, VALUE service, VALUE family, VALUE socktype, VALUE canonname; VALUE inspectname; - struct rb_addrinfo *res = call_getaddrinfo(node, service, family, socktype, protocol, flags, 0); + struct rb_addrinfo *res = call_getaddrinfo(node, service, family, socktype, protocol, flags, 0, Qnil); inspectname = make_inspectname(node, service, res->ai); @@ -932,13 +1052,13 @@ addrinfo_firstonly_new(VALUE node, VALUE service, VALUE family, VALUE socktype, } static VALUE -addrinfo_list_new(VALUE node, VALUE service, VALUE family, VALUE socktype, VALUE protocol, VALUE flags) +addrinfo_list_new(VALUE node, VALUE service, VALUE family, VALUE socktype, VALUE protocol, VALUE flags, VALUE timeout) { VALUE ret; struct addrinfo *r; VALUE inspectname; - struct rb_addrinfo *res = call_getaddrinfo(node, service, family, socktype, protocol, flags, 0); + struct rb_addrinfo *res = call_getaddrinfo(node, service, family, socktype, protocol, flags, 0, timeout); inspectname = make_inspectname(node, service, res->ai); @@ -1691,7 +1811,7 @@ addrinfo_mload(VALUE self, VALUE ary) #endif res = call_getaddrinfo(rb_ary_entry(pair, 0), rb_ary_entry(pair, 1), INT2NUM(pfamily), INT2NUM(socktype), INT2NUM(protocol), - INT2NUM(flags), 1); + INT2NUM(flags), 1, Qnil); len = res->ai->ai_addrlen; memcpy(&ss, res->ai->ai_addr, res->ai->ai_addrlen); @@ -2373,13 +2493,28 @@ addrinfo_unix_path(VALUE self) * */ static VALUE -addrinfo_s_getaddrinfo(int argc, VALUE *argv, VALUE self) +addrinfo_private_getaddrinfo(int argc, VALUE *argv, VALUE self) { VALUE node, service, family, socktype, protocol, flags; rb_scan_args(argc, argv, "24", &node, &service, &family, &socktype, &protocol, &flags); - return addrinfo_list_new(node, service, family, socktype, protocol, flags); + return addrinfo_list_new(node, service, family, socktype, protocol, flags, Qnil); +} + +#ifdef HAVE_GETADDRINFO_A +static ID id_timeout; + +static VALUE +addrinfo_private_getaddrinfo_a(int argc, VALUE *argv, VALUE self) +{ + VALUE node, service, family, socktype, protocol, flags, opts, timeout; + + rb_scan_args(argc, argv, "24:", &node, &service, &family, &socktype, + &protocol, &flags, &opts); + rb_get_kwargs(opts, &id_timeout, 1, 0, &timeout); + return addrinfo_list_new(node, service, family, socktype, protocol, flags, timeout); } +#endif /* * call-seq: @@ -2568,7 +2703,14 @@ rsock_init_addrinfo(void) rb_define_method(rb_cAddrinfo, "initialize", addrinfo_initialize, -1); rb_define_method(rb_cAddrinfo, "inspect", addrinfo_inspect, 0); rb_define_method(rb_cAddrinfo, "inspect_sockaddr", rsock_addrinfo_inspect_sockaddr, 0); - rb_define_singleton_method(rb_cAddrinfo, "getaddrinfo", addrinfo_s_getaddrinfo, -1); + + rb_define_private_method(CLASS_OF(rb_cAddrinfo), "_getaddrinfo", + addrinfo_private_getaddrinfo, -1); +#ifdef HAVE_GETADDRINFO_A + id_timeout = rb_intern("timeout"); + rb_define_private_method(CLASS_OF(rb_cAddrinfo), "_getaddrinfo_a", + addrinfo_private_getaddrinfo_a, -1); +#endif rb_define_singleton_method(rb_cAddrinfo, "ip", addrinfo_s_ip, 1); rb_define_singleton_method(rb_cAddrinfo, "tcp", addrinfo_s_tcp, 2); rb_define_singleton_method(rb_cAddrinfo, "udp", addrinfo_s_udp, 2); diff --git a/include/ruby/intern.h b/include/ruby/intern.h index 17aafd7..c05af10 100644 --- a/include/ruby/intern.h +++ b/include/ruby/intern.h @@ -933,6 +933,7 @@ VALUE rb_time_num_new(VALUE, VALUE); struct timeval rb_time_interval(VALUE num); struct timeval rb_time_timeval(VALUE time); struct timespec rb_time_timespec(VALUE time); +struct timespec rb_time_timespec_interval(VALUE num); VALUE rb_time_utc_offset(VALUE time); /* variable.c */ VALUE rb_mod_name(VALUE); diff --git a/test/socket/test_addrinfo.rb b/test/socket/test_addrinfo.rb index a06f3eb..2af6918 100644 --- a/test/socket/test_addrinfo.rb +++ b/test/socket/test_addrinfo.rb @@ -688,5 +688,10 @@ def test_marshal_unix assert_equal(ai1.canonname, ai2.canonname) end + def test_addrinfo_timeout + ai = Addrinfo.getaddrinfo("localhost.localdomain", "http", Socket::PF_INET, Socket::SOCK_STREAM, timeout: 1).fetch(0) + assert_equal(6, ai.protocol) + assert_equal(80, ai.ip_port) + end end end diff --git a/time.c b/time.c index 97d6d1d..7cd5445 100644 --- a/time.c +++ b/time.c @@ -2671,6 +2671,12 @@ rb_time_timespec(VALUE time) return time_timespec(time, FALSE); } +struct timespec +rb_time_timespec_interval(VALUE num) +{ + return time_timespec(num, TRUE); +} + /* * call-seq: * Time.now -> time