Commit d2b1d6f2 authored by Andrey Vagin's avatar Andrey Vagin Committed by Pavel Emelyanov

socket: prevent dumping of sockets if they are not collected

The idea is simple. If the collection of given type of sockets failed,
crtools can't be sure, that it's able to dump such sockets correctly.
Signed-off-by: 's avatarAndrey Vagin <avagin@openvz.org>
Signed-off-by: 's avatarPavel Emelyanov <xemul@parallels.com>
parent d274025e
...@@ -32,6 +32,8 @@ extern int restore_socket_opts(int sk, SkOptsEntry *soe); ...@@ -32,6 +32,8 @@ extern int restore_socket_opts(int sk, SkOptsEntry *soe);
extern void release_skopts(SkOptsEntry *); extern void release_skopts(SkOptsEntry *);
extern int restore_prepare_socket(int sk); extern int restore_prepare_socket(int sk);
extern bool socket_test_collect_bit(unsigned int family, unsigned int proto);
extern int sk_collect_one(int ino, int family, struct socket_desc *d); extern int sk_collect_one(int ino, int family, struct socket_desc *d);
extern int collect_sockets(int pid); extern int collect_sockets(int pid);
extern int collect_inet_sockets(void); extern int collect_inet_sockets(void);
...@@ -51,7 +53,7 @@ extern char *sktype2s(u32 t); ...@@ -51,7 +53,7 @@ extern char *sktype2s(u32 t);
extern char *skproto2s(u32 p); extern char *skproto2s(u32 p);
extern char *skstate2s(u32 state); extern char *skstate2s(u32 state);
extern struct socket_desc *lookup_socket(int ino, int family); extern struct socket_desc *lookup_socket(int ino, int family, int proto);
extern int dump_one_inet(struct fd_parms *p, int lfd, const int fdinfo); extern int dump_one_inet(struct fd_parms *p, int lfd, const int fdinfo);
extern int dump_one_inet6(struct fd_parms *p, int lfd, const int fdinfo); extern int dump_one_inet6(struct fd_parms *p, int lfd, const int fdinfo);
......
...@@ -160,7 +160,7 @@ static int can_dump_inet_sk(const struct inet_sk_desc *sk) ...@@ -160,7 +160,7 @@ static int can_dump_inet_sk(const struct inet_sk_desc *sk)
return 1; return 1;
} }
static struct inet_sk_desc *gen_uncon_sk(int lfd, const struct fd_parms *p) static struct inet_sk_desc *gen_uncon_sk(int lfd, const struct fd_parms *p, int proto)
{ {
struct inet_sk_desc *sk; struct inet_sk_desc *sk;
char address; char address;
...@@ -188,10 +188,11 @@ static struct inet_sk_desc *gen_uncon_sk(int lfd, const struct fd_parms *p) ...@@ -188,10 +188,11 @@ static struct inet_sk_desc *gen_uncon_sk(int lfd, const struct fd_parms *p)
ret = do_dump_opt(lfd, SOL_SOCKET, SO_DOMAIN, &sk->sd.family, sizeof(sk->sd.family)); ret = do_dump_opt(lfd, SOL_SOCKET, SO_DOMAIN, &sk->sd.family, sizeof(sk->sd.family));
ret |= do_dump_opt(lfd, SOL_SOCKET, SO_TYPE, &sk->type, sizeof(sk->type)); ret |= do_dump_opt(lfd, SOL_SOCKET, SO_TYPE, &sk->type, sizeof(sk->type));
ret |= do_dump_opt(lfd, SOL_SOCKET, SO_PROTOCOL, &sk->proto, sizeof(sk->proto));
if (ret) if (ret)
goto err; goto err;
sk->proto = proto;
if (sk->proto == IPPROTO_TCP) { if (sk->proto == IPPROTO_TCP) {
struct tcp_info info; struct tcp_info info;
...@@ -226,11 +227,18 @@ static int do_dump_one_inet_fd(int lfd, u32 id, const struct fd_parms *p, int fa ...@@ -226,11 +227,18 @@ static int do_dump_one_inet_fd(int lfd, u32 id, const struct fd_parms *p, int fa
struct inet_sk_desc *sk; struct inet_sk_desc *sk;
InetSkEntry ie = INET_SK_ENTRY__INIT; InetSkEntry ie = INET_SK_ENTRY__INIT;
SkOptsEntry skopts = SK_OPTS_ENTRY__INIT; SkOptsEntry skopts = SK_OPTS_ENTRY__INIT;
int ret = -1, err = -1; int ret = -1, err = -1, proto;
sk = (struct inet_sk_desc *)lookup_socket(p->stat.st_ino, family); ret = do_dump_opt(lfd, SOL_SOCKET, SO_PROTOCOL,
&proto, sizeof(proto));
if (ret)
goto err;
sk = (struct inet_sk_desc *)lookup_socket(p->stat.st_ino, family, proto);
if (IS_ERR(sk))
goto err;
if (!sk) { if (!sk) {
sk = gen_uncon_sk(lfd, p); sk = gen_uncon_sk(lfd, p, proto);
if (!sk) if (!sk)
goto err; goto err;
} }
......
...@@ -90,7 +90,9 @@ static int dump_one_netlink_fd(int lfd, u32 id, const struct fd_parms *p) ...@@ -90,7 +90,9 @@ static int dump_one_netlink_fd(int lfd, u32 id, const struct fd_parms *p)
NetlinkSkEntry ne = NETLINK_SK_ENTRY__INIT; NetlinkSkEntry ne = NETLINK_SK_ENTRY__INIT;
SkOptsEntry skopts = SK_OPTS_ENTRY__INIT; SkOptsEntry skopts = SK_OPTS_ENTRY__INIT;
sk = (struct netlink_sk_desc *)lookup_socket(p->stat.st_ino, PF_NETLINK); sk = (struct netlink_sk_desc *)lookup_socket(p->stat.st_ino, PF_NETLINK, 0);
if (IS_ERR_OR_NULL(sk))
goto err;
ne.id = id; ne.id = id;
ne.ino = p->stat.st_ino; ne.ino = p->stat.st_ino;
......
...@@ -151,8 +151,8 @@ static int dump_one_packet_fd(int lfd, u32 id, const struct fd_parms *p) ...@@ -151,8 +151,8 @@ static int dump_one_packet_fd(int lfd, u32 id, const struct fd_parms *p)
struct packet_sock_desc *sd; struct packet_sock_desc *sd;
int i, ret; int i, ret;
sd = (struct packet_sock_desc *)lookup_socket(p->stat.st_ino, PF_PACKET); sd = (struct packet_sock_desc *)lookup_socket(p->stat.st_ino, PF_PACKET, 0);
if (sd == NULL) { if (IS_ERR_OR_NULL(sd)) {
pr_err("Can't find packet socket %lu\n", p->stat.st_ino); pr_err("Can't find packet socket %lu\n", p->stat.st_ino);
return -1; return -1;
} }
...@@ -219,8 +219,8 @@ int dump_socket_map(struct vma_area *vma) ...@@ -219,8 +219,8 @@ int dump_socket_map(struct vma_area *vma)
{ {
struct packet_sock_desc *sd; struct packet_sock_desc *sd;
sd = (struct packet_sock_desc *)lookup_socket(vma->vm_socket_id, PF_PACKET); sd = (struct packet_sock_desc *)lookup_socket(vma->vm_socket_id, PF_PACKET, 0);
if (!sd) { if (IS_ERR_OR_NULL(sd)) {
pr_err("Can't find packet socket %u to mmap\n", vma->vm_socket_id); pr_err("Can't find packet socket %u to mmap\n", vma->vm_socket_id);
return -1; return -1;
} }
......
...@@ -115,8 +115,8 @@ static int dump_one_unix_fd(int lfd, u32 id, const struct fd_parms *p) ...@@ -115,8 +115,8 @@ static int dump_one_unix_fd(int lfd, u32 id, const struct fd_parms *p)
SkOptsEntry skopts = SK_OPTS_ENTRY__INIT; SkOptsEntry skopts = SK_OPTS_ENTRY__INIT;
FilePermsEntry perms = FILE_PERMS_ENTRY__INIT; FilePermsEntry perms = FILE_PERMS_ENTRY__INIT;
sk = (struct unix_sk_desc *)lookup_socket(p->stat.st_ino, PF_UNIX); sk = (struct unix_sk_desc *)lookup_socket(p->stat.st_ino, PF_UNIX, 0);
if (!sk) if (IS_ERR_OR_NULL(sk))
goto err; goto err;
if (!can_dump_unix_sk(sk)) if (!can_dump_unix_sk(sk))
...@@ -151,8 +151,8 @@ static int dump_one_unix_fd(int lfd, u32 id, const struct fd_parms *p) ...@@ -151,8 +151,8 @@ static int dump_one_unix_fd(int lfd, u32 id, const struct fd_parms *p)
if (ue.peer) { if (ue.peer) {
struct unix_sk_desc *peer; struct unix_sk_desc *peer;
peer = (struct unix_sk_desc *)lookup_socket(ue.peer, PF_UNIX); peer = (struct unix_sk_desc *)lookup_socket(ue.peer, PF_UNIX, 0);
if (!peer) { if (IS_ERR_OR_NULL(peer)) {
pr_err("Unix socket %#x without peer %#x\n", pr_err("Unix socket %#x without peer %#x\n",
ue.ino, ue.peer); ue.ino, ue.peer);
goto err; goto err;
......
...@@ -39,6 +39,71 @@ ...@@ -39,6 +39,71 @@
#define SO_GET_FILTER SO_ATTACH_FILTER #define SO_GET_FILTER SO_ATTACH_FILTER
#endif #endif
enum socket_cl_bits
{
NETLINK_CL_BIT,
INET_TCP_CL_BIT,
INET_UDP_CL_BIT,
INET_UDPLITE_CL_BIT,
INET6_TCP_CL_BIT,
INET6_UDP_CL_BIT,
INET6_UDPLITE_CL_BIT,
UNIX_CL_BIT,
PACKET_CL_BIT,
_MAX_CL_BIT,
};
#define MAX_CL_BIT (_MAX_CL_BIT - 1)
static DECLARE_BITMAP(socket_cl_bits, MAX_CL_BIT);
static inline
enum socket_cl_bits get_collect_bit_nr(unsigned int family, unsigned int proto)
{
if (family == AF_NETLINK)
return NETLINK_CL_BIT;
if (family == AF_UNIX)
return UNIX_CL_BIT;
if (family == AF_PACKET)
return PACKET_CL_BIT;
if (family == AF_INET) {
if (proto == IPPROTO_TCP)
return INET_TCP_CL_BIT;
if (proto == IPPROTO_UDP)
return INET_UDP_CL_BIT;
if (proto == IPPROTO_UDPLITE)
return INET_UDPLITE_CL_BIT;
}
if (family == AF_INET6) {
if (proto == IPPROTO_TCP)
return INET6_TCP_CL_BIT;
if (proto == IPPROTO_UDP)
return INET6_UDP_CL_BIT;
if (proto == IPPROTO_UDPLITE)
return INET6_UDPLITE_CL_BIT;
}
pr_err("Unknown pair family %d proto %d\n", family, proto);
BUG();
return -1;
}
static void set_collect_bit(unsigned int family, unsigned int proto)
{
enum socket_cl_bits nr;
nr = get_collect_bit_nr(family, proto);
set_bit(nr, socket_cl_bits);
}
bool socket_test_collect_bit(unsigned int family, unsigned int proto)
{
enum socket_cl_bits nr;
nr = get_collect_bit_nr(family, proto);
return test_bit(nr, socket_cl_bits) != 0;
}
static int dump_bound_dev(int sk, SkOptsEntry *soe) static int dump_bound_dev(int sk, SkOptsEntry *soe)
{ {
int ret; int ret;
...@@ -162,10 +227,16 @@ static int restore_socket_filter(int sk, SkOptsEntry *soe) ...@@ -162,10 +227,16 @@ static int restore_socket_filter(int sk, SkOptsEntry *soe)
static struct socket_desc *sockets[SK_HASH_SIZE]; static struct socket_desc *sockets[SK_HASH_SIZE];
struct socket_desc *lookup_socket(int ino, int family) struct socket_desc *lookup_socket(int ino, int family, int proto)
{ {
struct socket_desc *sd; struct socket_desc *sd;
if (!socket_test_collect_bit(family, proto)) {
pr_err("Sockets (family %d, proto %d) are not collected\n",
family, proto);
return ERR_PTR(-EINVAL);
}
pr_debug("\tSearching for socket %x (family %d)\n", ino, family); pr_debug("\tSearching for socket %x (family %d)\n", ino, family);
for (sd = sockets[ino % SK_HASH_SIZE]; sd; sd = sd->next) for (sd = sockets[ino % SK_HASH_SIZE]; sd; sd = sd->next)
if (sd->ino == ino) { if (sd->ino == ino) {
...@@ -409,20 +480,35 @@ static int inet_receive_one(struct nlmsghdr *h, void *arg) ...@@ -409,20 +480,35 @@ static int inet_receive_one(struct nlmsghdr *h, void *arg)
return inet_collect_one(h, i->sdiag_family, type, i->sdiag_protocol); return inet_collect_one(h, i->sdiag_family, type, i->sdiag_protocol);
} }
struct sock_diag_req {
struct nlmsghdr hdr;
union {
struct unix_diag_req u;
struct inet_diag_req_v2 i;
struct packet_diag_req p;
struct netlink_diag_req n;
} r;
};
static int do_collect_req(int nl, struct sock_diag_req *req, int size,
int (*receive_callback)(struct nlmsghdr *h, void *), void *arg)
{
int tmp;
tmp = do_rtnl_req(nl, req, size, receive_callback, arg);
if (tmp == 0)
set_collect_bit(req->r.n.sdiag_family, req->r.n.sdiag_protocol);
return tmp;
}
int collect_sockets(int pid) int collect_sockets(int pid)
{ {
int err = 0, tmp; int err = 0, tmp;
int rst = -1; int rst = -1;
int nl; int nl;
struct { struct sock_diag_req req;
struct nlmsghdr hdr;
union {
struct unix_diag_req u;
struct inet_diag_req_v2 i;
struct packet_diag_req p;
struct netlink_diag_req n;
} r;
} req;
if (current_ns_mask & CLONE_NEWNET) { if (current_ns_mask & CLONE_NEWNET) {
pr_info("Switching to %d's net for collecting sockets\n", pid); pr_info("Switching to %d's net for collecting sockets\n", pid);
...@@ -450,7 +536,7 @@ int collect_sockets(int pid) ...@@ -450,7 +536,7 @@ int collect_sockets(int pid)
req.r.u.udiag_show = UDIAG_SHOW_NAME | UDIAG_SHOW_VFS | req.r.u.udiag_show = UDIAG_SHOW_NAME | UDIAG_SHOW_VFS |
UDIAG_SHOW_PEER | UDIAG_SHOW_ICONS | UDIAG_SHOW_PEER | UDIAG_SHOW_ICONS |
UDIAG_SHOW_RQLEN; UDIAG_SHOW_RQLEN;
tmp = do_rtnl_req(nl, &req, sizeof(req), unix_receive_one, NULL); tmp = do_collect_req(nl, &req, sizeof(req), unix_receive_one, NULL);
if (tmp) if (tmp)
err = tmp; err = tmp;
...@@ -460,7 +546,7 @@ int collect_sockets(int pid) ...@@ -460,7 +546,7 @@ int collect_sockets(int pid)
req.r.i.idiag_ext = 0; req.r.i.idiag_ext = 0;
/* Only listening and established sockets supported yet */ /* Only listening and established sockets supported yet */
req.r.i.idiag_states = (1 << TCP_LISTEN) | (1 << TCP_ESTABLISHED); req.r.i.idiag_states = (1 << TCP_LISTEN) | (1 << TCP_ESTABLISHED);
tmp = do_rtnl_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i); tmp = do_collect_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
if (tmp) if (tmp)
err = tmp; err = tmp;
...@@ -469,7 +555,7 @@ int collect_sockets(int pid) ...@@ -469,7 +555,7 @@ int collect_sockets(int pid)
req.r.i.sdiag_protocol = IPPROTO_UDP; req.r.i.sdiag_protocol = IPPROTO_UDP;
req.r.i.idiag_ext = 0; req.r.i.idiag_ext = 0;
req.r.i.idiag_states = -1; /* All */ req.r.i.idiag_states = -1; /* All */
tmp = do_rtnl_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i); tmp = do_collect_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
if (tmp) if (tmp)
err = tmp; err = tmp;
...@@ -478,7 +564,7 @@ int collect_sockets(int pid) ...@@ -478,7 +564,7 @@ int collect_sockets(int pid)
req.r.i.sdiag_protocol = IPPROTO_UDPLITE; req.r.i.sdiag_protocol = IPPROTO_UDPLITE;
req.r.i.idiag_ext = 0; req.r.i.idiag_ext = 0;
req.r.i.idiag_states = -1; /* All */ req.r.i.idiag_states = -1; /* All */
tmp = do_rtnl_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i); tmp = do_collect_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
if (tmp) if (tmp)
err = tmp; err = tmp;
...@@ -488,7 +574,7 @@ int collect_sockets(int pid) ...@@ -488,7 +574,7 @@ int collect_sockets(int pid)
req.r.i.idiag_ext = 0; req.r.i.idiag_ext = 0;
/* Only listening sockets supported yet */ /* Only listening sockets supported yet */
req.r.i.idiag_states = (1 << TCP_LISTEN) | (1 << TCP_ESTABLISHED); req.r.i.idiag_states = (1 << TCP_LISTEN) | (1 << TCP_ESTABLISHED);
tmp = do_rtnl_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i); tmp = do_collect_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
if (tmp) if (tmp)
err = tmp; err = tmp;
...@@ -497,7 +583,7 @@ int collect_sockets(int pid) ...@@ -497,7 +583,7 @@ int collect_sockets(int pid)
req.r.i.sdiag_protocol = IPPROTO_UDP; req.r.i.sdiag_protocol = IPPROTO_UDP;
req.r.i.idiag_ext = 0; req.r.i.idiag_ext = 0;
req.r.i.idiag_states = -1; /* All */ req.r.i.idiag_states = -1; /* All */
tmp = do_rtnl_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i); tmp = do_collect_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
if (tmp) if (tmp)
err = tmp; err = tmp;
...@@ -506,7 +592,7 @@ int collect_sockets(int pid) ...@@ -506,7 +592,7 @@ int collect_sockets(int pid)
req.r.i.sdiag_protocol = IPPROTO_UDPLITE; req.r.i.sdiag_protocol = IPPROTO_UDPLITE;
req.r.i.idiag_ext = 0; req.r.i.idiag_ext = 0;
req.r.i.idiag_states = -1; /* All */ req.r.i.idiag_states = -1; /* All */
tmp = do_rtnl_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i); tmp = do_collect_req(nl, &req, sizeof(req), inet_receive_one, &req.r.i);
if (tmp) if (tmp)
err = tmp; err = tmp;
...@@ -514,7 +600,7 @@ int collect_sockets(int pid) ...@@ -514,7 +600,7 @@ int collect_sockets(int pid)
req.r.p.sdiag_protocol = 0; req.r.p.sdiag_protocol = 0;
req.r.p.pdiag_show = PACKET_SHOW_INFO | PACKET_SHOW_MCLIST | req.r.p.pdiag_show = PACKET_SHOW_INFO | PACKET_SHOW_MCLIST |
PACKET_SHOW_FANOUT | PACKET_SHOW_RING_CFG; PACKET_SHOW_FANOUT | PACKET_SHOW_RING_CFG;
tmp = do_rtnl_req(nl, &req, sizeof(req), packet_receive_one, NULL); tmp = do_collect_req(nl, &req, sizeof(req), packet_receive_one, NULL);
if (tmp) { if (tmp) {
if (tmp == -ENOENT) /* Fedora 19 */ if (tmp == -ENOENT) /* Fedora 19 */
pr_warn("The currect kernel doesn't support packet_diag\n"); pr_warn("The currect kernel doesn't support packet_diag\n");
...@@ -525,7 +611,7 @@ int collect_sockets(int pid) ...@@ -525,7 +611,7 @@ int collect_sockets(int pid)
req.r.n.sdiag_family = AF_NETLINK; req.r.n.sdiag_family = AF_NETLINK;
req.r.n.sdiag_protocol = NDIAG_PROTO_ALL; req.r.n.sdiag_protocol = NDIAG_PROTO_ALL;
req.r.n.ndiag_show = NDIAG_SHOW_GROUPS; req.r.n.ndiag_show = NDIAG_SHOW_GROUPS;
tmp = do_rtnl_req(nl, &req, sizeof(req), netlink_receive_one, NULL); tmp = do_collect_req(nl, &req, sizeof(req), netlink_receive_one, NULL);
if (tmp) { if (tmp) {
if (tmp == -ENOENT) /* Going to be in 3.10 */ if (tmp == -ENOENT) /* Going to be in 3.10 */
pr_warn("The currect kernel doesn't support netlink_diag\n"); pr_warn("The currect kernel doesn't support netlink_diag\n");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment