1 /*
2 * Check decoding of SCM_PIDFD control messages.
3 *
4 * Copyright (c) 2023 Dmitry V. Levin <ldv@strace.io>
5 * All rights reserved.
6 *
7 * SPDX-License-Identifier: GPL-2.0-or-later
8 */
9
10 #include "tests.h"
11 #include <assert.h>
12 #include <stdio.h>
13 #include <string.h>
14 #include <unistd.h>
15 #include <sys/socket.h>
16
17 #define XLAT_MACROS_ONLY
18 # include "xlat/sock_options.h"
19 # include "xlat/scmvals.h"
20 #undef XLAT_MACROS_ONLY
21
22 static void
23 print_pidfd(const struct cmsghdr *c)
24 {
25 const void *cmsg_header = c;
26 const void *cmsg_data = CMSG_DATA(c);
27 int pidfd;
28 const unsigned int expected_len = sizeof(pidfd);
29 const unsigned int data_len = c->cmsg_len - (cmsg_data - cmsg_header);
30
31 if (expected_len != data_len)
32 perror_msg_and_fail("sizeof(pidfd) = %u, data_len = %u\n",
33 expected_len, data_len);
34
35 memcpy(&pidfd, cmsg_data, sizeof(pidfd));
36 printf("%d<anon_inode:[pidfd]>", pidfd);
37 }
38
39 int
40 main(void)
41 {
42 skip_if_unavailable("/proc/self/fd/");
43
44 int sv[2];
45 if (socketpair(AF_UNIX, SOCK_STREAM, 0, sv))
46 perror_msg_and_skip("socketpair AF_UNIX SOCK_STREAM");
47
48 int one = 1;
49 if (setsockopt(sv[0], SOL_SOCKET, SO_PASSPIDFD, &one, sizeof(one)))
50 perror_msg_and_skip("setsockopt SO_PASSPIDFD");
51
52 char sym = 'A';
53 if (send(sv[1], &sym, 1, 0) != 1)
54 perror_msg_and_fail("send");
55 if (close(sv[1]))
56 perror_msg_and_fail("close send");
57
58 int pidfd;
59 unsigned int cmsg_size = CMSG_SPACE(sizeof(pidfd));
60 struct cmsghdr *cmsg = tail_alloc(cmsg_size);
61 memset(cmsg, 0, cmsg_size);
62
63 struct iovec iov = {
64 .iov_base = &sym,
65 .iov_len = sizeof(sym)
66 };
67 struct msghdr mh = {
68 .msg_iov = &iov,
69 .msg_iovlen = 1,
70 .msg_control = cmsg,
71 .msg_controllen = cmsg_size
72 };
73
74 if (recvmsg(sv[0], &mh, 0) != 1)
75 perror_msg_and_fail("recvmsg");
76
77 printf("recvmsg(%d<socket:[%lu]>, {msg_name=NULL, msg_namelen=0"
78 ", msg_iov=[{iov_base=\"A\", iov_len=1}], msg_iovlen=1",
79 sv[0], inode_of_sockfd(sv[0]));
80
81 bool found = false;
82 if (mh.msg_controllen) {
83 printf(", msg_control=[");
84 for (struct cmsghdr *c = CMSG_FIRSTHDR(&mh); c;
85 c = CMSG_NXTHDR(&mh, c)) {
86 printf("%s{cmsg_len=%lu, cmsg_level=",
87 (c == cmsg ? "" : ", "),
88 (unsigned long) c->cmsg_len);
89 if (c->cmsg_level == SOL_SOCKET) {
90 printf("SOL_SOCKET");
91 } else {
92 printf("%d /* expected SOL_SOCKET == %d */",
93 c->cmsg_level, (int) SOL_SOCKET);
94 }
95 printf(", cmsg_type=");
96 if (c->cmsg_type == SCM_PIDFD) {
97 printf("SCM_PIDFD, cmsg_data=");
98 print_pidfd(c);
99 found = true;
100 } else {
101 printf("%d /* expected SCM_PIDFD == %d */",
102 c->cmsg_type, (int) SCM_PIDFD);
103 }
104 printf("}");
105 }
106 printf("]");
107 }
108 printf(", msg_controllen=%lu, msg_flags=0}, 0) = 1\n",
109 (unsigned long) mh.msg_controllen);
110
111 if (!found)
112 error_msg_and_fail("SCM_PIDFD not found");
113
114 puts("+++ exited with 0 +++");
115 return 0;
116 }