From mboxrd@z Thu Jan 1 00:00:00 1970 Return-Path: Received: from mail-ve1eur01on0062.outbound.protection.outlook.com ([104.47.1.62]:44407 "EHLO EUR01-VE1-obe.outbound.protection.outlook.com" rhost-flags-OK-OK-OK-FAIL) by vger.kernel.org with ESMTP id S1751480AbeCUPxl (ORCPT ); Wed, 21 Mar 2018 11:53:41 -0400 Subject: Re: [PATCH net-next 06/14] net/tls: Add generic NIC offload infrastructure To: Kirill Tkhai , Saeed Mahameed , "David S. Miller" Cc: netdev@vger.kernel.org, Dave Watson , Ilya Lesokhin , Aviad Yehezkel References: <20180320024510.7408-1-saeedm@mellanox.com> <20180320024510.7408-7-saeedm@mellanox.com> From: Boris Pismenny Message-ID: <0d69bcee-0334-0da7-4377-0422f62ae1a7@mellanox.com> Date: Wed, 21 Mar 2018 17:53:32 +0200 MIME-Version: 1.0 In-Reply-To: Content-Type: text/plain; charset=utf-8; format=flowed Content-Language: en-US Content-Transfer-Encoding: 7bit Sender: netdev-owner@vger.kernel.org List-ID: ... > > Other patches have two licenses in header. Can I distribute this file under GPL license terms? > Sure, I'll update the license to match other files under net/tls. >> +#include >> +#include >> +#include >> +#include >> +#include >> + >> +#include >> +#include >> + >> +/* device_offload_lock is used to synchronize tls_dev_add >> + * against NETDEV_DOWN notifications. >> + */ >> +DEFINE_STATIC_PERCPU_RWSEM(device_offload_lock); >> + >> +static void tls_device_gc_task(struct work_struct *work); >> + >> +static DECLARE_WORK(tls_device_gc_work, tls_device_gc_task); >> +static LIST_HEAD(tls_device_gc_list); >> +static LIST_HEAD(tls_device_list); >> +static DEFINE_SPINLOCK(tls_device_lock); >> + >> +static void tls_device_free_ctx(struct tls_context *ctx) >> +{ >> + struct tls_offload_context *offlad_ctx = tls_offload_ctx(ctx); >> + >> + kfree(offlad_ctx); >> + kfree(ctx); >> +} >> + >> +static void tls_device_gc_task(struct work_struct *work) >> +{ >> + struct tls_context *ctx, *tmp; >> + struct list_head gc_list; >> + unsigned long flags; >> + >> + spin_lock_irqsave(&tls_device_lock, flags); >> + INIT_LIST_HEAD(&gc_list); > > This is stack variable, and it should be initialized outside of global spinlock. > There is LIST_HEAD() primitive for that in kernel. > There is one more similar place below. > Sure. >> + list_splice_init(&tls_device_gc_list, &gc_list); >> + spin_unlock_irqrestore(&tls_device_lock, flags); >> + >> + list_for_each_entry_safe(ctx, tmp, &gc_list, list) { >> + struct net_device *netdev = ctx->netdev; >> + >> + if (netdev) { >> + netdev->tlsdev_ops->tls_dev_del(netdev, ctx, >> + TLS_OFFLOAD_CTX_DIR_TX); >> + dev_put(netdev); >> + } > > How is possible the situation we meet NULL netdev here > This can happen in tls_device_down. tls_deviec_down is called whenever a netdev that is used for TLS inline crypto offload goes down. It gets called via the NETDEV_DOWN event of the netdevice notifier. This flow is somewhat similar to the xfrm_device netdev notifier. However, we do not destroy the socket (as in destroying the xfrm_state in xfrm_device). Instead, we cleanup the netdev state and allow software fallback to handle the rest of the traffic. >> + >> + list_del(&ctx->list); >> + tls_device_free_ctx(ctx); >> + } >> +} >> + >> +static void tls_device_queue_ctx_destruction(struct tls_context *ctx) >> +{ >> + unsigned long flags; >> + >> + spin_lock_irqsave(&tls_device_lock, flags); >> + list_move_tail(&ctx->list, &tls_device_gc_list); >> + >> + /* schedule_work inside the spinlock >> + * to make sure tls_device_down waits for that work. >> + */ >> + schedule_work(&tls_device_gc_work); >> + >> + spin_unlock_irqrestore(&tls_device_lock, flags); >> +} >> + >> +/* We assume that the socket is already connected */ >> +static struct net_device *get_netdev_for_sock(struct sock *sk) >> +{ >> + struct inet_sock *inet = inet_sk(sk); >> + struct net_device *netdev = NULL; >> + >> + netdev = dev_get_by_index(sock_net(sk), inet->cork.fl.flowi_oif); >> + >> + return netdev; >> +} >> + >> +static int attach_sock_to_netdev(struct sock *sk, struct net_device *netdev, >> + struct tls_context *ctx) >> +{ >> + int rc; >> + >> + rc = netdev->tlsdev_ops->tls_dev_add(netdev, sk, TLS_OFFLOAD_CTX_DIR_TX, >> + &ctx->crypto_send, >> + tcp_sk(sk)->write_seq); >> + if (rc) { >> + pr_err_ratelimited("The netdev has refused to offload this socket\n"); >> + goto out; >> + } >> + >> + rc = 0; >> +out: >> + return rc; >> +} >> + >> +static void destroy_record(struct tls_record_info *record) >> +{ >> + skb_frag_t *frag; >> + int nr_frags = record->num_frags; >> + >> + while (nr_frags > 0) { >> + frag = &record->frags[nr_frags - 1]; >> + __skb_frag_unref(frag); >> + --nr_frags; >> + } >> + kfree(record); >> +} >> + >> +static void delete_all_records(struct tls_offload_context *offload_ctx) >> +{ >> + struct tls_record_info *info, *temp; >> + >> + list_for_each_entry_safe(info, temp, &offload_ctx->records_list, list) { >> + list_del(&info->list); >> + destroy_record(info); >> + } >> + >> + offload_ctx->retransmit_hint = NULL; >> +} >> + >> +static void tls_icsk_clean_acked(struct sock *sk, u32 acked_seq) >> +{ >> + struct tls_context *tls_ctx = tls_get_ctx(sk); >> + struct tls_offload_context *ctx; >> + struct tls_record_info *info, *temp; >> + unsigned long flags; >> + u64 deleted_records = 0; >> + >> + if (!tls_ctx) >> + return; >> + >> + ctx = tls_offload_ctx(tls_ctx); >> + >> + spin_lock_irqsave(&ctx->lock, flags); >> + info = ctx->retransmit_hint; >> + if (info && !before(acked_seq, info->end_seq)) { >> + ctx->retransmit_hint = NULL; >> + list_del(&info->list); >> + destroy_record(info); >> + deleted_records++; >> + } >> + >> + list_for_each_entry_safe(info, temp, &ctx->records_list, list) { >> + if (before(acked_seq, info->end_seq)) >> + break; >> + list_del(&info->list); >> + >> + destroy_record(info); >> + deleted_records++; >> + } >> + >> + ctx->unacked_record_sn += deleted_records; >> + spin_unlock_irqrestore(&ctx->lock, flags); >> +} >> + >> +/* At this point, there should be no references on this >> + * socket and no in-flight SKBs associated with this >> + * socket, so it is safe to free all the resources. >> + */ >> +void tls_device_sk_destruct(struct sock *sk) >> +{ >> + struct tls_context *tls_ctx = tls_get_ctx(sk); >> + struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx); >> + >> + if (ctx->open_record) >> + destroy_record(ctx->open_record); >> + >> + delete_all_records(ctx); >> + crypto_free_aead(ctx->aead_send); >> + ctx->sk_destruct(sk); >> + >> + if (refcount_dec_and_test(&tls_ctx->refcount)) >> + tls_device_queue_ctx_destruction(tls_ctx); >> +} >> +EXPORT_SYMBOL(tls_device_sk_destruct); >> + >> +static inline void tls_append_frag(struct tls_record_info *record, >> + struct page_frag *pfrag, >> + int size) >> +{ >> + skb_frag_t *frag; >> + >> + frag = &record->frags[record->num_frags - 1]; >> + if (frag->page.p == pfrag->page && >> + frag->page_offset + frag->size == pfrag->offset) { >> + frag->size += size; >> + } else { >> + ++frag; >> + frag->page.p = pfrag->page; >> + frag->page_offset = pfrag->offset; >> + frag->size = size; >> + ++record->num_frags; >> + get_page(pfrag->page); >> + } >> + >> + pfrag->offset += size; >> + record->len += size; >> +} >> + >> +static inline int tls_push_record(struct sock *sk, >> + struct tls_context *ctx, >> + struct tls_offload_context *offload_ctx, >> + struct tls_record_info *record, >> + struct page_frag *pfrag, >> + int flags, >> + unsigned char record_type) >> +{ >> + skb_frag_t *frag; >> + struct tcp_sock *tp = tcp_sk(sk); >> + struct page_frag fallback_frag; >> + struct page_frag *tag_pfrag = pfrag; >> + int i; >> + >> + /* fill prepand */ >> + frag = &record->frags[0]; >> + tls_fill_prepend(ctx, >> + skb_frag_address(frag), >> + record->len - ctx->prepend_size, >> + record_type); >> + >> + if (unlikely(!skb_page_frag_refill(ctx->tag_size, pfrag, GFP_KERNEL))) { >> + /* HW doesn't care about the data in the tag >> + * so in case pfrag has no room >> + * for a tag and we can't allocate a new pfrag >> + * just use the page in the first frag >> + * rather then write a complicated fall back code. >> + */ >> + tag_pfrag = &fallback_frag; >> + tag_pfrag->page = skb_frag_page(frag); >> + tag_pfrag->offset = 0; >> + } >> + >> + tls_append_frag(record, tag_pfrag, ctx->tag_size); >> + record->end_seq = tp->write_seq + record->len; >> + spin_lock_irq(&offload_ctx->lock); >> + list_add_tail(&record->list, &offload_ctx->records_list); >> + spin_unlock_irq(&offload_ctx->lock); >> + offload_ctx->open_record = NULL; >> + set_bit(TLS_PENDING_CLOSED_RECORD, &ctx->flags); >> + tls_advance_record_sn(sk, ctx); >> + >> + for (i = 0; i < record->num_frags; i++) { >> + frag = &record->frags[i]; >> + sg_unmark_end(&offload_ctx->sg_tx_data[i]); >> + sg_set_page(&offload_ctx->sg_tx_data[i], skb_frag_page(frag), >> + frag->size, frag->page_offset); >> + sk_mem_charge(sk, frag->size); >> + get_page(skb_frag_page(frag)); >> + } >> + sg_mark_end(&offload_ctx->sg_tx_data[record->num_frags - 1]); >> + >> + /* all ready, send */ >> + return tls_push_sg(sk, ctx, offload_ctx->sg_tx_data, 0, flags); >> +} >> + >> +static inline int tls_create_new_record(struct tls_offload_context *offload_ctx, >> + struct page_frag *pfrag, >> + size_t prepend_size) >> +{ >> + skb_frag_t *frag; >> + struct tls_record_info *record; >> + >> + record = kmalloc(sizeof(*record), GFP_KERNEL); >> + if (!record) >> + return -ENOMEM; >> + >> + frag = &record->frags[0]; >> + __skb_frag_set_page(frag, pfrag->page); >> + frag->page_offset = pfrag->offset; >> + skb_frag_size_set(frag, prepend_size); >> + >> + get_page(pfrag->page); >> + pfrag->offset += prepend_size; >> + >> + record->num_frags = 1; >> + record->len = prepend_size; >> + offload_ctx->open_record = record; >> + return 0; >> +} >> + >> +static inline int tls_do_allocation(struct sock *sk, >> + struct tls_offload_context *offload_ctx, >> + struct page_frag *pfrag, >> + size_t prepend_size) >> +{ >> + int ret; >> + >> + if (!offload_ctx->open_record) { >> + if (unlikely(!skb_page_frag_refill(prepend_size, pfrag, >> + sk->sk_allocation))) { >> + sk->sk_prot->enter_memory_pressure(sk); >> + sk_stream_moderate_sndbuf(sk); >> + return -ENOMEM; >> + } >> + >> + ret = tls_create_new_record(offload_ctx, pfrag, prepend_size); >> + if (ret) >> + return ret; >> + >> + if (pfrag->size > pfrag->offset) >> + return 0; >> + } >> + >> + if (!sk_page_frag_refill(sk, pfrag)) >> + return -ENOMEM; >> + >> + return 0; >> +} >> + >> +static int tls_push_data(struct sock *sk, >> + struct iov_iter *msg_iter, >> + size_t size, int flags, >> + unsigned char record_type) >> +{ >> + struct tls_context *tls_ctx = tls_get_ctx(sk); >> + struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx); >> + struct tls_record_info *record = ctx->open_record; >> + struct page_frag *pfrag; >> + int copy, rc = 0; >> + size_t orig_size = size; >> + u32 max_open_record_len; >> + long timeo; >> + int more = flags & (MSG_SENDPAGE_NOTLAST | MSG_MORE); >> + int tls_push_record_flags = flags | MSG_SENDPAGE_NOTLAST; >> + bool done = false; >> + >> + if (flags & >> + ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SENDPAGE_NOTLAST)) >> + return -ENOTSUPP; >> + >> + if (sk->sk_err) >> + return -sk->sk_err; >> + >> + timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); >> + rc = tls_complete_pending_work(sk, tls_ctx, flags, &timeo); >> + if (rc < 0) >> + return rc; >> + >> + pfrag = sk_page_frag(sk); >> + >> + /* KTLS_TLS_HEADER_SIZE is not counted as part of the TLS record, and >> + * we need to leave room for an authentication tag. >> + */ >> + max_open_record_len = TLS_MAX_PAYLOAD_SIZE + >> + tls_ctx->prepend_size; >> + do { >> + if (tls_do_allocation(sk, ctx, pfrag, >> + tls_ctx->prepend_size)) { >> + rc = sk_stream_wait_memory(sk, &timeo); >> + if (!rc) >> + continue; >> + >> + record = ctx->open_record; >> + if (!record) >> + break; >> +handle_error: >> + if (record_type != TLS_RECORD_TYPE_DATA) { >> + /* avoid sending partial >> + * record with type != >> + * application_data >> + */ >> + size = orig_size; >> + destroy_record(record); >> + ctx->open_record = NULL; >> + } else if (record->len > tls_ctx->prepend_size) { >> + goto last_record; >> + } >> + >> + break; >> + } >> + >> + record = ctx->open_record; >> + copy = min_t(size_t, size, (pfrag->size - pfrag->offset)); >> + copy = min_t(size_t, copy, (max_open_record_len - record->len)); >> + >> + if (copy_from_iter_nocache(page_address(pfrag->page) + >> + pfrag->offset, >> + copy, msg_iter) != copy) { >> + rc = -EFAULT; >> + goto handle_error; >> + } >> + tls_append_frag(record, pfrag, copy); >> + >> + size -= copy; >> + if (!size) { >> +last_record: >> + tls_push_record_flags = flags; >> + if (more) { >> + tls_ctx->pending_open_record_frags = >> + record->num_frags; >> + break; >> + } >> + >> + done = true; >> + } >> + >> + if ((done) || record->len >= max_open_record_len || >> + (record->num_frags >= MAX_SKB_FRAGS - 1)) { >> + rc = tls_push_record(sk, >> + tls_ctx, >> + ctx, >> + record, >> + pfrag, >> + tls_push_record_flags, >> + record_type); >> + if (rc < 0) >> + break; >> + } >> + } while (!done); >> + >> + if (orig_size - size > 0) >> + rc = orig_size - size; >> + >> + return rc; >> +} >> + >> +int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) >> +{ >> + unsigned char record_type = TLS_RECORD_TYPE_DATA; >> + int rc = 0; >> + >> + lock_sock(sk); >> + >> + if (unlikely(msg->msg_controllen)) { >> + rc = tls_proccess_cmsg(sk, msg, &record_type); >> + if (rc) >> + goto out; >> + } >> + >> + rc = tls_push_data(sk, &msg->msg_iter, size, >> + msg->msg_flags, record_type); >> + >> +out: >> + release_sock(sk); >> + return rc; >> +} >> + >> +int tls_device_sendpage(struct sock *sk, struct page *page, >> + int offset, size_t size, int flags) >> +{ >> + struct iov_iter msg_iter; >> + struct kvec iov; >> + char *kaddr = kmap(page); >> + int rc = 0; >> + >> + if (flags & MSG_SENDPAGE_NOTLAST) >> + flags |= MSG_MORE; >> + >> + lock_sock(sk); >> + >> + if (flags & MSG_OOB) { >> + rc = -ENOTSUPP; >> + goto out; >> + } >> + >> + iov.iov_base = kaddr + offset; >> + iov.iov_len = size; >> + iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, &iov, 1, size); >> + rc = tls_push_data(sk, &msg_iter, size, >> + flags, TLS_RECORD_TYPE_DATA); >> + kunmap(page); >> + >> +out: >> + release_sock(sk); >> + return rc; >> +} >> + >> +struct tls_record_info *tls_get_record(struct tls_offload_context *context, >> + u32 seq, u64 *p_record_sn) >> +{ >> + struct tls_record_info *info; >> + u64 record_sn = context->hint_record_sn; >> + >> + info = context->retransmit_hint; >> + if (!info || >> + before(seq, info->end_seq - info->len)) { >> + /* if retransmit_hint is irrelevant start >> + * from the begging of the list >> + */ >> + info = list_first_entry(&context->records_list, >> + struct tls_record_info, list); >> + record_sn = context->unacked_record_sn; >> + } >> + >> + list_for_each_entry_from(info, &context->records_list, list) { >> + if (before(seq, info->end_seq)) { >> + if (!context->retransmit_hint || >> + after(info->end_seq, >> + context->retransmit_hint->end_seq)) { >> + context->hint_record_sn = record_sn; >> + context->retransmit_hint = info; >> + } >> + *p_record_sn = record_sn; >> + return info; >> + } >> + record_sn++; >> + } >> + >> + return NULL; >> +} >> +EXPORT_SYMBOL(tls_get_record); >> + >> +static int tls_device_push_pending_record(struct sock *sk, int flags) >> +{ >> + struct iov_iter msg_iter; >> + >> + iov_iter_kvec(&msg_iter, WRITE | ITER_KVEC, NULL, 0, 0); >> + return tls_push_data(sk, &msg_iter, 0, flags, TLS_RECORD_TYPE_DATA); >> +} >> + >> +int tls_set_device_offload(struct sock *sk, struct tls_context *ctx) >> +{ >> + u16 nonece_size, tag_size, iv_size, rec_seq_size; >> + struct tls_record_info *start_marker_record; >> + struct tls_offload_context *offload_ctx; >> + struct tls_crypto_info *crypto_info; >> + struct net_device *netdev; >> + char *iv, *rec_seq; >> + struct sk_buff *skb; >> + int rc = -EINVAL; >> + __be64 rcd_sn; >> + >> + if (!ctx) >> + goto out; >> + >> + if (ctx->priv_ctx) { >> + rc = -EEXIST; >> + goto out; >> + } >> + >> + /* We support starting offload on multiple sockets >> + * concurrently, So we only need a read lock here. >> + */ >> + percpu_down_read(&device_offload_lock); >> + netdev = get_netdev_for_sock(sk); >> + if (!netdev) { >> + pr_err_ratelimited("%s: netdev not found\n", __func__); >> + rc = -EINVAL; >> + goto release_lock; >> + } >> + >> + if (!(netdev->features & NETIF_F_HW_TLS_TX)) { >> + rc = -ENOTSUPP; >> + goto release_netdev; >> + } >> + >> + /* Avoid offloading if the device is down >> + * We don't want to offload new flows after >> + * the NETDEV_DOWN event >> + */ >> + if (!(netdev->flags & IFF_UP)) { >> + rc = -EINVAL; >> + goto release_lock; >> + } >> + >> + crypto_info = &ctx->crypto_send; >> + switch (crypto_info->cipher_type) { >> + case TLS_CIPHER_AES_GCM_128: { >> + nonece_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; >> + tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE; >> + iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE; >> + iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv; >> + rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE; >> + rec_seq = >> + ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq; >> + break; >> + } >> + default: >> + rc = -EINVAL; >> + goto release_netdev; >> + } >> + >> + start_marker_record = kmalloc(sizeof(*start_marker_record), GFP_KERNEL); > > Can we memory allocations and simple memory initializations ouside the global rwsem? > Sure, we can move all memory allocations outside the lock. >> + if (!start_marker_record) { >> + rc = -ENOMEM; >> + goto release_netdev; >> + } >> + >> + offload_ctx = kzalloc(TLS_OFFLOAD_CONTEXT_SIZE, GFP_KERNEL); >> + if (!offload_ctx) >> + goto free_marker_record; >> + >> + ctx->priv_ctx = offload_ctx; >> + rc = attach_sock_to_netdev(sk, netdev, ctx); >> + if (rc) >> + goto free_offload_context; >> + >> + ctx->netdev = netdev; >> + ctx->prepend_size = TLS_HEADER_SIZE + nonece_size; >> + ctx->tag_size = tag_size; >> + ctx->iv_size = iv_size; >> + ctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE, >> + GFP_KERNEL); >> + if (!ctx->iv) { >> + rc = -ENOMEM; >> + goto detach_sock; >> + } >> + >> + memcpy(ctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size); >> + >> + ctx->rec_seq_size = rec_seq_size; >> + ctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL); >> + if (!ctx->rec_seq) { >> + rc = -ENOMEM; >> + goto free_iv; >> + } >> + memcpy(ctx->rec_seq, rec_seq, rec_seq_size); >> + >> + /* start at rec_seq - 1 to account for the start marker record */ >> + memcpy(&rcd_sn, ctx->rec_seq, sizeof(rcd_sn)); >> + offload_ctx->unacked_record_sn = be64_to_cpu(rcd_sn) - 1; >> + >> + rc = tls_sw_fallback_init(sk, offload_ctx, crypto_info); >> + if (rc) >> + goto free_rec_seq; >> + >> + start_marker_record->end_seq = tcp_sk(sk)->write_seq; >> + start_marker_record->len = 0; >> + start_marker_record->num_frags = 0; >> + >> + INIT_LIST_HEAD(&offload_ctx->records_list); >> + list_add_tail(&start_marker_record->list, &offload_ctx->records_list); >> + spin_lock_init(&offload_ctx->lock); >> + >> + inet_csk(sk)->icsk_clean_acked = &tls_icsk_clean_acked; >> + ctx->push_pending_record = tls_device_push_pending_record; >> + offload_ctx->sk_destruct = sk->sk_destruct; >> + >> + /* TLS offload is greatly simplified if we don't send >> + * SKBs where only part of the payload needs to be encrypted. >> + * So mark the last skb in the write queue as end of record. >> + */ >> + skb = tcp_write_queue_tail(sk); >> + if (skb) >> + TCP_SKB_CB(skb)->eor = 1; >> + >> + refcount_set(&ctx->refcount, 1); >> + spin_lock_irq(&tls_device_lock); >> + list_add_tail(&ctx->list, &tls_device_list); >> + spin_unlock_irq(&tls_device_lock); >> + >> + /* following this assignment tls_is_sk_tx_device_offloaded >> + * will return true and the context might be accessed >> + * by the netdev's xmit function. >> + */ >> + smp_store_release(&sk->sk_destruct, >> + &tls_device_sk_destruct); >> + goto release_lock; >> + >> +free_rec_seq: >> + kfree(ctx->rec_seq); >> +free_iv: >> + kfree(ctx->iv); >> +detach_sock: >> + netdev->tlsdev_ops->tls_dev_del(netdev, ctx, TLS_OFFLOAD_CTX_DIR_TX); >> +free_offload_context: >> + kfree(offload_ctx); >> + ctx->priv_ctx = NULL; >> +free_marker_record: >> + kfree(start_marker_record); >> +release_netdev: >> + dev_put(netdev); >> +release_lock: >> + percpu_up_read(&device_offload_lock); >> +out: >> + return rc; >> +} >> + >> +static int tls_device_register(struct net_device *dev) >> +{ >> + if ((dev->features & NETIF_F_HW_TLS_TX) && !dev->tlsdev_ops) >> + return NOTIFY_BAD; >> + >> + return NOTIFY_DONE; >> +} > > This function is the same as tls_device_feat_change(). Can't we merge > them together and avoid duplicating of code? > Sure. >> +static int tls_device_unregister(struct net_device *dev) >> +{ >> + return NOTIFY_DONE; >> +} > > This function does nothing, and next patches do not change it. > Can't we remove it since so? > Sure. >> +static int tls_device_feat_change(struct net_device *dev) >> +{ >> + if ((dev->features & NETIF_F_HW_TLS_TX) && !dev->tlsdev_ops) >> + return NOTIFY_BAD; >> + >> + return NOTIFY_DONE; >> +} >> + >> +static int tls_device_down(struct net_device *netdev) >> +{ >> + struct tls_context *ctx, *tmp; >> + struct list_head list; >> + unsigned long flags; >> + >> + if (!(netdev->features & NETIF_F_HW_TLS_TX)) >> + return NOTIFY_DONE; > > Can't we move this check in tls_dev_event() and use it for all types of events? > Then we avoid duplicate code. > No. Not all events require this check. Also, the result is different for different events. >> + >> + /* Request a write lock to block new offload attempts >> + */ >> + percpu_down_write(&device_offload_lock); > > What is the reason percpu_rwsem is chosen here? It looks like this primitive > gives more advantages readers, then plain rwsem does. But it also gives > disadvantages to writers. It would be good, unless tls_device_down() is called > with rtnl_lock() held from netdevice notifier. But since netdevice notifier > are called with rtnl_lock() held, percpu_rwsem will increase the time rtnl_lock() > is locked. We use the a rwsem to allow multiple (readers) invocations of tls_set_device_offload, which is triggered by the user (persumably) during the TLS handshake. This might be considered a fast-path. However, we must block all calls to tls_set_device_offload while we are processing NETDEV_DOWN events (writer). As you've mentioned, the percpu rwsem is more efficient for readers, especially on NUMA systems, where cache-line bouncing occurs during reader acquire and reduces performance. > > Can't we use plain rwsem here instead? > Its a performance tradeoff. I'm not certain that the percpu rwsem write side acquire is significantly worse than using the global rwsem. For now, while all of this is experimental, can we agree to focus on the performance of readers? We can change it later if it becomes a problem. >> + >> + spin_lock_irqsave(&tls_device_lock, flags); >> + INIT_LIST_HEAD(&list); > > This may go outside the global spinlock. > Sure. >> + list_for_each_entry_safe(ctx, tmp, &tls_device_list, list) { >> + if (ctx->netdev != netdev || >> + !refcount_inc_not_zero(&ctx->refcount)) >> + continue; >> + >> + list_move(&ctx->list, &list); >> + } >> + spin_unlock_irqrestore(&tls_device_lock, flags); >> + >> + list_for_each_entry_safe(ctx, tmp, &list, list) { >> + netdev->tlsdev_ops->tls_dev_del(netdev, ctx, >> + TLS_OFFLOAD_CTX_DIR_TX); >> + ctx->netdev = NULL; >> + dev_put(netdev); >> + list_del_init(&ctx->list); >> + >> + if (refcount_dec_and_test(&ctx->refcount)) >> + tls_device_free_ctx(ctx); >> + } >> + >> + percpu_up_write(&device_offload_lock); >> + >> + flush_work(&tls_device_gc_work); >> + >> + return NOTIFY_DONE; >> +} >> + >> +static int tls_dev_event(struct notifier_block *this, unsigned long event, >> + void *ptr) >> +{ >> + struct net_device *dev = netdev_notifier_info_to_dev(ptr); >> + >> + switch (event) { >> + case NETDEV_REGISTER: >> + return tls_device_register(dev); >> + >> + case NETDEV_UNREGISTER: >> + return tls_device_unregister(dev); >> + >> + case NETDEV_FEAT_CHANGE: >> + return tls_device_feat_change(dev); >> + >> + case NETDEV_DOWN: >> + return tls_device_down(dev); >> + } >> + return NOTIFY_DONE; >> +} >> + >> +static struct notifier_block tls_dev_notifier = { >> + .notifier_call = tls_dev_event, >> +}; >> + >> +void __init tls_device_init(void) >> +{ >> + register_netdevice_notifier(&tls_dev_notifier); >> +} >> + >> +void __exit tls_device_cleanup(void) >> +{ >> + unregister_netdevice_notifier(&tls_dev_notifier); >> + flush_work(&tls_device_gc_work); >> +} >> diff --git a/net/tls/tls_device_fallback.c b/net/tls/tls_device_fallback.c >> new file mode 100644 >> index 000000000000..14d31a36885c >> --- /dev/null >> +++ b/net/tls/tls_device_fallback.c >> @@ -0,0 +1,419 @@ >> +/* Copyright (c) 2018, Mellanox Technologies All rights reserved. >> + * >> + * Redistribution and use in source and binary forms, with or >> + * without modification, are permitted provided that the following >> + * conditions are met: >> + * >> + * - Redistributions of source code must retain the above >> + * copyright notice, this list of conditions and the following >> + * disclaimer. >> + * >> + * - Redistributions in binary form must reproduce the above >> + * copyright notice, this list of conditions and the following >> + * disclaimer in the documentation and/or other materials >> + * provided with the distribution. >> + * >> + * - Neither the name of the Mellanox Technologies nor the >> + * names of its contributors may be used to endorse or promote >> + * products derived from this software without specific prior written >> + * permission. >> + * >> + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" >> + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, >> + * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR >> + * A PARTICULAR PURPOSE ARE DISCLAIMED. >> + * IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR >> + * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL >> + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR >> + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) >> + * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, >> + * STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING >> + * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE >> + * POSSIBILITY OF SUCH DAMAGE >> + */ >> + >> +#include >> +#include >> +#include >> +#include >> + >> +static void chain_to_walk(struct scatterlist *sg, struct scatter_walk *walk) >> +{ >> + struct scatterlist *src = walk->sg; >> + int diff = walk->offset - src->offset; >> + >> + sg_set_page(sg, sg_page(src), >> + src->length - diff, walk->offset); >> + >> + scatterwalk_crypto_chain(sg, sg_next(src), 0, 2); >> +} >> + >> +static int tls_enc_record(struct aead_request *aead_req, >> + struct crypto_aead *aead, char *aad, char *iv, >> + __be64 rcd_sn, struct scatter_walk *in, >> + struct scatter_walk *out, int *in_len) >> +{ >> + struct scatterlist sg_in[3]; >> + struct scatterlist sg_out[3]; >> + unsigned char buf[TLS_HEADER_SIZE + TLS_CIPHER_AES_GCM_128_IV_SIZE]; >> + u16 len; >> + int rc; >> + >> + len = min_t(int, *in_len, ARRAY_SIZE(buf)); >> + >> + scatterwalk_copychunks(buf, in, len, 0); >> + scatterwalk_copychunks(buf, out, len, 1); >> + >> + *in_len -= len; >> + if (!*in_len) >> + return 0; >> + >> + scatterwalk_pagedone(in, 0, 1); >> + scatterwalk_pagedone(out, 1, 1); >> + >> + len = buf[4] | (buf[3] << 8); >> + len -= TLS_CIPHER_AES_GCM_128_IV_SIZE; >> + >> + tls_make_aad(aad, len - TLS_CIPHER_AES_GCM_128_TAG_SIZE, >> + (char *)&rcd_sn, sizeof(rcd_sn), buf[0]); >> + >> + memcpy(iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, buf + TLS_HEADER_SIZE, >> + TLS_CIPHER_AES_GCM_128_IV_SIZE); >> + >> + sg_init_table(sg_in, ARRAY_SIZE(sg_in)); >> + sg_init_table(sg_out, ARRAY_SIZE(sg_out)); >> + sg_set_buf(sg_in, aad, TLS_AAD_SPACE_SIZE); >> + sg_set_buf(sg_out, aad, TLS_AAD_SPACE_SIZE); >> + chain_to_walk(sg_in + 1, in); >> + chain_to_walk(sg_out + 1, out); >> + >> + *in_len -= len; >> + if (*in_len < 0) { >> + *in_len += TLS_CIPHER_AES_GCM_128_TAG_SIZE; >> + if (*in_len < 0) >> + /* the input buffer doesn't contain the entire record. >> + * trim len accordingly. The resulting authentication tag >> + * will contain garbage. but we don't care as we won't >> + * include any of it in the output skb >> + * Note that we assume the output buffer length >> + * is larger then input buffer length + tag size >> + */ >> + len += *in_len; >> + >> + *in_len = 0; >> + } >> + >> + if (*in_len) { >> + scatterwalk_copychunks(NULL, in, len, 2); >> + scatterwalk_pagedone(in, 0, 1); >> + scatterwalk_copychunks(NULL, out, len, 2); >> + scatterwalk_pagedone(out, 1, 1); >> + } >> + >> + len -= TLS_CIPHER_AES_GCM_128_TAG_SIZE; >> + aead_request_set_crypt(aead_req, sg_in, sg_out, len, iv); >> + >> + rc = crypto_aead_encrypt(aead_req); >> + >> + return rc; >> +} >> + >> +static void tls_init_aead_request(struct aead_request *aead_req, >> + struct crypto_aead *aead) >> +{ >> + aead_request_set_tfm(aead_req, aead); >> + aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE); >> +} >> + >> +static struct aead_request *tls_alloc_aead_request(struct crypto_aead *aead, >> + gfp_t flags) >> +{ >> + unsigned int req_size = sizeof(struct aead_request) + >> + crypto_aead_reqsize(aead); >> + struct aead_request *aead_req; >> + >> + aead_req = kzalloc(req_size, flags); >> + if (!aead_req) >> + return NULL; >> + >> + tls_init_aead_request(aead_req, aead); >> + return aead_req; >> +} >> + >> +static int tls_enc_records(struct aead_request *aead_req, >> + struct crypto_aead *aead, struct scatterlist *sg_in, >> + struct scatterlist *sg_out, char *aad, char *iv, >> + u64 rcd_sn, int len) >> +{ >> + struct scatter_walk in; >> + struct scatter_walk out; >> + int rc; >> + >> + scatterwalk_start(&in, sg_in); >> + scatterwalk_start(&out, sg_out); >> + >> + do { >> + rc = tls_enc_record(aead_req, aead, aad, iv, >> + cpu_to_be64(rcd_sn), &in, &out, &len); >> + rcd_sn++; >> + >> + } while (rc == 0 && len); >> + >> + scatterwalk_done(&in, 0, 0); >> + scatterwalk_done(&out, 1, 0); >> + >> + return rc; >> +} >> + >> +static inline void update_chksum(struct sk_buff *skb, int headln) >> +{ >> + /* Can't use icsk->icsk_af_ops->send_check here because the ip addresses >> + * might have been changed by NAT. >> + */ >> + >> + const struct ipv6hdr *ipv6h; >> + const struct iphdr *iph; >> + struct tcphdr *th = tcp_hdr(skb); >> + int datalen = skb->len - headln; >> + >> + /* We only changed the payload so if we are using partial we don't >> + * need to update anything. >> + */ >> + if (likely(skb->ip_summed == CHECKSUM_PARTIAL)) >> + return; >> + >> + skb->ip_summed = CHECKSUM_PARTIAL; >> + skb->csum_start = skb_transport_header(skb) - skb->head; >> + skb->csum_offset = offsetof(struct tcphdr, check); >> + >> + if (skb->sk->sk_family == AF_INET6) { >> + ipv6h = ipv6_hdr(skb); >> + th->check = ~csum_ipv6_magic(&ipv6h->saddr, &ipv6h->daddr, >> + datalen, IPPROTO_TCP, 0); >> + } else { >> + iph = ip_hdr(skb); >> + th->check = ~csum_tcpudp_magic(iph->saddr, iph->daddr, datalen, >> + IPPROTO_TCP, 0); >> + } >> +} >> + >> +static void complete_skb(struct sk_buff *nskb, struct sk_buff *skb, int headln) >> +{ >> + skb_copy_header(nskb, skb); >> + >> + skb_put(nskb, skb->len); >> + memcpy(nskb->data, skb->data, headln); >> + update_chksum(nskb, headln); >> + >> + nskb->destructor = skb->destructor; >> + nskb->sk = skb->sk; >> + skb->destructor = NULL; >> + skb->sk = NULL; >> + refcount_add(nskb->truesize - skb->truesize, >> + &nskb->sk->sk_wmem_alloc); >> +} >> + >> +/* This function may be called after the user socket is already >> + * closed so make sure we don't use anything freed during >> + * tls_sk_proto_close here >> + */ >> +static struct sk_buff *tls_sw_fallback(struct sock *sk, struct sk_buff *skb) >> +{ >> + int tcp_header_size = tcp_hdrlen(skb); >> + int tcp_payload_offset = skb_transport_offset(skb) + tcp_header_size; >> + int payload_len = skb->len - tcp_payload_offset; >> + struct tls_context *tls_ctx = tls_get_ctx(sk); >> + struct tls_offload_context *ctx = tls_offload_ctx(tls_ctx); >> + int remaining, buf_len, resync_sgs, rc, i = 0; >> + void *buf, *dummy_buf, *iv, *aad; >> + struct scatterlist *sg_in; >> + struct scatterlist sg_out[3]; >> + u32 tcp_seq = ntohl(tcp_hdr(skb)->seq); >> + struct aead_request *aead_req; >> + struct sk_buff *nskb = NULL; >> + struct tls_record_info *record; >> + unsigned long flags; >> + s32 sync_size; >> + u64 rcd_sn; >> + >> + /* worst case is: >> + * MAX_SKB_FRAGS in tls_record_info >> + * MAX_SKB_FRAGS + 1 in SKB head an frags. >> + */ >> + int sg_in_max_elements = 2 * MAX_SKB_FRAGS + 1; >> + >> + if (!payload_len) >> + return skb; >> + >> + sg_in = kmalloc_array(sg_in_max_elements, sizeof(*sg_in), GFP_ATOMIC); >> + if (!sg_in) >> + goto free_orig; >> + >> + sg_init_table(sg_in, sg_in_max_elements); >> + sg_init_table(sg_out, ARRAY_SIZE(sg_out)); >> + >> + spin_lock_irqsave(&ctx->lock, flags); >> + record = tls_get_record(ctx, tcp_seq, &rcd_sn); >> + if (!record) { >> + spin_unlock_irqrestore(&ctx->lock, flags); >> + WARN(1, "Record not found for seq %u\n", tcp_seq); >> + goto free_sg; >> + } >> + >> + sync_size = tcp_seq - tls_record_start_seq(record); >> + if (sync_size < 0) { >> + int is_start_marker = tls_record_is_start_marker(record); >> + >> + spin_unlock_irqrestore(&ctx->lock, flags); >> + if (!is_start_marker) >> + /* This should only occur if the relevant record was >> + * already acked. In that case it should be ok >> + * to drop the packet and avoid retransmission. >> + * >> + * There is a corner case where the packet contains >> + * both an acked and a non-acked record. >> + * We currently don't handle that case and rely >> + * on TCP to retranmit a packet that doesn't contain >> + * already acked payload. >> + */ >> + goto free_orig; >> + >> + if (payload_len > -sync_size) { >> + WARN(1, "Fallback of partially offloaded packets is not supported\n"); >> + goto free_sg; >> + } else { >> + return skb; >> + } >> + } >> + >> + remaining = sync_size; >> + while (remaining > 0) { >> + skb_frag_t *frag = &record->frags[i]; >> + >> + __skb_frag_ref(frag); >> + sg_set_page(sg_in + i, skb_frag_page(frag), >> + skb_frag_size(frag), frag->page_offset); >> + >> + remaining -= skb_frag_size(frag); >> + >> + if (remaining < 0) >> + sg_in[i].length += remaining; >> + >> + i++; >> + } >> + spin_unlock_irqrestore(&ctx->lock, flags); >> + resync_sgs = i; >> + >> + aead_req = tls_alloc_aead_request(ctx->aead_send, GFP_ATOMIC); >> + if (!aead_req) >> + goto put_sg; >> + >> + buf_len = TLS_CIPHER_AES_GCM_128_SALT_SIZE + >> + TLS_CIPHER_AES_GCM_128_IV_SIZE + >> + TLS_AAD_SPACE_SIZE + >> + sync_size + >> + tls_ctx->tag_size; >> + buf = kmalloc(buf_len, GFP_ATOMIC); >> + if (!buf) >> + goto free_req; >> + >> + nskb = alloc_skb(skb_headroom(skb) + skb->len, GFP_ATOMIC); >> + if (!nskb) >> + goto free_buf; >> + >> + skb_reserve(nskb, skb_headroom(skb)); >> + >> + iv = buf; >> + >> + memcpy(iv, tls_ctx->crypto_send_aes_gcm_128.salt, >> + TLS_CIPHER_AES_GCM_128_SALT_SIZE); >> + aad = buf + TLS_CIPHER_AES_GCM_128_SALT_SIZE + >> + TLS_CIPHER_AES_GCM_128_IV_SIZE; >> + dummy_buf = aad + TLS_AAD_SPACE_SIZE; >> + >> + sg_set_buf(&sg_out[0], dummy_buf, sync_size); >> + sg_set_buf(&sg_out[1], nskb->data + tcp_payload_offset, >> + payload_len); >> + /* Add room for authentication tag produced by crypto */ >> + dummy_buf += sync_size; >> + sg_set_buf(&sg_out[2], dummy_buf, tls_ctx->tag_size); >> + rc = skb_to_sgvec(skb, &sg_in[i], tcp_payload_offset, >> + payload_len); >> + if (rc < 0) >> + goto free_nskb; >> + >> + rc = tls_enc_records(aead_req, ctx->aead_send, sg_in, sg_out, aad, iv, >> + rcd_sn, sync_size + payload_len); >> + if (rc < 0) >> + goto free_nskb; >> + >> + complete_skb(nskb, skb, tcp_payload_offset); >> + >> + /* validate_xmit_skb_list assumes that if the skb wasn't segmented >> + * nskb->prev will point to the skb itself >> + */ >> + nskb->prev = nskb; >> +free_buf: >> + kfree(buf); >> +free_req: >> + kfree(aead_req); >> +put_sg: >> + for (i = 0; i < resync_sgs; i++) >> + put_page(sg_page(&sg_in[i])); >> +free_sg: >> + kfree(sg_in); >> +free_orig: >> + kfree_skb(skb); >> + return nskb; >> + >> +free_nskb: >> + kfree_skb(nskb); >> + nskb = NULL; >> + goto free_buf; >> +} >> + >> +static struct sk_buff *tls_validate_xmit_skb(struct sock *sk, >> + struct net_device *dev, >> + struct sk_buff *skb) >> +{ >> + if (dev == tls_get_ctx(sk)->netdev) >> + return skb; >> + >> + return tls_sw_fallback(sk, skb); >> +} >> + >> +int tls_sw_fallback_init(struct sock *sk, >> + struct tls_offload_context *offload_ctx, >> + struct tls_crypto_info *crypto_info) >> +{ >> + int rc; >> + const u8 *key; >> + >> + offload_ctx->aead_send = >> + crypto_alloc_aead("gcm(aes)", 0, CRYPTO_ALG_ASYNC); >> + if (IS_ERR(offload_ctx->aead_send)) { >> + rc = PTR_ERR(offload_ctx->aead_send); >> + pr_err_ratelimited("crypto_alloc_aead failed rc=%d\n", rc); >> + offload_ctx->aead_send = NULL; >> + goto err_out; >> + } >> + >> + key = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->key; >> + >> + rc = crypto_aead_setkey(offload_ctx->aead_send, key, >> + TLS_CIPHER_AES_GCM_128_KEY_SIZE); >> + if (rc) >> + goto free_aead; >> + >> + rc = crypto_aead_setauthsize(offload_ctx->aead_send, >> + TLS_CIPHER_AES_GCM_128_TAG_SIZE); >> + if (rc) >> + goto free_aead; >> + >> + sk->sk_validate_xmit_skb = tls_validate_xmit_skb; >> + return 0; >> +free_aead: >> + crypto_free_aead(offload_ctx->aead_send); >> +err_out: >> + return rc; >> +} >> diff --git a/net/tls/tls_main.c b/net/tls/tls_main.c >> index d824d548447e..e0dface33017 100644 >> --- a/net/tls/tls_main.c >> +++ b/net/tls/tls_main.c >> @@ -54,6 +54,9 @@ enum { >> enum { >> TLS_BASE_TX, >> TLS_SW_TX, >> +#ifdef CONFIG_TLS_DEVICE >> + TLS_HW_TX, >> +#endif >> TLS_NUM_CONFIG, >> }; >> >> @@ -416,11 +419,19 @@ static int do_tls_setsockopt_tx(struct sock *sk, char __user *optval, >> goto err_crypto_info; >> } >> >> - /* currently SW is default, we will have ethtool in future */ >> - rc = tls_set_sw_offload(sk, ctx); >> - tx_conf = TLS_SW_TX; >> - if (rc) >> - goto err_crypto_info; >> +#ifdef CONFIG_TLS_DEVICE >> + rc = tls_set_device_offload(sk, ctx); >> + tx_conf = TLS_HW_TX; >> + if (rc) { >> +#else >> + { >> +#endif >> + /* if HW offload fails fallback to SW */ >> + rc = tls_set_sw_offload(sk, ctx); >> + tx_conf = TLS_SW_TX; >> + if (rc) >> + goto err_crypto_info; >> + } >> >> ctx->tx_conf = tx_conf; >> update_sk_prot(sk, ctx); >> @@ -473,6 +484,12 @@ static void build_protos(struct proto *prot, struct proto *base) >> prot[TLS_SW_TX] = prot[TLS_BASE_TX]; >> prot[TLS_SW_TX].sendmsg = tls_sw_sendmsg; >> prot[TLS_SW_TX].sendpage = tls_sw_sendpage; >> + >> +#ifdef CONFIG_TLS_DEVICE >> + prot[TLS_HW_TX] = prot[TLS_SW_TX]; >> + prot[TLS_HW_TX].sendmsg = tls_device_sendmsg; >> + prot[TLS_HW_TX].sendpage = tls_device_sendpage; >> +#endif >> } >> >> static int tls_init(struct sock *sk) >> @@ -531,6 +548,9 @@ static int __init tls_register(void) >> { >> build_protos(tls_prots[TLSV4], &tcp_prot); >> >> +#ifdef CONFIG_TLS_DEVICE >> + tls_device_init(); >> +#endif >> tcp_register_ulp(&tcp_tls_ulp_ops); >> >> return 0; >> @@ -539,6 +559,9 @@ static int __init tls_register(void) >> static void __exit tls_unregister(void) >> { >> tcp_unregister_ulp(&tcp_tls_ulp_ops); >> +#ifdef CONFIG_TLS_DEVICE >> + tls_device_cleanup(); >> +#endif >> } >> >> module_init(tls_register); > > Thanks, > Kirill > Best, Boris.