diff --git a/kernel/audit.c b/kernel/audit.c index a871bf8..182d31b 100644 --- a/kernel/audit.c +++ b/kernel/audit.c @@ -110,7 +110,6 @@ struct audit_net { * @pid: auditd PID * @portid: netlink portid * @net: the associated network namespace - * @lock: spinlock to protect write access * * Description: * This struct is RCU protected; you must either hold the RCU lock for reading @@ -120,8 +119,9 @@ static struct auditd_connection { int pid; u32 portid; struct net *net; - spinlock_t lock; -} auditd_conn; + struct rcu_head rcu; +} *auditd_conn; +static DEFINE_SPINLOCK(auditd_conn_lock); /* If audit_rate_limit is non-zero, limit the rate of sending audit records * to that number per second. This prevents DoS attacks, but results in @@ -223,9 +223,11 @@ struct audit_reply { int auditd_test_task(const struct task_struct *task) { int rc; + pid_t pid; rcu_read_lock(); - rc = (auditd_conn.pid && task->tgid == auditd_conn.pid ? 1 : 0); + pid = rcu_dereference(auditd_conn)->pid; + rc = (pid && task->tgid == pid ? 1 : 0); rcu_read_unlock(); return rc; @@ -426,30 +428,37 @@ static int audit_set_failure(u32 state) return audit_do_config_change("audit_failure", &audit_failure, state); } +static void auditd_conn_free(struct rcu_head *rcu) +{ + struct auditd_connection *cn = container_of(rcu, struct auditd_connection, rcu); + + if (cn->net) + put_net(cn->net); + kfree(cn); +} + /** - * auditd_set - Set/Reset the auditd connection state - * @pid: auditd PID - * @portid: auditd netlink portid - * @net: auditd network namespace pointer + * auditd_set - Set the auditd connection state + * @new_coon: the new auditd connection * * Description: * This function will obtain and drop network namespace references as * necessary. */ -static void auditd_set(int pid, u32 portid, struct net *net) +static void auditd_set(struct auditd_connection *new_conn) { + struct auditd_connection *old_conn; unsigned long flags; - spin_lock_irqsave(&auditd_conn.lock, flags); - auditd_conn.pid = pid; - auditd_conn.portid = portid; - if (auditd_conn.net) - put_net(auditd_conn.net); - if (net) - auditd_conn.net = get_net(net); - else - auditd_conn.net = NULL; - spin_unlock_irqrestore(&auditd_conn.lock, flags); + if (new_conn->net) + get_net(new_conn->net); + spin_lock_irqsave(&auditd_conn_lock, flags); + old_conn = rcu_dereference_protected(auditd_conn, + lockdep_is_held(&auditd_conn_lock)); + rcu_assign_pointer(auditd_conn, new_conn); + spin_unlock_irqrestore(&auditd_conn_lock, flags); + + call_rcu(&old_conn->rcu, auditd_conn_free); } /** @@ -537,19 +546,24 @@ static void kauditd_retry_skb(struct sk_buff *skb) /** * auditd_reset - Disconnect the auditd connection + * @flags: GFP flags * * Description: * Break the auditd/kauditd connection and move all the queued records into the * hold queue in case auditd reconnects. */ -static void auditd_reset(void) +static void auditd_reset(gfp_t flags) { struct sk_buff *skb; + struct auditd_connection *null_conn; + null_conn = kzalloc(sizeof(*null_conn), flags); + if (!null_conn) + return; /* if it isn't already broken, break the connection */ rcu_read_lock(); - if (auditd_conn.pid) - auditd_set(0, 0, NULL); + if (rcu_dereference(auditd_conn)->pid) + auditd_set(null_conn); rcu_read_unlock(); /* flush all of the main and retry queues to the hold queue */ @@ -585,15 +599,15 @@ static int auditd_send_unicast_skb(struct sk_buff *skb) * section netlink_unicast() should safely return an error */ rcu_read_lock(); - if (!auditd_conn.pid) { + if (!rcu_dereference(auditd_conn)->pid) { rcu_read_unlock(); rc = -ECONNREFUSED; goto err; } - net = auditd_conn.net; + net = rcu_dereference(auditd_conn)->net; get_net(net); sk = audit_get_sk(net); - portid = auditd_conn.portid; + portid = rcu_dereference(auditd_conn)->portid; rcu_read_unlock(); rc = netlink_unicast(sk, skb, portid, 0); @@ -605,7 +619,7 @@ static int auditd_send_unicast_skb(struct sk_buff *skb) err: if (rc == -ECONNREFUSED) - auditd_reset(); + auditd_reset(GFP_KERNEL); return rc; } @@ -735,14 +749,14 @@ static int kauditd_thread(void *dummy) while (!kthread_should_stop()) { /* NOTE: see the lock comments in auditd_send_unicast_skb() */ rcu_read_lock(); - if (!auditd_conn.pid) { + if (!rcu_dereference(auditd_conn)->pid) { rcu_read_unlock(); goto main_queue; } - net = auditd_conn.net; + net = rcu_dereference(auditd_conn)->net; get_net(net); sk = audit_get_sk(net); - portid = auditd_conn.portid; + portid = rcu_dereference(auditd_conn)->portid; rcu_read_unlock(); /* attempt to flush the hold queue */ @@ -751,7 +765,7 @@ static int kauditd_thread(void *dummy) NULL, kauditd_rehold_skb); if (rc < 0) { sk = NULL; - auditd_reset(); + auditd_reset(GFP_KERNEL); goto main_queue; } @@ -761,7 +775,7 @@ static int kauditd_thread(void *dummy) NULL, kauditd_hold_skb); if (rc < 0) { sk = NULL; - auditd_reset(); + auditd_reset(GFP_KERNEL); goto main_queue; } @@ -774,7 +788,7 @@ static int kauditd_thread(void *dummy) kauditd_send_multicast_skb, kauditd_retry_skb); if (sk == NULL || rc < 0) - auditd_reset(); + auditd_reset(GFP_KERNEL); sk = NULL; /* drop our netns reference, no auditd sends past this line */ @@ -1103,7 +1117,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh) s.enabled = audit_enabled; s.failure = audit_failure; rcu_read_lock(); - s.pid = auditd_conn.pid; + s.pid = rcu_dereference(auditd_conn)->pid; rcu_read_unlock(); s.rate_limit = audit_rate_limit; s.backlog_limit = audit_backlog_limit; @@ -1139,17 +1153,22 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh) int new_pid = s.pid; pid_t auditd_pid; pid_t requesting_pid = task_tgid_vnr(current); + struct auditd_connection *new; + new = kmalloc(sizeof(*new), GFP_KERNEL); + if (!new) + return -ENOMEM; /* test the auditd connection */ audit_replace(requesting_pid); rcu_read_lock(); - auditd_pid = auditd_conn.pid; + auditd_pid = rcu_dereference(auditd_conn)->pid; /* only the current auditd can unregister itself */ if ((!new_pid) && (requesting_pid != auditd_pid)) { rcu_read_unlock(); audit_log_config_change("audit_pid", new_pid, auditd_pid, 0); + kfree(new); return -EACCES; } /* replacing a healthy auditd is not allowed */ @@ -1157,6 +1176,7 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh) rcu_read_unlock(); audit_log_config_change("audit_pid", new_pid, auditd_pid, 0); + kfree(new); return -EEXIST; } rcu_read_unlock(); @@ -1166,15 +1186,16 @@ static int audit_receive_msg(struct sk_buff *skb, struct nlmsghdr *nlh) auditd_pid, 1); if (new_pid) { + new->pid = new_pid; + new->portid = NETLINK_CB(skb).portid; + new->net = sock_net(NETLINK_CB(skb).sk); /* register a new auditd connection */ - auditd_set(new_pid, - NETLINK_CB(skb).portid, - sock_net(NETLINK_CB(skb).sk)); + auditd_set(new); /* try to process any backlog */ wake_up_interruptible(&kauditd_wait); } else /* unregister the auditd connection */ - auditd_reset(); + auditd_reset(GFP_KERNEL); } if (s.mask & AUDIT_STATUS_RATE_LIMIT) { err = audit_set_rate_limit(s.rate_limit); @@ -1448,8 +1469,8 @@ static void __net_exit audit_net_exit(struct net *net) struct audit_net *aunet = net_generic(net, audit_net_id); rcu_read_lock(); - if (net == auditd_conn.net) - auditd_reset(); + if (net == rcu_dereference(auditd_conn)->net) + auditd_reset(GFP_ATOMIC); rcu_read_unlock(); netlink_kernel_release(aunet->sk); @@ -1470,8 +1491,9 @@ static int __init audit_init(void) if (audit_initialized == AUDIT_DISABLED) return 0; - memset(&auditd_conn, 0, sizeof(auditd_conn)); - spin_lock_init(&auditd_conn.lock); + auditd_conn = kzalloc(sizeof(*auditd_conn), GFP_KERNEL); + if (!auditd_conn) + return -ENOMEM; skb_queue_head_init(&audit_queue); skb_queue_head_init(&audit_retry_queue);